fix scaling laws scripts after the bigram embeddings were removed
This commit is contained in:
+10
-9
@@ -24,7 +24,7 @@ RESULTS_FILE="$RESULTS_DIR/results.csv"
|
||||
|
||||
# Write CSV header only if file doesn't exist
|
||||
if [ ! -f "$RESULTS_FILE" ]; then
|
||||
echo "flops_budget,depth,model_dim,params_wte,params_bigram_embed,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
echo "flops_budget,depth,model_dim,params_wte,params_value_embeds,params_lm_head,params_transformer,params_scalars,params_total,num_iterations,tokens_trained,val_bpb,core_score,train_time_sec" > "$RESULTS_FILE"
|
||||
fi
|
||||
|
||||
log() {
|
||||
@@ -86,13 +86,14 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
LOG_FILE="$RESULTS_DIR/${TAG}_train.log"
|
||||
|
||||
# Extract detailed parameter counts (for scaling law analysis with different conventions)
|
||||
PARAMS_WTE=$(grep "wte:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_BIGRAM=$(grep "bigram_embed:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_VE=$(grep "value_embeds:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_LM=$(grep "lm_head:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TRANSFORMER=$(grep "transformer_matrices:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_SCALARS=$(grep "scalars:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TOTAL=$(grep "total:" "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
# Note: the log format is padded, e.g. "wte : 25,165,824"
|
||||
# so we grep for "^key " (key at start of line followed by space) to avoid false matches
|
||||
PARAMS_WTE=$(grep "^wte " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_VE=$(grep "^value_embeds " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_LM=$(grep "^lm_head " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TRANSFORMER=$(grep "^transformer_matrices " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_SCALARS=$(grep "^scalars " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
PARAMS_TOTAL=$(grep "^total " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||
|
||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
||||
@@ -112,7 +113,7 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
||||
log " Params: $PARAMS_TOTAL (transformer: $PARAMS_TRANSFORMER), Iters: $NUM_ITERS, Val BPB: $VAL_BPB, CORE: $CORE_SCORE"
|
||||
|
||||
# Append to CSV
|
||||
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_BIGRAM,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
echo "$flops,$d,$MODEL_DIM,$PARAMS_WTE,$PARAMS_VE,$PARAMS_LM,$PARAMS_TRANSFORMER,$PARAMS_SCALARS,$PARAMS_TOTAL,$NUM_ITERS,$TOKENS_TRAINED,$VAL_BPB,$CORE_SCORE,$TRAIN_TIME" >> "$RESULTS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
Reference in New Issue
Block a user