rename checkpoint_dir to checkpoints_dir for consistency.
This commit is contained in:
@@ -94,11 +94,11 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||||||
return model, tokenizer, meta_data
|
return model, tokenizer, meta_data
|
||||||
|
|
||||||
|
|
||||||
def find_largest_model(checkpoint_dir):
|
def find_largest_model(checkpoints_dir):
|
||||||
# attempt to guess the model tag: take the biggest model available
|
# attempt to guess the model tag: take the biggest model available
|
||||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
||||||
if not model_tags:
|
if not model_tags:
|
||||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
||||||
# 1) normally all model tags are of the form d<number>, try that first:
|
# 1) normally all model tags are of the form d<number>, try that first:
|
||||||
candidates = []
|
candidates = []
|
||||||
for model_tag in model_tags:
|
for model_tag in model_tags:
|
||||||
@@ -110,7 +110,7 @@ def find_largest_model(checkpoint_dir):
|
|||||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||||
return candidates[0][1]
|
return candidates[0][1]
|
||||||
# 2) if that failed, take the most recently updated model:
|
# 2) if that failed, take the most recently updated model:
|
||||||
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
||||||
return model_tags[0]
|
return model_tags[0]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user