Fix undefined variable in chat_rl after recent refactor

* Fix undefined variable

* Remove unused import

Remove unused import 're' from chat_rl.py
This commit is contained in:
Adria Blancafort
2026-01-07 18:08:57 +01:00
committed by GitHub
parent ae0bf52529
commit 1b5de29e71
+1 -2
View File
@@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
import argparse import argparse
import os import os
import itertools import itertools
import re
import wandb import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -174,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
tokens = tokenizer.render_for_completion(conversation) tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens) prefix_length = len(tokens)
# Generate k samples using batched generation inside the Engine # Generate k samples using batched generation inside the Engine
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch( generated_token_sequences, masks = engine.generate_batch(
tokens, tokens,
num_samples=num_samples, num_samples=num_samples,