themblem/emblem5/ai/train2.py
Fam Zheng 3715bafcb4 Increase image resolution to 512x512 in training and verification
- Update train2.py and verify2.py to use 512x512 image resize instead of 224x224
- Add grep command permission to Claude settings
2026-01-24 12:25:45 +00:00

744 lines
31 KiB
Python
Executable File

#!/usr/bin/env python3
'''
all children of data/scans are scan_ids
in each scan_id, there is a file called "metadata.json"
labels key has 'pos' or 'neg'
it also has 'code' key which is a string
For the largest 20% scan_ids, we use as validation set
for all codes appearing in the validation set in all scans, we also use as validation set
the rest are training set
preprocess the sbs.jpg for both train and validation:
1. split left and right half
2. crop 3x3 of left
3. crop 3x3 of right
4. for each pair of crop, concat into grid-i-j.jpg, and apply some colorjitter
5. all the grid-i-j.jpg are used with the label
load a resnet18 model, and train it on the training set
train for 10 epochs, and print accuracy on validation set each epoch
save the model in the end.
'''
import os
import json
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
from collections import defaultdict
import multiprocessing as mp
from functools import partial
from datetime import datetime
from common import *
import wandb
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
"""Process a single scan to create grid files and metadata
Returns: (sample_metadata, reason) where reason is None on success, or a string describing the failure"""
scan_id, metadata = scan_item
sample_metadata = []
sbs_path = os.path.join(data_dir, 'scans', scan_id, 'sbs.jpg')
if not os.path.exists(sbs_path):
return sample_metadata, 'sbs.jpg missing'
# Create grid directory if it doesn't exist
grid_dir = os.path.join(data_dir, 'scans', scan_id, 'grids')
os.makedirs(grid_dir, exist_ok=True)
# Check if all grid files already exist
all_grids_exist = True
for i in range(3):
for j in range(3):
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
if not os.path.exists(grid_path):
all_grids_exist = False
break
if not all_grids_exist:
break
# If all grid files exist, just return metadata
if all_grids_exist:
for i in range(3):
for j in range(3):
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
label = 1 if 'pos' in metadata['labels'] else 0
sample_metadata.append({
'scan_id': scan_id,
'grid_path': grid_path,
'grid_i': i,
'grid_j': j,
'label': label
})
return sample_metadata, None
# Load the side-by-side image
try:
sbs_img = Image.open(sbs_path).convert('RGB')
except Exception as e:
return sample_metadata, f'sbs.jpg unreadable: {str(e)}'
width, height = sbs_img.size
# Check if image is too small
if width < 6 or height < 3:
return sample_metadata, f'sbs.jpg too small: {width}x{height}'
# Calculate crop dimensions
crop_width = width // 6 # width/2 / 3
crop_height = height // 3
if crop_width <= 0 or crop_height <= 0:
return sample_metadata, f'invalid crop dimensions: {crop_width}x{crop_height}'
# Generate all 3x3 grid combinations
try:
for i in range(3):
for j in range(3):
# Calculate crop positions directly from original image
left_x = i * crop_width
right_x = (i + 3) * crop_width # Skip middle section
y = j * crop_height
# Crop directly from original image
left_crop = sbs_img.crop((left_x, y, left_x + crop_width, y + crop_height))
right_crop = sbs_img.crop((right_x, y, right_x + crop_width, y + crop_height))
# Apply color jitter only to left crop
color_jitter = transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=hue_jitter
)
left_crop = color_jitter(left_crop)
# Concatenate left and right crops horizontally
grid_img = Image.new('RGB', (crop_width * 2, crop_height))
grid_img.paste(left_crop, (0, 0))
grid_img.paste(right_crop, (crop_width, 0))
# Save grid image
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
grid_img.save(grid_path, 'JPEG', quality=95)
# Store metadata
label = 1 if 'pos' in metadata['labels'] else 0
sample_metadata.append({
'scan_id': scan_id,
'grid_path': grid_path,
'grid_i': i,
'grid_j': j,
'label': label
})
except Exception as e:
return sample_metadata, f'error creating grids: {str(e)}'
return sample_metadata, None
class GridDataset(Dataset):
def __init__(self, scan_data, data_dir='data', transform=None, num_workers=None, hue_jitter=0.1):
self.scan_data = scan_data
self.data_dir = data_dir
self.transform = transform
self.sample_metadata = []
self.hue_jitter = hue_jitter
if num_workers is None:
num_workers = min(mp.cpu_count(), 8) # Limit to 8 workers max
print(f"Preprocessing {len(scan_data)} scans to create grid files using {num_workers} workers...")
# Use multiprocessing to create grid files
with mp.Pool(processes=num_workers) as pool:
# Process all scans in parallel with hue_jitter parameter
process_func = partial(process_scan_grid, data_dir=self.data_dir, hue_jitter=self.hue_jitter)
results = list(tqdm(
pool.imap(process_func, scan_data.items()),
total=len(scan_data),
desc="Creating grid files"
))
# Collect all sample metadata and statistics
stats = defaultdict(int)
for result in results:
if isinstance(result, tuple) and len(result) == 2:
metadata_list, reason = result
self.sample_metadata.extend(metadata_list)
if reason is not None:
stats[reason] += 1
else:
# Backward compatibility - old format without reason
self.sample_metadata.extend(result)
print(f"Created {len(self.sample_metadata)} grid files")
if stats:
print("\nStatistics on why grid files were not created:")
for reason, count in sorted(stats.items(), key=lambda x: -x[1]):
print(f" {reason}: {count} scans")
def __len__(self):
return len(self.sample_metadata)
def __getitem__(self, idx):
# Get sample metadata
metadata = self.sample_metadata[idx]
# Load the pre-saved grid image directly
grid_img = Image.open(metadata['grid_path']).convert('RGB')
# Apply transforms
if self.transform:
grid_img = self.transform(grid_img)
return grid_img, torch.tensor(metadata['label'], dtype=torch.long)
def load_scan_data(data_dir):
"""Load all scan metadata and organize by scan_id"""
scan_data = {}
scans_dir = os.path.join(data_dir, 'scans')
if not os.path.exists(scans_dir):
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
scan_ids = [d for d in os.listdir(scans_dir)
if os.path.isdir(os.path.join(scans_dir, d))]
print(f"Loading metadata for {len(scan_ids)} scans...")
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
scan_data[scan_id] = metadata
except Exception as e:
print(f"Error loading metadata for {scan_id}: {e}")
return scan_data
def split_train_val(scan_data, avoid_overlap=True):
"""Split data into train and validation sets
Args:
scan_data: Dictionary of scan_id -> metadata
avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap.
If False, use simple random split without overlap avoidance.
"""
scan_ids = list(scan_data.keys())
print("Splitting data into train and validation sets...")
# Select 10% random scan_ids for initial validation set
val_size = int(len(scan_ids) * 0.1)
initial_val_scan_ids = random.sample(scan_ids, val_size)
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
if avoid_overlap:
# Get all (code, labels) combinations that appear in the initial validation set
val_code_labels = set()
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
if code:
val_code_labels.add((code, labels))
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
# Find all scans in train that have matching (code, labels) combinations and move them to validation
additional_val_scan_ids = set()
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
# If (code, labels) combination matches, move to validation
if code and (code, labels) in val_code_labels:
additional_val_scan_ids.add(scan_id)
# Combine validation sets
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
else:
# Simple split without overlap avoidance
all_val_scan_ids = set(initial_val_scan_ids)
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
# Create train and validation data dictionaries
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
print(f"Final split:")
print(f" Total scans: {len(scan_data)}")
print(f" Training scans: {len(train_data)}")
print(f" Validation scans: {len(val_data)}")
return train_data, val_data
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata=None, num_epochs=10):
"""Train the model for specified number of epochs"""
best_val_acc = 0.0
# Enable mixed precision training
scaler = torch.amp.GradScaler('cuda')
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
train_pos_correct = 0
train_pos_total = 0
train_neg_correct = 0
train_neg_total = 0
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
for batch_idx, (data, target) in enumerate(train_pbar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Use mixed precision training
with torch.amp.autocast('cuda'):
output = model(data)
loss = criterion(output, target).mean()
# Scale loss and backpropagate
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.detach().item()
pred = output.argmax(dim=1, keepdim=True)
train_correct += pred.eq(target.view_as(pred)).sum().item()
train_total += target.size(0)
# Track per-class accuracy
for i in range(target.size(0)):
if target[i] == 1: # Positive class
train_pos_total += 1
if pred[i] == target[i]:
train_pos_correct += 1
else: # Negative class
train_neg_total += 1
if pred[i] == target[i]:
train_neg_correct += 1
# Clear memory after each batch
del data, target, output, loss, pred
# Calculate per-class accuracies
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
train_pbar.set_postfix({
'Loss': f'{train_loss/(batch_idx+1):.4f}',
'Acc': f'{100.*train_correct/train_total:.2f}%',
'Pos': f'{train_pos_acc:.2f}%',
'Neg': f'{train_neg_acc:.2f}%'
})
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
val_pos_correct = 0
val_pos_total = 0
val_neg_correct = 0
val_neg_total = 0
with torch.no_grad():
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
for batch_idx, (data, target) in enumerate(val_pbar):
data, target = data.to(device), target.to(device)
# Use mixed precision for validation too
with torch.amp.autocast('cuda'):
output = model(data)
loss = criterion(output, target).mean()
val_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
val_correct += pred.eq(target.view_as(pred)).sum().item()
val_total += target.size(0)
# Track per-class accuracy
for i in range(target.size(0)):
if target[i] == 1: # Positive class
val_pos_total += 1
if pred[i] == target[i]:
val_pos_correct += 1
else: # Negative class
val_neg_total += 1
if pred[i] == target[i]:
val_neg_correct += 1
# Clear memory after each batch
del data, target, output, loss, pred
# Calculate per-class accuracies
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
val_pbar.set_postfix({
'Loss': f'{val_loss/(batch_idx+1):.4f}',
'Acc': f'{100.*val_correct/val_total:.2f}%',
'Pos': f'{val_pos_acc:.2f}%',
'Neg': f'{val_neg_acc:.2f}%'
})
# Calculate final accuracies
train_acc = 100. * train_correct / train_total
val_acc = 100. * val_correct / val_total
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f' Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% (Pos: {train_pos_acc:.2f}%, Neg: {train_neg_acc:.2f}%)')
print(f' Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}% (Pos: {val_pos_acc:.2f}%, Neg: {val_neg_acc:.2f}%)')
print(f' Train samples - Pos: {train_pos_total}, Neg: {train_neg_total}')
print(f' Val samples - Pos: {val_pos_total}, Neg: {val_neg_total}')
# Step the scheduler with validation accuracy
scheduler.step(val_acc)
current_lr = optimizer.param_groups[0]['lr']
print(f' Current learning rate: {current_lr:.6f}')
# Log to wandb
wandb.log({
'epoch': epoch + 1,
'train/loss': train_loss / len(train_loader),
'train/acc': train_acc,
'train/pos_acc': train_pos_acc,
'train/neg_acc': train_neg_acc,
'train/pos_samples': train_pos_total,
'train/neg_samples': train_neg_total,
'val/loss': val_loss / len(val_loader),
'val/acc': val_acc,
'val/pos_acc': val_pos_acc,
'val/neg_acc': val_neg_acc,
'val/pos_samples': val_pos_total,
'val/neg_samples': val_neg_total,
'learning_rate': current_lr,
})
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
models_dir = os.path.expanduser('~/emblem/models')
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1:03d}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, best_model_path, train_metadata)
print(f' New best validation accuracy: {val_acc:.2f}%')
print(f' Best model saved: {best_model_path}')
wandb.log({'best_val_acc': best_val_acc, 'best_epoch': epoch + 1})
return model, val_acc, val_pos_acc, val_neg_acc
def main():
parser = argparse.ArgumentParser(description='Train ResNet18 model on grid data')
parser.add_argument('--data-dir', default='data', help='Data directory')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size (increased for better GPU utilization)')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
parser.add_argument('--num-workers', type=int, default=None, help='Number of workers for preprocessing (default: auto)')
parser.add_argument('--quick', action='store_true', help='Quick mode: use only 1%% of scans for faster testing')
parser.add_argument('--hue-jitter', type=float, default=0.1, help='Hue jitter parameter for ColorJitter (default: 0.1)')
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
parser.add_argument('--wandb-project', type=str, default='themblem', help='W&B project name (default: themblem)')
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name (default: auto-generated)')
args = parser.parse_args()
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'Using hue jitter: {args.hue_jitter}')
# Create models directory if it doesn't exist
models_dir = os.path.expanduser('~/emblem/models')
os.makedirs(models_dir, exist_ok=True)
# Load full scan data (needed for finetune mode to find old codes)
print("Loading scan data...")
full_scan_data = load_scan_data(args.data_dir)
if not full_scan_data:
print("No scan data found!")
return
# Start with full scan data, then filter
scan_data = full_scan_data.copy()
# Filter by scan IDs if provided
if args.scan_ids:
ranges = parse_ranges(args.scan_ids)
filtered_scan_data = {}
for scan_id, metadata in scan_data.items():
try:
scan_id_int = int(scan_id)
for val_range in ranges:
if in_range(scan_id_int, val_range):
filtered_scan_data[scan_id] = metadata
break
except ValueError:
# Skip non-numeric scan IDs
continue
if not filtered_scan_data:
print(f"Error: No scan IDs found in range '{args.scan_ids}'")
return
scan_data = filtered_scan_data
print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'")
# Quick mode: use only 1% of scans for faster testing
if args.quick:
original_count = len(scan_data)
scan_ids = list(scan_data.keys())
# Take 1% of scans, but at least 10 scans for meaningful training
quick_count = max(10, int(len(scan_ids) * 0.005))
selected_scan_ids = random.sample(scan_ids, quick_count)
scan_data = {scan_id: scan_data[scan_id] for scan_id in selected_scan_ids}
print(f"Quick mode enabled: Using {len(scan_data)} scans out of {original_count} (1%)")
# Split into train and validation
print("Splitting data into train and validation...")
# In finetune mode (when model is specified), don't avoid overlap between train and val
avoid_overlap = args.model is None
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
# In finetune mode, add old codes to prevent forgetting
old_codes_info = None
if args.model:
# Get distinct codes from current finetune dataset
finetune_codes = set()
for scan_id, metadata in scan_data.items():
code = metadata.get('code')
if code:
finetune_codes.add(code)
num_finetune_codes = len(finetune_codes)
print(f"Finetune mode: Found {num_finetune_codes} distinct codes in finetune dataset")
if num_finetune_codes > 0:
# Find old codes (codes in full dataset but not in finetune dataset)
old_codes = set()
for scan_id, metadata in full_scan_data.items():
code = metadata.get('code')
if code and code not in finetune_codes:
old_codes.add(code)
print(f"Found {len(old_codes)} distinct old codes in full dataset")
if len(old_codes) > 0:
# Select the same number of distinct random old codes
num_old_codes_to_select = min(num_finetune_codes, len(old_codes))
selected_old_codes = random.sample(list(old_codes), num_old_codes_to_select)
print(f"Selected {num_old_codes_to_select} old codes to prevent forgetting: {selected_old_codes}")
# Add all scans for selected old codes to training set
old_scans_added = 0
for scan_id, metadata in full_scan_data.items():
code = metadata.get('code')
if code and code in selected_old_codes:
# Only add if not already in train_data (avoid duplicates)
if scan_id not in train_data:
train_data[scan_id] = metadata
old_scans_added += 1
print(f"Added {old_scans_added} scans from old codes to training set")
print(f"Training set size: {len(train_data)} scans (including {old_scans_added} from old codes)")
old_codes_info = {
'num_old_codes': num_old_codes_to_select,
'old_codes': selected_old_codes,
'old_scans_added': old_scans_added
}
else:
print("No old codes found in full dataset")
else:
print("No codes found in finetune dataset, skipping old code selection")
# Define transforms
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create datasets
print("Creating training dataset...")
train_dataset = GridDataset(train_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
print(f"Training samples: {len(train_dataset)}")
print("Creating validation dataset...")
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
print(f"Validation samples: {len(val_dataset)}")
# Validate datasets have samples
if len(train_dataset) == 0:
print("Error: Training dataset is empty. Cannot proceed with training.")
print("This usually means all training scans are missing sbs.jpg files.")
return
if len(val_dataset) == 0:
print("Error: Validation dataset is empty. Cannot proceed with training.")
print("This usually means all validation scans are missing sbs.jpg files.")
return
# Create data loaders
print("Creating data loaders...")
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
pin_memory=False,
persistent_workers=True,
prefetch_factor=2
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=8,
pin_memory=False,
persistent_workers=True,
prefetch_factor=2
)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
# Create model
print("Creating ResNet18 model...")
models_dir = os.path.expanduser('~/emblem/models')
if args.model:
# Load specified model for finetuning
model_path = os.path.expanduser(args.model)
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
return
print(f"Loading model from {model_path}...")
model, _ = load_model(model_path)
else:
# Load default model
default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
print(f"Loading default model from {default_model_path}...")
model, _ = load_model(default_model_path)
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
if hasattr(model, '_orig_mod'):
model = model._orig_mod
# Modify final layer for binary classification
# model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)
print(f"Model loaded and moved to {device}")
# Define loss function and optimizer
print("Setting up loss function and optimizer...")
# Use FocalLoss with 0.99:0.01 weights for positive/negative classes
pos_weight = 0.99
criterion = FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device))
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Create ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(
optimizer,
mode='max',
factor=args.scheduler_factor,
patience=args.scheduler_patience,
min_lr=args.scheduler_min_lr
)
print(f"Using Adam optimizer with lr={args.lr}")
print(f"Using FocalLoss with weights: Negative={1.0 - pos_weight:.2f}, Positive={pos_weight:.2f}")
print(f"Using ReduceLROnPlateau scheduler with factor={args.scheduler_factor}, patience={args.scheduler_patience}, min_lr={args.scheduler_min_lr}")
# Collect training metadata
train_scan_ids = list(train_data.keys())
train_codes = set()
for scan_id, metadata in train_data.items():
code = metadata.get('code')
if code:
train_codes.add(code)
train_metadata = {
'train_scan_ids': train_scan_ids,
'train_codes': list(train_codes),
'hue_jitter': args.hue_jitter,
'quick_mode': args.quick
}
print(f"Training metadata:")
print(f" Train scan IDs: {len(train_scan_ids)}")
print(f" Train codes: {len(train_codes)}")
print(f" Hue jitter: {args.hue_jitter}")
print(f" Quick mode: {args.quick}")
# Initialize wandb
wandb.login(key='ec22e6ed1ed9891779d57da600f889294f83e41d')
wandb_config = {
'batch_size': args.batch_size,
'lr': args.lr,
'epochs': args.epochs,
'hue_jitter': args.hue_jitter,
'scheduler_patience': args.scheduler_patience,
'scheduler_factor': args.scheduler_factor,
'scheduler_min_lr': args.scheduler_min_lr,
'quick_mode': args.quick,
'train_scans': len(train_scan_ids),
'val_scans': len(val_data),
'train_samples': len(train_dataset),
'val_samples': len(val_dataset),
'model_path': args.model if args.model else 'default',
'finetune_mode': args.model is not None,
}
if old_codes_info:
wandb_config['old_codes_count'] = old_codes_info['num_old_codes']
wandb_config['old_scans_added'] = old_codes_info['old_scans_added']
wandb_name = args.wandb_name or f"train2-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
wandb.init(
project=args.wandb_project,
name=wandb_name,
config=wandb_config,
)
print(f"W&B logging enabled: project={args.wandb_project}, run={wandb_name}")
# Train the model
print("Starting training...")
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
# Save final model
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
models_dir = os.path.expanduser('~/emblem/models')
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs:03d}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, final_model_path, train_metadata)
print(f"Training completed! Final model saved: {final_model_path}")
# Finish wandb run
wandb.finish()
if __name__ == '__main__':
main()