manually control the over-active garbage collector, save a small few minutes from a typical run
This commit is contained in:
@@ -11,6 +11,7 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex
|
|||||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import argparse
|
import argparse
|
||||||
@@ -429,8 +430,19 @@ while True:
|
|||||||
wandb_run.log(log_data)
|
wandb_run.log(log_data)
|
||||||
|
|
||||||
# state update
|
# state update
|
||||||
|
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
|
||||||
|
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
|
||||||
|
# So we manually manage and help it out here
|
||||||
|
if first_step_of_run:
|
||||||
|
gc.collect() # manually collect a lot of garbage from setup
|
||||||
|
gc.freeze() # immediately freeze all currently surviving objects and exclude them from GC
|
||||||
|
gc.disable() # nuclear intervention here: disable GC entirely except:
|
||||||
|
elif step % 5000 == 0: # every 5000 steps...
|
||||||
|
gc.collect() # manually collect, just to be safe for very, very long runs
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
|
|||||||
Reference in New Issue
Block a user