Add wandb logging, finetune old code selection, and improve file naming

- Add required wandb integration with hardcoded API key
- Log training metrics (loss, accuracy, per-class metrics) to wandb
- In finetune mode, automatically add old codes to training set to prevent forgetting
- Zero-pad epoch numbers in model filenames (ep001, ep002) for better sorting
- Change upload.py to display sizes and speeds in KB instead of bytes
- Update Makefile finetune target to use 100 epochs instead of 30
- Set default wandb project to 'euphon/themblem'
This commit is contained in:
Fam Zheng 2025-12-28 21:26:44 +00:00
parent 90043c2541
commit 32380a8082
3 changed files with 120 additions and 9 deletions

View File

@ -123,7 +123,7 @@ finetune: FORCE
echo "Error: Model file not found: $$MODEL_PATH"; \
exit 1; \
fi; \
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 30
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 100
browser: FORCE
cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR)

View File

@ -42,6 +42,7 @@ import multiprocessing as mp
from functools import partial
from datetime import datetime
from common import *
import wandb
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
"""Process a single scan to create grid files and metadata
@ -424,16 +425,35 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler
current_lr = optimizer.param_groups[0]['lr']
print(f' Current learning rate: {current_lr:.6f}')
# Log to wandb
wandb.log({
'epoch': epoch + 1,
'train/loss': train_loss / len(train_loader),
'train/acc': train_acc,
'train/pos_acc': train_pos_acc,
'train/neg_acc': train_neg_acc,
'train/pos_samples': train_pos_total,
'train/neg_samples': train_neg_total,
'val/loss': val_loss / len(val_loader),
'val/acc': val_acc,
'val/pos_acc': val_pos_acc,
'val/neg_acc': val_neg_acc,
'val/pos_samples': val_pos_total,
'val/neg_samples': val_neg_total,
'learning_rate': current_lr,
})
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
models_dir = os.path.expanduser('~/emblem/models')
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1:03d}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, best_model_path, train_metadata)
print(f' New best validation accuracy: {val_acc:.2f}%')
print(f' Best model saved: {best_model_path}')
wandb.log({'best_val_acc': best_val_acc, 'best_epoch': epoch + 1})
return model, val_acc, val_pos_acc, val_neg_acc
@ -451,6 +471,8 @@ def main():
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
parser.add_argument('--wandb-project', type=str, default='euphon/themblem', help='W&B project name (default: euphon/themblem)')
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name (default: auto-generated)')
args = parser.parse_args()
# Set device
@ -462,14 +484,17 @@ def main():
models_dir = os.path.expanduser('~/emblem/models')
os.makedirs(models_dir, exist_ok=True)
# Load scan data
# Load full scan data (needed for finetune mode to find old codes)
print("Loading scan data...")
scan_data = load_scan_data(args.data_dir)
full_scan_data = load_scan_data(args.data_dir)
if not scan_data:
if not full_scan_data:
print("No scan data found!")
return
# Start with full scan data, then filter
scan_data = full_scan_data.copy()
# Filter by scan IDs if provided
if args.scan_ids:
ranges = parse_ranges(args.scan_ids)
@ -506,6 +531,57 @@ def main():
avoid_overlap = args.model is None
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
# In finetune mode, add old codes to prevent forgetting
old_codes_info = None
if args.model:
# Get distinct codes from current finetune dataset
finetune_codes = set()
for scan_id, metadata in scan_data.items():
code = metadata.get('code')
if code:
finetune_codes.add(code)
num_finetune_codes = len(finetune_codes)
print(f"Finetune mode: Found {num_finetune_codes} distinct codes in finetune dataset")
if num_finetune_codes > 0:
# Find old codes (codes in full dataset but not in finetune dataset)
old_codes = set()
for scan_id, metadata in full_scan_data.items():
code = metadata.get('code')
if code and code not in finetune_codes:
old_codes.add(code)
print(f"Found {len(old_codes)} distinct old codes in full dataset")
if len(old_codes) > 0:
# Select the same number of distinct random old codes
num_old_codes_to_select = min(num_finetune_codes, len(old_codes))
selected_old_codes = random.sample(list(old_codes), num_old_codes_to_select)
print(f"Selected {num_old_codes_to_select} old codes to prevent forgetting: {selected_old_codes}")
# Add all scans for selected old codes to training set
old_scans_added = 0
for scan_id, metadata in full_scan_data.items():
code = metadata.get('code')
if code and code in selected_old_codes:
# Only add if not already in train_data (avoid duplicates)
if scan_id not in train_data:
train_data[scan_id] = metadata
old_scans_added += 1
print(f"Added {old_scans_added} scans from old codes to training set")
print(f"Training set size: {len(train_data)} scans (including {old_scans_added} from old codes)")
old_codes_info = {
'num_old_codes': num_old_codes_to_select,
'old_codes': selected_old_codes,
'old_scans_added': old_scans_added
}
else:
print("No old codes found in full dataset")
else:
print("No codes found in finetune dataset, skipping old code selection")
# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
@ -620,6 +696,35 @@ def main():
print(f" Hue jitter: {args.hue_jitter}")
print(f" Quick mode: {args.quick}")
# Initialize wandb
wandb.login(key='ec22e6ed1ed9891779d57da600f889294f83e41d')
wandb_config = {
'batch_size': args.batch_size,
'lr': args.lr,
'epochs': args.epochs,
'hue_jitter': args.hue_jitter,
'scheduler_patience': args.scheduler_patience,
'scheduler_factor': args.scheduler_factor,
'scheduler_min_lr': args.scheduler_min_lr,
'quick_mode': args.quick,
'train_scans': len(train_scan_ids),
'val_scans': len(val_data),
'train_samples': len(train_dataset),
'val_samples': len(val_dataset),
'model_path': args.model if args.model else 'default',
'finetune_mode': args.model is not None,
}
if old_codes_info:
wandb_config['old_codes_count'] = old_codes_info['num_old_codes']
wandb_config['old_scans_added'] = old_codes_info['old_scans_added']
wandb_name = args.wandb_name or f"train2-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
wandb.init(
project=args.wandb_project,
name=wandb_name,
config=wandb_config,
)
print(f"W&B logging enabled: project={args.wandb_project}, run={wandb_name}")
# Train the model
print("Starting training...")
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
@ -628,9 +733,12 @@ def main():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
models_dir = os.path.expanduser('~/emblem/models')
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs:03d}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, final_model_path, train_metadata)
print(f"Training completed! Final model saved: {final_model_path}")
# Finish wandb run
wandb.finish()
if __name__ == '__main__':
main()

View File

@ -22,14 +22,17 @@ class ProgressCallback:
self.consumed = consumed
elapsed = time.time() - self.start_time
if elapsed > 0:
bps = consumed / elapsed
kbps = (consumed / elapsed) / 1024 # Convert to KB/s
percent = 100 * consumed / total if total else 0
print(f"\r{self.filename}: {percent:.1f}% ({consumed}/{total} bytes, {bps:.0f} B/s)", end='', flush=True)
consumed_kb = consumed / 1024
total_kb = total / 1024
print(f"\r{self.filename}: {percent:.1f}% ({consumed_kb:.0f}/{total_kb:.0f} KB, {kbps:.0f} KB/s)", end='', flush=True)
for x in sys.argv[1:]:
file_size = os.path.getsize(x)
file_size_kb = file_size / 1024
filename = os.path.basename(x)
print(f"Uploading {filename} ({file_size} bytes)...")
print(f"Uploading {filename} ({file_size_kb:.0f} KB)...")
bucket.put_object_from_file(filename, x, progress_callback=ProgressCallback(filename, file_size))
print() # New line after progress completes