Fix undefined variable in chat_rl after recent refactor
* Fix undefined variable * Remove unused import Remove unused import 're' from chat_rl.py
This commit is contained in:
+1
-2
@@ -19,7 +19,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -174,7 +173,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||||||
tokens = tokenizer.render_for_completion(conversation)
|
tokens = tokenizer.render_for_completion(conversation)
|
||||||
prefix_length = len(tokens)
|
prefix_length = len(tokens)
|
||||||
# Generate k samples using batched generation inside the Engine
|
# Generate k samples using batched generation inside the Engine
|
||||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
assert num_samples <= args.device_batch_size # usually this is true. we can add a loop if not...
|
||||||
generated_token_sequences, masks = engine.generate_batch(
|
generated_token_sequences, masks = engine.generate_batch(
|
||||||
tokens,
|
tokens,
|
||||||
num_samples=num_samples,
|
num_samples=num_samples,
|
||||||
|
|||||||
Reference in New Issue
Block a user