Restore completion-only loss masking in SFT dataloader (#582)
* printing steps count * adding reply only loss for chat * using the mask by render_conversation function of tokeniser * undoing some changes * putting back the comment which got removed accidently, no functionality change
This commit is contained in:
+17
-5
@@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||||
bos_token = tokenizer.get_bos_token_id()
|
bos_token = tokenizer.get_bos_token_id()
|
||||||
|
|
||||||
# Conversation buffer: list of token lists
|
# Conversation buffer: list of (token_ids, loss_mask) tuples
|
||||||
conv_buffer = []
|
conv_buffer = []
|
||||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||||
@@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
nonlocal cursor, epoch
|
nonlocal cursor, epoch
|
||||||
while len(conv_buffer) < buffer_size:
|
while len(conv_buffer) < buffer_size:
|
||||||
conversation = dataset[cursor]
|
conversation = dataset[cursor]
|
||||||
ids, _ = tokenizer.render_conversation(conversation)
|
ids, mask = tokenizer.render_conversation(conversation)
|
||||||
conv_buffer.append(ids)
|
conv_buffer.append((ids, mask))
|
||||||
cursor += ddp_world_size
|
cursor += ddp_world_size
|
||||||
if cursor >= dataset_size:
|
if cursor >= dataset_size:
|
||||||
cursor = cursor % dataset_size
|
cursor = cursor % dataset_size
|
||||||
@@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
rows = []
|
rows = []
|
||||||
|
mask_rows = []
|
||||||
row_lengths = [] # Track actual content length (excluding padding) for each row
|
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||||
for _ in range(args.device_batch_size):
|
for _ in range(args.device_batch_size):
|
||||||
row = []
|
row = []
|
||||||
|
mask_row = []
|
||||||
padded = False
|
padded = False
|
||||||
while len(row) < row_capacity:
|
while len(row) < row_capacity:
|
||||||
# Ensure buffer has conversations
|
# Ensure buffer has conversations
|
||||||
@@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
# Find largest conversation that fits entirely
|
# Find largest conversation that fits entirely
|
||||||
best_idx = -1
|
best_idx = -1
|
||||||
best_len = 0
|
best_len = 0
|
||||||
for i, conv in enumerate(conv_buffer):
|
for i, (conv, _) in enumerate(conv_buffer):
|
||||||
conv_len = len(conv)
|
conv_len = len(conv)
|
||||||
if conv_len <= remaining and conv_len > best_len:
|
if conv_len <= remaining and conv_len > best_len:
|
||||||
best_idx = i
|
best_idx = i
|
||||||
@@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
|
|
||||||
if best_idx >= 0:
|
if best_idx >= 0:
|
||||||
# Found a conversation that fits - use it entirely
|
# Found a conversation that fits - use it entirely
|
||||||
conv = conv_buffer.pop(best_idx)
|
conv, conv_mask = conv_buffer.pop(best_idx)
|
||||||
row.extend(conv)
|
row.extend(conv)
|
||||||
|
mask_row.extend(conv_mask)
|
||||||
consumed += ddp_world_size # Track actual consumption
|
consumed += ddp_world_size # Track actual consumption
|
||||||
else:
|
else:
|
||||||
# No conversation fits - pad the remainder instead of cropping
|
# No conversation fits - pad the remainder instead of cropping
|
||||||
# This ensures we never discard any tokens
|
# This ensures we never discard any tokens
|
||||||
content_len = len(row)
|
content_len = len(row)
|
||||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||||
|
mask_row.extend([0] * remaining)
|
||||||
padded = True
|
padded = True
|
||||||
break # Row is now full (with padding)
|
break # Row is now full (with padding)
|
||||||
|
|
||||||
@@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
else:
|
else:
|
||||||
row_lengths.append(row_capacity)
|
row_lengths.append(row_capacity)
|
||||||
rows.append(row[:row_capacity])
|
rows.append(row[:row_capacity])
|
||||||
|
mask_rows.append(mask_row[:row_capacity])
|
||||||
|
|
||||||
# Stopping condition to respect num_iterations, if given
|
# Stopping condition to respect num_iterations, if given
|
||||||
it += 1
|
it += 1
|
||||||
@@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
|||||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
|
||||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
|
||||||
|
|
||||||
|
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
||||||
|
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
||||||
|
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
|
||||||
|
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
|
||||||
|
mask_targets = mask_tensor[:, 1:].to(device=device)
|
||||||
|
targets[mask_targets == 0] = -1
|
||||||
|
|
||||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||||
for i, content_len in enumerate(row_lengths):
|
for i, content_len in enumerate(row_lengths):
|
||||||
|
|||||||
Reference in New Issue
Block a user