Restore completion-only loss masking in SFT dataloader (#582)

* printing steps count

* adding reply only loss for chat

* using the mask by render_conversation function of tokeniser

* undoing some changes

* putting back the comment which got removed accidently, no functionality change
This commit is contained in:
Anish
2026-03-03 06:07:47 +05:30
committed by GitHub
parent c7ba252142
commit 83dccc20ae
+17 -5
View File
@@ -197,7 +197,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
row_capacity = args.max_seq_len + 1 # +1 for target at last position row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id() bos_token = tokenizer.get_bos_token_id()
# Conversation buffer: list of token lists # Conversation buffer: list of (token_ids, loss_mask) tuples
conv_buffer = [] conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching) cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering consumed = ddp_rank # Track actual consumption separately from buffering
@@ -208,8 +208,8 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
nonlocal cursor, epoch nonlocal cursor, epoch
while len(conv_buffer) < buffer_size: while len(conv_buffer) < buffer_size:
conversation = dataset[cursor] conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation) ids, mask = tokenizer.render_conversation(conversation)
conv_buffer.append(ids) conv_buffer.append((ids, mask))
cursor += ddp_world_size cursor += ddp_world_size
if cursor >= dataset_size: if cursor >= dataset_size:
cursor = cursor % dataset_size cursor = cursor % dataset_size
@@ -218,9 +218,11 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
while True: while True:
rows = [] rows = []
mask_rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size): for _ in range(args.device_batch_size):
row = [] row = []
mask_row = []
padded = False padded = False
while len(row) < row_capacity: while len(row) < row_capacity:
# Ensure buffer has conversations # Ensure buffer has conversations
@@ -232,7 +234,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
# Find largest conversation that fits entirely # Find largest conversation that fits entirely
best_idx = -1 best_idx = -1
best_len = 0 best_len = 0
for i, conv in enumerate(conv_buffer): for i, (conv, _) in enumerate(conv_buffer):
conv_len = len(conv) conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len: if conv_len <= remaining and conv_len > best_len:
best_idx = i best_idx = i
@@ -240,14 +242,16 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
if best_idx >= 0: if best_idx >= 0:
# Found a conversation that fits - use it entirely # Found a conversation that fits - use it entirely
conv = conv_buffer.pop(best_idx) conv, conv_mask = conv_buffer.pop(best_idx)
row.extend(conv) row.extend(conv)
mask_row.extend(conv_mask)
consumed += ddp_world_size # Track actual consumption consumed += ddp_world_size # Track actual consumption
else: else:
# No conversation fits - pad the remainder instead of cropping # No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens # This ensures we never discard any tokens
content_len = len(row) content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens row.extend([bos_token] * remaining) # Pad with BOS tokens
mask_row.extend([0] * remaining)
padded = True padded = True
break # Row is now full (with padding) break # Row is now full (with padding)
@@ -257,6 +261,7 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
else: else:
row_lengths.append(row_capacity) row_lengths.append(row_capacity)
rows.append(row[:row_capacity]) rows.append(row[:row_capacity])
mask_rows.append(mask_row[:row_capacity])
# Stopping condition to respect num_iterations, if given # Stopping condition to respect num_iterations, if given
it += 1 it += 1
@@ -280,6 +285,13 @@ def sft_data_generator_bos_bestfit(split, buffer_size=100):
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
mask_targets = mask_tensor[:, 1:].to(device=device)
targets[mask_targets == 0] = -1
# Mask out padding positions in targets (set to -1 = ignore_index) # Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked # For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths): for i, content_len in enumerate(row_lengths):