update to new os env var to get rid of deprecation warning
This commit is contained in:
@@ -12,7 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 -
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
|
||||
Reference in New Issue
Block a user