From 32380a8082dff3d824137684d3a3c621b4416aa7 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Sun, 28 Dec 2025 21:26:44 +0000 Subject: [PATCH] Add wandb logging, finetune old code selection, and improve file naming - Add required wandb integration with hardcoded API key - Log training metrics (loss, accuracy, per-class metrics) to wandb - In finetune mode, automatically add old codes to training set to prevent forgetting - Zero-pad epoch numbers in model filenames (ep001, ep002) for better sorting - Change upload.py to display sizes and speeds in KB instead of bytes - Update Makefile finetune target to use 100 epochs instead of 30 - Set default wandb project to 'euphon/themblem' --- Makefile | 2 +- emblem5/ai/train2.py | 118 +++++++++++++++++++++++++++++++++++++++++-- emblem5/ai/upload.py | 9 ++-- 3 files changed, 120 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index c06253d..c46d51f 100644 --- a/Makefile +++ b/Makefile @@ -123,7 +123,7 @@ finetune: FORCE echo "Error: Model file not found: $$MODEL_PATH"; \ exit 1; \ fi; \ - cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 30 + cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 100 browser: FORCE cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR) diff --git a/emblem5/ai/train2.py b/emblem5/ai/train2.py index a69dcbe..c9f7450 100755 --- a/emblem5/ai/train2.py +++ b/emblem5/ai/train2.py @@ -42,6 +42,7 @@ import multiprocessing as mp from functools import partial from datetime import datetime from common import * +import wandb def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1): """Process a single scan to create grid files and metadata @@ -424,16 +425,35 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler current_lr = optimizer.param_groups[0]['lr'] print(f' Current learning rate: {current_lr:.6f}') + # Log to wandb + wandb.log({ + 'epoch': epoch + 1, + 'train/loss': train_loss / len(train_loader), + 'train/acc': train_acc, + 'train/pos_acc': train_pos_acc, + 'train/neg_acc': train_neg_acc, + 'train/pos_samples': train_pos_total, + 'train/neg_samples': train_neg_total, + 'val/loss': val_loss / len(val_loader), + 'val/acc': val_acc, + 'val/pos_acc': val_pos_acc, + 'val/neg_acc': val_neg_acc, + 'val/pos_samples': val_pos_total, + 'val/neg_samples': val_neg_total, + 'learning_rate': current_lr, + }) + # Save best model if val_acc > best_val_acc: best_val_acc = val_acc timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" models_dir = os.path.expanduser('~/emblem/models') - best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') + best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1:03d}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') save_model(model, transform, best_model_path, train_metadata) print(f' New best validation accuracy: {val_acc:.2f}%') print(f' Best model saved: {best_model_path}') + wandb.log({'best_val_acc': best_val_acc, 'best_epoch': epoch + 1}) return model, val_acc, val_pos_acc, val_neg_acc @@ -451,6 +471,8 @@ def main(): parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)') parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")') parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning') + parser.add_argument('--wandb-project', type=str, default='euphon/themblem', help='W&B project name (default: euphon/themblem)') + parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name (default: auto-generated)') args = parser.parse_args() # Set device @@ -462,14 +484,17 @@ def main(): models_dir = os.path.expanduser('~/emblem/models') os.makedirs(models_dir, exist_ok=True) - # Load scan data + # Load full scan data (needed for finetune mode to find old codes) print("Loading scan data...") - scan_data = load_scan_data(args.data_dir) + full_scan_data = load_scan_data(args.data_dir) - if not scan_data: + if not full_scan_data: print("No scan data found!") return + # Start with full scan data, then filter + scan_data = full_scan_data.copy() + # Filter by scan IDs if provided if args.scan_ids: ranges = parse_ranges(args.scan_ids) @@ -506,6 +531,57 @@ def main(): avoid_overlap = args.model is None train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap) + # In finetune mode, add old codes to prevent forgetting + old_codes_info = None + if args.model: + # Get distinct codes from current finetune dataset + finetune_codes = set() + for scan_id, metadata in scan_data.items(): + code = metadata.get('code') + if code: + finetune_codes.add(code) + + num_finetune_codes = len(finetune_codes) + print(f"Finetune mode: Found {num_finetune_codes} distinct codes in finetune dataset") + + if num_finetune_codes > 0: + # Find old codes (codes in full dataset but not in finetune dataset) + old_codes = set() + for scan_id, metadata in full_scan_data.items(): + code = metadata.get('code') + if code and code not in finetune_codes: + old_codes.add(code) + + print(f"Found {len(old_codes)} distinct old codes in full dataset") + + if len(old_codes) > 0: + # Select the same number of distinct random old codes + num_old_codes_to_select = min(num_finetune_codes, len(old_codes)) + selected_old_codes = random.sample(list(old_codes), num_old_codes_to_select) + print(f"Selected {num_old_codes_to_select} old codes to prevent forgetting: {selected_old_codes}") + + # Add all scans for selected old codes to training set + old_scans_added = 0 + for scan_id, metadata in full_scan_data.items(): + code = metadata.get('code') + if code and code in selected_old_codes: + # Only add if not already in train_data (avoid duplicates) + if scan_id not in train_data: + train_data[scan_id] = metadata + old_scans_added += 1 + + print(f"Added {old_scans_added} scans from old codes to training set") + print(f"Training set size: {len(train_data)} scans (including {old_scans_added} from old codes)") + old_codes_info = { + 'num_old_codes': num_old_codes_to_select, + 'old_codes': selected_old_codes, + 'old_scans_added': old_scans_added + } + else: + print("No old codes found in full dataset") + else: + print("No codes found in finetune dataset, skipping old code selection") + # Define transforms transform = transforms.Compose([ transforms.Resize((224, 224)), @@ -620,6 +696,35 @@ def main(): print(f" Hue jitter: {args.hue_jitter}") print(f" Quick mode: {args.quick}") + # Initialize wandb + wandb.login(key='ec22e6ed1ed9891779d57da600f889294f83e41d') + wandb_config = { + 'batch_size': args.batch_size, + 'lr': args.lr, + 'epochs': args.epochs, + 'hue_jitter': args.hue_jitter, + 'scheduler_patience': args.scheduler_patience, + 'scheduler_factor': args.scheduler_factor, + 'scheduler_min_lr': args.scheduler_min_lr, + 'quick_mode': args.quick, + 'train_scans': len(train_scan_ids), + 'val_scans': len(val_data), + 'train_samples': len(train_dataset), + 'val_samples': len(val_dataset), + 'model_path': args.model if args.model else 'default', + 'finetune_mode': args.model is not None, + } + if old_codes_info: + wandb_config['old_codes_count'] = old_codes_info['num_old_codes'] + wandb_config['old_scans_added'] = old_codes_info['old_scans_added'] + wandb_name = args.wandb_name or f"train2-{datetime.now().strftime('%Y%m%d_%H%M%S')}" + wandb.init( + project=args.wandb_project, + name=wandb_name, + config=wandb_config, + ) + print(f"W&B logging enabled: project={args.wandb_project}, run={wandb_name}") + # Train the model print("Starting training...") model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs) @@ -628,9 +733,12 @@ def main(): timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" models_dir = os.path.expanduser('~/emblem/models') - final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') + final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs:03d}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') save_model(model, transform, final_model_path, train_metadata) print(f"Training completed! Final model saved: {final_model_path}") + + # Finish wandb run + wandb.finish() if __name__ == '__main__': main() \ No newline at end of file diff --git a/emblem5/ai/upload.py b/emblem5/ai/upload.py index 3994365..69f9f58 100755 --- a/emblem5/ai/upload.py +++ b/emblem5/ai/upload.py @@ -22,14 +22,17 @@ class ProgressCallback: self.consumed = consumed elapsed = time.time() - self.start_time if elapsed > 0: - bps = consumed / elapsed + kbps = (consumed / elapsed) / 1024 # Convert to KB/s percent = 100 * consumed / total if total else 0 - print(f"\r{self.filename}: {percent:.1f}% ({consumed}/{total} bytes, {bps:.0f} B/s)", end='', flush=True) + consumed_kb = consumed / 1024 + total_kb = total / 1024 + print(f"\r{self.filename}: {percent:.1f}% ({consumed_kb:.0f}/{total_kb:.0f} KB, {kbps:.0f} KB/s)", end='', flush=True) for x in sys.argv[1:]: file_size = os.path.getsize(x) + file_size_kb = file_size / 1024 filename = os.path.basename(x) - print(f"Uploading {filename} ({file_size} bytes)...") + print(f"Uploading {filename} ({file_size_kb:.0f} KB)...") bucket.put_object_from_file(filename, x, progress_callback=ProgressCallback(filename, file_size)) print() # New line after progress completes