fix small bug with params logging and batch size
This commit is contained in:
+14
-3
@@ -8,7 +8,7 @@ FLOPS_BUDGETS=(
|
|||||||
4.64e18
|
4.64e18
|
||||||
1e19
|
1e19
|
||||||
)
|
)
|
||||||
DEPTHS=(8 10 12 14 16 18 20)
|
DEPTHS=(10 12 14 16 18 20)
|
||||||
|
|
||||||
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
|
||||||
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}"
|
||||||
@@ -60,6 +60,15 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
|||||||
# Unique tag for this run
|
# Unique tag for this run
|
||||||
TAG="scaling_${flops}_d${d}"
|
TAG="scaling_${flops}_d${d}"
|
||||||
|
|
||||||
|
# Reduce --device-batch-size to avoid OOM at larger depths
|
||||||
|
if [ $d -ge 28 ]; then
|
||||||
|
DEVICE_BATCH_SIZE_ARG="--device-batch-size=8"
|
||||||
|
elif [ $d -ge 20 ]; then
|
||||||
|
DEVICE_BATCH_SIZE_ARG="--device-batch-size=16"
|
||||||
|
else
|
||||||
|
DEVICE_BATCH_SIZE_ARG="--device-batch-size=32"
|
||||||
|
fi
|
||||||
|
|
||||||
# Record start time
|
# Record start time
|
||||||
START_TIME=$(date +%s)
|
START_TIME=$(date +%s)
|
||||||
|
|
||||||
@@ -77,6 +86,7 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
|||||||
--core-metric-max-per-task=-1 \
|
--core-metric-max-per-task=-1 \
|
||||||
--sample-every=-1 \
|
--sample-every=-1 \
|
||||||
--save-every=-1 \
|
--save-every=-1 \
|
||||||
|
$DEVICE_BATCH_SIZE_ARG \
|
||||||
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
2>&1 | tee "$RESULTS_DIR/${TAG}_train.log"
|
||||||
|
|
||||||
END_TIME=$(date +%s)
|
END_TIME=$(date +%s)
|
||||||
@@ -96,8 +106,9 @@ for flops in "${FLOPS_BUDGETS[@]}"; do
|
|||||||
PARAMS_TOTAL=$(grep "^total " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
PARAMS_TOTAL=$(grep "^total " "$LOG_FILE" | tail -1 | grep -oP '[\d,]+' | tr -d ',')
|
||||||
|
|
||||||
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
NUM_ITERS=$(grep "Calculated number of iterations" "$LOG_FILE" | tail -1 | sed 's/.*: //' | tr -d ',')
|
||||||
# Calculate tokens trained (iterations * batch_size, default 524288)
|
# Extract actual batch size from log (auto-computed, varies by model size)
|
||||||
TOKENS_TRAINED=$((NUM_ITERS * 524288))
|
BATCH_SIZE=$(grep "Total batch size" "$LOG_FILE" | tail -1 | grep -oP 'Total batch size \K[\d,]+' | tr -d ',')
|
||||||
|
TOKENS_TRAINED=$((NUM_ITERS * BATCH_SIZE))
|
||||||
# Model dim
|
# Model dim
|
||||||
MODEL_DIM=$((d * 64))
|
MODEL_DIM=$((d * 64))
|
||||||
# Val BPB from final eval
|
# Val BPB from final eval
|
||||||
|
|||||||
Reference in New Issue
Block a user