#!/usr/bin/env python3 ''' all children of data/scans are scan_ids in each scan_id, there is a file called "metadata.json" labels key has 'pos' or 'neg' it also has 'code' key which is a string For the largest 20% scan_ids, we use as validation set for all codes appearing in the validation set in all scans, we also use as validation set the rest are training set preprocess the sbs.jpg for both train and validation: 1. split left and right half 2. crop 3x3 of left 3. crop 3x3 of right 4. for each pair of crop, concat into grid-i-j.jpg, and apply some colorjitter 5. all the grid-i-j.jpg are used with the label load a resnet18 model, and train it on the training set train for 10 epochs, and print accuracy on validation set each epoch save the model in the end. ''' import os import json import random import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import ReduceLROnPlateau from PIL import Image import numpy as np from tqdm import tqdm import argparse from collections import defaultdict import multiprocessing as mp from functools import partial from datetime import datetime from common import * def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1): """Process a single scan to create grid files and metadata Returns: (sample_metadata, reason) where reason is None on success, or a string describing the failure""" scan_id, metadata = scan_item sample_metadata = [] sbs_path = os.path.join(data_dir, 'scans', scan_id, 'sbs.jpg') if not os.path.exists(sbs_path): return sample_metadata, 'sbs.jpg missing' # Create grid directory if it doesn't exist grid_dir = os.path.join(data_dir, 'scans', scan_id, 'grids') os.makedirs(grid_dir, exist_ok=True) # Check if all grid files already exist all_grids_exist = True for i in range(3): for j in range(3): grid_filename = f'grid-{i}-{j}.jpg' grid_path = os.path.join(grid_dir, grid_filename) if not os.path.exists(grid_path): all_grids_exist = False break if not all_grids_exist: break # If all grid files exist, just return metadata if all_grids_exist: for i in range(3): for j in range(3): grid_filename = f'grid-{i}-{j}.jpg' grid_path = os.path.join(grid_dir, grid_filename) label = 1 if 'pos' in metadata['labels'] else 0 sample_metadata.append({ 'scan_id': scan_id, 'grid_path': grid_path, 'grid_i': i, 'grid_j': j, 'label': label }) return sample_metadata, None # Load the side-by-side image try: sbs_img = Image.open(sbs_path).convert('RGB') except Exception as e: return sample_metadata, f'sbs.jpg unreadable: {str(e)}' width, height = sbs_img.size # Check if image is too small if width < 6 or height < 3: return sample_metadata, f'sbs.jpg too small: {width}x{height}' # Calculate crop dimensions crop_width = width // 6 # width/2 / 3 crop_height = height // 3 if crop_width <= 0 or crop_height <= 0: return sample_metadata, f'invalid crop dimensions: {crop_width}x{crop_height}' # Generate all 3x3 grid combinations try: for i in range(3): for j in range(3): # Calculate crop positions directly from original image left_x = i * crop_width right_x = (i + 3) * crop_width # Skip middle section y = j * crop_height # Crop directly from original image left_crop = sbs_img.crop((left_x, y, left_x + crop_width, y + crop_height)) right_crop = sbs_img.crop((right_x, y, right_x + crop_width, y + crop_height)) # Apply color jitter only to left crop color_jitter = transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=hue_jitter ) left_crop = color_jitter(left_crop) # Concatenate left and right crops horizontally grid_img = Image.new('RGB', (crop_width * 2, crop_height)) grid_img.paste(left_crop, (0, 0)) grid_img.paste(right_crop, (crop_width, 0)) # Save grid image grid_filename = f'grid-{i}-{j}.jpg' grid_path = os.path.join(grid_dir, grid_filename) grid_img.save(grid_path, 'JPEG', quality=95) # Store metadata label = 1 if 'pos' in metadata['labels'] else 0 sample_metadata.append({ 'scan_id': scan_id, 'grid_path': grid_path, 'grid_i': i, 'grid_j': j, 'label': label }) except Exception as e: return sample_metadata, f'error creating grids: {str(e)}' return sample_metadata, None class GridDataset(Dataset): def __init__(self, scan_data, data_dir='data', transform=None, num_workers=None, hue_jitter=0.1): self.scan_data = scan_data self.data_dir = data_dir self.transform = transform self.sample_metadata = [] self.hue_jitter = hue_jitter if num_workers is None: num_workers = min(mp.cpu_count(), 8) # Limit to 8 workers max print(f"Preprocessing {len(scan_data)} scans to create grid files using {num_workers} workers...") # Use multiprocessing to create grid files with mp.Pool(processes=num_workers) as pool: # Process all scans in parallel with hue_jitter parameter process_func = partial(process_scan_grid, data_dir=self.data_dir, hue_jitter=self.hue_jitter) results = list(tqdm( pool.imap(process_func, scan_data.items()), total=len(scan_data), desc="Creating grid files" )) # Collect all sample metadata and statistics stats = defaultdict(int) for result in results: if isinstance(result, tuple) and len(result) == 2: metadata_list, reason = result self.sample_metadata.extend(metadata_list) if reason is not None: stats[reason] += 1 else: # Backward compatibility - old format without reason self.sample_metadata.extend(result) print(f"Created {len(self.sample_metadata)} grid files") if stats: print("\nStatistics on why grid files were not created:") for reason, count in sorted(stats.items(), key=lambda x: -x[1]): print(f" {reason}: {count} scans") def __len__(self): return len(self.sample_metadata) def __getitem__(self, idx): # Get sample metadata metadata = self.sample_metadata[idx] # Load the pre-saved grid image directly grid_img = Image.open(metadata['grid_path']).convert('RGB') # Apply transforms if self.transform: grid_img = self.transform(grid_img) return grid_img, torch.tensor(metadata['label'], dtype=torch.long) def load_scan_data(data_dir): """Load all scan metadata and organize by scan_id""" scan_data = {} scans_dir = os.path.join(data_dir, 'scans') if not os.path.exists(scans_dir): raise FileNotFoundError(f"Scans directory not found: {scans_dir}") scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))] print(f"Loading metadata for {len(scan_ids)} scans...") for scan_id in tqdm(scan_ids, desc="Loading scan metadata"): metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json') if os.path.exists(metadata_path): try: with open(metadata_path, 'r') as f: metadata = json.load(f) scan_data[scan_id] = metadata except Exception as e: print(f"Error loading metadata for {scan_id}: {e}") return scan_data def split_train_val(scan_data): """Split data into train and validation sets""" scan_ids = list(scan_data.keys()) print("Splitting data into train and validation sets...") # Select 10% random scan_ids for initial validation set val_size = int(len(scan_ids) * 0.1) initial_val_scan_ids = random.sample(scan_ids, val_size) print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") # Get all (code, labels) combinations that appear in the initial validation set val_code_labels = set() for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): if scan_id in scan_data: code = scan_data[scan_id].get('code') labels = tuple(sorted(scan_data[scan_id].get('labels', []))) if code: val_code_labels.add((code, labels)) print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set") # Find all scans in train that have matching (code, labels) combinations and move them to validation additional_val_scan_ids = set() train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"): if scan_id in scan_data: code = scan_data[scan_id].get('code') labels = tuple(sorted(scan_data[scan_id].get('labels', []))) # If (code, labels) combination matches, move to validation if code and (code, labels) in val_code_labels: additional_val_scan_ids.add(scan_id) # Combine validation sets all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids # Create train and validation data dictionaries train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids} print(f"Final split:") print(f" Total scans: {len(scan_data)}") print(f" Training scans: {len(train_data)}") print(f" Validation scans: {len(val_data)}") return train_data, val_data def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata=None, num_epochs=10): """Train the model for specified number of epochs""" best_val_acc = 0.0 # Enable mixed precision training scaler = torch.amp.GradScaler('cuda') for epoch in range(num_epochs): # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 train_pos_correct = 0 train_pos_total = 0 train_neg_correct = 0 train_neg_total = 0 train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]') for batch_idx, (data, target) in enumerate(train_pbar): data, target = data.to(device), target.to(device) optimizer.zero_grad() # Use mixed precision training with torch.amp.autocast('cuda'): output = model(data) loss = criterion(output, target).mean() # Scale loss and backpropagate scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() train_loss += loss.detach().item() pred = output.argmax(dim=1, keepdim=True) train_correct += pred.eq(target.view_as(pred)).sum().item() train_total += target.size(0) # Track per-class accuracy for i in range(target.size(0)): if target[i] == 1: # Positive class train_pos_total += 1 if pred[i] == target[i]: train_pos_correct += 1 else: # Negative class train_neg_total += 1 if pred[i] == target[i]: train_neg_correct += 1 # Clear memory after each batch del data, target, output, loss, pred # Calculate per-class accuracies train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1) train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1) train_pbar.set_postfix({ 'Loss': f'{train_loss/(batch_idx+1):.4f}', 'Acc': f'{100.*train_correct/train_total:.2f}%', 'Pos': f'{train_pos_acc:.2f}%', 'Neg': f'{train_neg_acc:.2f}%' }) # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 val_pos_correct = 0 val_pos_total = 0 val_neg_correct = 0 val_neg_total = 0 with torch.no_grad(): val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]') for batch_idx, (data, target) in enumerate(val_pbar): data, target = data.to(device), target.to(device) # Use mixed precision for validation too with torch.amp.autocast('cuda'): output = model(data) loss = criterion(output, target).mean() val_loss += loss.item() pred = output.argmax(dim=1, keepdim=True) val_correct += pred.eq(target.view_as(pred)).sum().item() val_total += target.size(0) # Track per-class accuracy for i in range(target.size(0)): if target[i] == 1: # Positive class val_pos_total += 1 if pred[i] == target[i]: val_pos_correct += 1 else: # Negative class val_neg_total += 1 if pred[i] == target[i]: val_neg_correct += 1 # Clear memory after each batch del data, target, output, loss, pred # Calculate per-class accuracies val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1) val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1) val_pbar.set_postfix({ 'Loss': f'{val_loss/(batch_idx+1):.4f}', 'Acc': f'{100.*val_correct/val_total:.2f}%', 'Pos': f'{val_pos_acc:.2f}%', 'Neg': f'{val_neg_acc:.2f}%' }) # Calculate final accuracies train_acc = 100. * train_correct / train_total val_acc = 100. * val_correct / val_total train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1) train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1) val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1) val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1) print(f'Epoch {epoch+1}/{num_epochs}:') print(f' Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% (Pos: {train_pos_acc:.2f}%, Neg: {train_neg_acc:.2f}%)') print(f' Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}% (Pos: {val_pos_acc:.2f}%, Neg: {val_neg_acc:.2f}%)') print(f' Train samples - Pos: {train_pos_total}, Neg: {train_neg_total}') print(f' Val samples - Pos: {val_pos_total}, Neg: {val_neg_total}') # Step the scheduler with validation accuracy scheduler.step(val_acc) current_lr = optimizer.param_groups[0]['lr'] print(f' Current learning rate: {current_lr:.6f}') # Save best model if val_acc > best_val_acc: best_val_acc = val_acc timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" best_model_path = f'models/best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' save_model(model, transform, best_model_path, train_metadata) print(f' New best validation accuracy: {val_acc:.2f}%') print(f' Best model saved: {best_model_path}') return model, val_acc, val_pos_acc, val_neg_acc def main(): parser = argparse.ArgumentParser(description='Train ResNet18 model on grid data') parser.add_argument('--data-dir', default='data', help='Data directory') parser.add_argument('--batch-size', type=int, default=64, help='Batch size (increased for better GPU utilization)') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') parser.add_argument('--num-workers', type=int, default=None, help='Number of workers for preprocessing (default: auto)') parser.add_argument('--quick', action='store_true', help='Quick mode: use only 1%% of scans for faster testing') parser.add_argument('--hue-jitter', type=float, default=0.1, help='Hue jitter parameter for ColorJitter (default: 0.1)') parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)') parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)') parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)') args = parser.parse_args() # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') print(f'Using hue jitter: {args.hue_jitter}') # Create models directory if it doesn't exist os.makedirs('models', exist_ok=True) # Load scan data print("Loading scan data...") scan_data = load_scan_data(args.data_dir) if not scan_data: print("No scan data found!") return # Quick mode: use only 1% of scans for faster testing if args.quick: original_count = len(scan_data) scan_ids = list(scan_data.keys()) # Take 1% of scans, but at least 10 scans for meaningful training quick_count = max(10, int(len(scan_ids) * 0.005)) selected_scan_ids = random.sample(scan_ids, quick_count) scan_data = {scan_id: scan_data[scan_id] for scan_id in selected_scan_ids} print(f"Quick mode enabled: Using {len(scan_data)} scans out of {original_count} (1%)") # Split into train and validation print("Splitting data into train and validation...") train_data, val_data = split_train_val(scan_data) # Define transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Create datasets print("Creating training dataset...") train_dataset = GridDataset(train_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter) print(f"Training samples: {len(train_dataset)}") print("Creating validation dataset...") val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter) print(f"Validation samples: {len(val_dataset)}") # Create data loaders print("Creating data loaders...") train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False, persistent_workers=True, prefetch_factor=2 ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=False, persistent_workers=True, prefetch_factor=2 ) print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") # Create model print("Creating ResNet18 model...") # model = torchvision.models.resnet18(pretrained=True) model, _ = load_model('models/gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt') # Modify final layer for binary classification # model.fc = nn.Linear(model.fc.in_features, 2) model = model.to(device) print(f"Model created and moved to {device}") # Define loss function and optimizer print("Setting up loss function and optimizer...") # Use FocalLoss with 0.99:0.01 weights for positive/negative classes pos_weight = 0.99 criterion = FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device)) optimizer = optim.Adam(model.parameters(), lr=args.lr) # Create ReduceLROnPlateau scheduler scheduler = ReduceLROnPlateau( optimizer, mode='max', factor=args.scheduler_factor, patience=args.scheduler_patience, min_lr=args.scheduler_min_lr ) print(f"Using Adam optimizer with lr={args.lr}") print(f"Using FocalLoss with weights: Negative={1.0 - pos_weight:.2f}, Positive={pos_weight:.2f}") print(f"Using ReduceLROnPlateau scheduler with factor={args.scheduler_factor}, patience={args.scheduler_patience}, min_lr={args.scheduler_min_lr}") # Collect training metadata train_scan_ids = list(train_data.keys()) train_codes = set() for scan_id, metadata in train_data.items(): code = metadata.get('code') if code: train_codes.add(code) train_metadata = { 'train_scan_ids': train_scan_ids, 'train_codes': list(train_codes), 'hue_jitter': args.hue_jitter, 'quick_mode': args.quick } print(f"Training metadata:") print(f" Train scan IDs: {len(train_scan_ids)}") print(f" Train codes: {len(train_codes)}") print(f" Hue jitter: {args.hue_jitter}") print(f" Quick mode: {args.quick}") # Train the model print("Starting training...") model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs) # Save final model timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" final_model_path = f'models/final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' save_model(model, transform, final_model_path, train_metadata) print(f"Training completed! Final model saved: {final_model_path}") if __name__ == '__main__': main()