744 lines
31 KiB
Python
Executable File
744 lines
31 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
'''
|
|
all children of data/scans are scan_ids
|
|
in each scan_id, there is a file called "metadata.json"
|
|
labels key has 'pos' or 'neg'
|
|
it also has 'code' key which is a string
|
|
|
|
For the largest 20% scan_ids, we use as validation set
|
|
for all codes appearing in the validation set in all scans, we also use as validation set
|
|
|
|
the rest are training set
|
|
|
|
preprocess the sbs.jpg for both train and validation:
|
|
1. split left and right half
|
|
2. crop 3x3 of left
|
|
3. crop 3x3 of right
|
|
4. for each pair of crop, concat into grid-i-j.jpg, and apply some colorjitter
|
|
5. all the grid-i-j.jpg are used with the label
|
|
|
|
load a resnet18 model, and train it on the training set
|
|
train for 10 epochs, and print accuracy on validation set each epoch
|
|
|
|
save the model in the end.
|
|
'''
|
|
|
|
import os
|
|
import json
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from PIL import Image
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import argparse
|
|
from collections import defaultdict
|
|
import multiprocessing as mp
|
|
from functools import partial
|
|
from datetime import datetime
|
|
from common import *
|
|
import wandb
|
|
|
|
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
|
|
"""Process a single scan to create grid files and metadata
|
|
Returns: (sample_metadata, reason) where reason is None on success, or a string describing the failure"""
|
|
scan_id, metadata = scan_item
|
|
sample_metadata = []
|
|
|
|
sbs_path = os.path.join(data_dir, 'scans', scan_id, 'sbs.jpg')
|
|
if not os.path.exists(sbs_path):
|
|
return sample_metadata, 'sbs.jpg missing'
|
|
|
|
# Create grid directory if it doesn't exist
|
|
grid_dir = os.path.join(data_dir, 'scans', scan_id, 'grids')
|
|
os.makedirs(grid_dir, exist_ok=True)
|
|
|
|
# Check if all grid files already exist
|
|
all_grids_exist = True
|
|
for i in range(3):
|
|
for j in range(3):
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
if not os.path.exists(grid_path):
|
|
all_grids_exist = False
|
|
break
|
|
if not all_grids_exist:
|
|
break
|
|
|
|
# If all grid files exist, just return metadata
|
|
if all_grids_exist:
|
|
for i in range(3):
|
|
for j in range(3):
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
label = 1 if 'pos' in metadata['labels'] else 0
|
|
sample_metadata.append({
|
|
'scan_id': scan_id,
|
|
'grid_path': grid_path,
|
|
'grid_i': i,
|
|
'grid_j': j,
|
|
'label': label
|
|
})
|
|
return sample_metadata, None
|
|
|
|
# Load the side-by-side image
|
|
try:
|
|
sbs_img = Image.open(sbs_path).convert('RGB')
|
|
except Exception as e:
|
|
return sample_metadata, f'sbs.jpg unreadable: {str(e)}'
|
|
width, height = sbs_img.size
|
|
|
|
# Check if image is too small
|
|
if width < 6 or height < 3:
|
|
return sample_metadata, f'sbs.jpg too small: {width}x{height}'
|
|
|
|
# Calculate crop dimensions
|
|
crop_width = width // 6 # width/2 / 3
|
|
crop_height = height // 3
|
|
|
|
if crop_width <= 0 or crop_height <= 0:
|
|
return sample_metadata, f'invalid crop dimensions: {crop_width}x{crop_height}'
|
|
|
|
# Generate all 3x3 grid combinations
|
|
try:
|
|
for i in range(3):
|
|
for j in range(3):
|
|
# Calculate crop positions directly from original image
|
|
left_x = i * crop_width
|
|
right_x = (i + 3) * crop_width # Skip middle section
|
|
y = j * crop_height
|
|
|
|
# Crop directly from original image
|
|
left_crop = sbs_img.crop((left_x, y, left_x + crop_width, y + crop_height))
|
|
right_crop = sbs_img.crop((right_x, y, right_x + crop_width, y + crop_height))
|
|
|
|
# Apply color jitter only to left crop
|
|
color_jitter = transforms.ColorJitter(
|
|
brightness=0.2,
|
|
contrast=0.2,
|
|
saturation=0.2,
|
|
hue=hue_jitter
|
|
)
|
|
left_crop = color_jitter(left_crop)
|
|
|
|
# Concatenate left and right crops horizontally
|
|
grid_img = Image.new('RGB', (crop_width * 2, crop_height))
|
|
grid_img.paste(left_crop, (0, 0))
|
|
grid_img.paste(right_crop, (crop_width, 0))
|
|
|
|
# Save grid image
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
grid_img.save(grid_path, 'JPEG', quality=95)
|
|
|
|
# Store metadata
|
|
label = 1 if 'pos' in metadata['labels'] else 0
|
|
sample_metadata.append({
|
|
'scan_id': scan_id,
|
|
'grid_path': grid_path,
|
|
'grid_i': i,
|
|
'grid_j': j,
|
|
'label': label
|
|
})
|
|
except Exception as e:
|
|
return sample_metadata, f'error creating grids: {str(e)}'
|
|
|
|
return sample_metadata, None
|
|
|
|
class GridDataset(Dataset):
|
|
def __init__(self, scan_data, data_dir='data', transform=None, num_workers=None, hue_jitter=0.1):
|
|
self.scan_data = scan_data
|
|
self.data_dir = data_dir
|
|
self.transform = transform
|
|
self.sample_metadata = []
|
|
self.hue_jitter = hue_jitter
|
|
|
|
if num_workers is None:
|
|
num_workers = min(mp.cpu_count(), 8) # Limit to 8 workers max
|
|
|
|
print(f"Preprocessing {len(scan_data)} scans to create grid files using {num_workers} workers...")
|
|
|
|
# Use multiprocessing to create grid files
|
|
with mp.Pool(processes=num_workers) as pool:
|
|
# Process all scans in parallel with hue_jitter parameter
|
|
process_func = partial(process_scan_grid, data_dir=self.data_dir, hue_jitter=self.hue_jitter)
|
|
results = list(tqdm(
|
|
pool.imap(process_func, scan_data.items()),
|
|
total=len(scan_data),
|
|
desc="Creating grid files"
|
|
))
|
|
|
|
# Collect all sample metadata and statistics
|
|
stats = defaultdict(int)
|
|
for result in results:
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
metadata_list, reason = result
|
|
self.sample_metadata.extend(metadata_list)
|
|
if reason is not None:
|
|
stats[reason] += 1
|
|
else:
|
|
# Backward compatibility - old format without reason
|
|
self.sample_metadata.extend(result)
|
|
|
|
print(f"Created {len(self.sample_metadata)} grid files")
|
|
if stats:
|
|
print("\nStatistics on why grid files were not created:")
|
|
for reason, count in sorted(stats.items(), key=lambda x: -x[1]):
|
|
print(f" {reason}: {count} scans")
|
|
|
|
def __len__(self):
|
|
return len(self.sample_metadata)
|
|
|
|
def __getitem__(self, idx):
|
|
# Get sample metadata
|
|
metadata = self.sample_metadata[idx]
|
|
|
|
# Load the pre-saved grid image directly
|
|
grid_img = Image.open(metadata['grid_path']).convert('RGB')
|
|
|
|
# Apply transforms
|
|
if self.transform:
|
|
grid_img = self.transform(grid_img)
|
|
|
|
return grid_img, torch.tensor(metadata['label'], dtype=torch.long)
|
|
|
|
def load_scan_data(data_dir):
|
|
"""Load all scan metadata and organize by scan_id"""
|
|
scan_data = {}
|
|
scans_dir = os.path.join(data_dir, 'scans')
|
|
|
|
if not os.path.exists(scans_dir):
|
|
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
|
|
|
|
scan_ids = [d for d in os.listdir(scans_dir)
|
|
if os.path.isdir(os.path.join(scans_dir, d))]
|
|
|
|
print(f"Loading metadata for {len(scan_ids)} scans...")
|
|
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
|
|
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
scan_data[scan_id] = metadata
|
|
except Exception as e:
|
|
print(f"Error loading metadata for {scan_id}: {e}")
|
|
|
|
return scan_data
|
|
|
|
def split_train_val(scan_data, avoid_overlap=True):
|
|
"""Split data into train and validation sets
|
|
|
|
Args:
|
|
scan_data: Dictionary of scan_id -> metadata
|
|
avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap.
|
|
If False, use simple random split without overlap avoidance.
|
|
"""
|
|
scan_ids = list(scan_data.keys())
|
|
|
|
print("Splitting data into train and validation sets...")
|
|
|
|
# Select 10% random scan_ids for initial validation set
|
|
val_size = int(len(scan_ids) * 0.1)
|
|
initial_val_scan_ids = random.sample(scan_ids, val_size)
|
|
|
|
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
|
|
|
if avoid_overlap:
|
|
# Get all (code, labels) combinations that appear in the initial validation set
|
|
val_code_labels = set()
|
|
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
|
if scan_id in scan_data:
|
|
code = scan_data[scan_id].get('code')
|
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
|
if code:
|
|
val_code_labels.add((code, labels))
|
|
|
|
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
|
|
|
|
# Find all scans in train that have matching (code, labels) combinations and move them to validation
|
|
additional_val_scan_ids = set()
|
|
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
|
|
|
|
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
|
|
if scan_id in scan_data:
|
|
code = scan_data[scan_id].get('code')
|
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
|
# If (code, labels) combination matches, move to validation
|
|
if code and (code, labels) in val_code_labels:
|
|
additional_val_scan_ids.add(scan_id)
|
|
|
|
# Combine validation sets
|
|
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
|
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
|
else:
|
|
# Simple split without overlap avoidance
|
|
all_val_scan_ids = set(initial_val_scan_ids)
|
|
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
|
|
|
# Create train and validation data dictionaries
|
|
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
|
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
|
|
|
|
print(f"Final split:")
|
|
print(f" Total scans: {len(scan_data)}")
|
|
print(f" Training scans: {len(train_data)}")
|
|
print(f" Validation scans: {len(val_data)}")
|
|
|
|
return train_data, val_data
|
|
|
|
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata=None, num_epochs=10):
|
|
"""Train the model for specified number of epochs"""
|
|
best_val_acc = 0.0
|
|
|
|
# Enable mixed precision training
|
|
scaler = torch.amp.GradScaler('cuda')
|
|
|
|
for epoch in range(num_epochs):
|
|
# Training phase
|
|
model.train()
|
|
train_loss = 0.0
|
|
train_correct = 0
|
|
train_total = 0
|
|
train_pos_correct = 0
|
|
train_pos_total = 0
|
|
train_neg_correct = 0
|
|
train_neg_total = 0
|
|
|
|
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
|
|
for batch_idx, (data, target) in enumerate(train_pbar):
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Use mixed precision training
|
|
with torch.amp.autocast('cuda'):
|
|
output = model(data)
|
|
loss = criterion(output, target).mean()
|
|
|
|
# Scale loss and backpropagate
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
train_loss += loss.detach().item()
|
|
pred = output.argmax(dim=1, keepdim=True)
|
|
train_correct += pred.eq(target.view_as(pred)).sum().item()
|
|
train_total += target.size(0)
|
|
|
|
# Track per-class accuracy
|
|
for i in range(target.size(0)):
|
|
if target[i] == 1: # Positive class
|
|
train_pos_total += 1
|
|
if pred[i] == target[i]:
|
|
train_pos_correct += 1
|
|
else: # Negative class
|
|
train_neg_total += 1
|
|
if pred[i] == target[i]:
|
|
train_neg_correct += 1
|
|
|
|
# Clear memory after each batch
|
|
del data, target, output, loss, pred
|
|
|
|
# Calculate per-class accuracies
|
|
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
|
|
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
|
|
|
|
train_pbar.set_postfix({
|
|
'Loss': f'{train_loss/(batch_idx+1):.4f}',
|
|
'Acc': f'{100.*train_correct/train_total:.2f}%',
|
|
'Pos': f'{train_pos_acc:.2f}%',
|
|
'Neg': f'{train_neg_acc:.2f}%'
|
|
})
|
|
|
|
# Validation phase
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_correct = 0
|
|
val_total = 0
|
|
val_pos_correct = 0
|
|
val_pos_total = 0
|
|
val_neg_correct = 0
|
|
val_neg_total = 0
|
|
|
|
with torch.no_grad():
|
|
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
|
|
for batch_idx, (data, target) in enumerate(val_pbar):
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
# Use mixed precision for validation too
|
|
with torch.amp.autocast('cuda'):
|
|
output = model(data)
|
|
loss = criterion(output, target).mean()
|
|
|
|
val_loss += loss.item()
|
|
pred = output.argmax(dim=1, keepdim=True)
|
|
val_correct += pred.eq(target.view_as(pred)).sum().item()
|
|
val_total += target.size(0)
|
|
|
|
# Track per-class accuracy
|
|
for i in range(target.size(0)):
|
|
if target[i] == 1: # Positive class
|
|
val_pos_total += 1
|
|
if pred[i] == target[i]:
|
|
val_pos_correct += 1
|
|
else: # Negative class
|
|
val_neg_total += 1
|
|
if pred[i] == target[i]:
|
|
val_neg_correct += 1
|
|
|
|
# Clear memory after each batch
|
|
del data, target, output, loss, pred
|
|
|
|
# Calculate per-class accuracies
|
|
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
|
|
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
|
|
|
|
val_pbar.set_postfix({
|
|
'Loss': f'{val_loss/(batch_idx+1):.4f}',
|
|
'Acc': f'{100.*val_correct/val_total:.2f}%',
|
|
'Pos': f'{val_pos_acc:.2f}%',
|
|
'Neg': f'{val_neg_acc:.2f}%'
|
|
})
|
|
|
|
# Calculate final accuracies
|
|
train_acc = 100. * train_correct / train_total
|
|
val_acc = 100. * val_correct / val_total
|
|
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
|
|
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
|
|
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
|
|
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
|
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}:')
|
|
print(f' Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% (Pos: {train_pos_acc:.2f}%, Neg: {train_neg_acc:.2f}%)')
|
|
print(f' Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}% (Pos: {val_pos_acc:.2f}%, Neg: {val_neg_acc:.2f}%)')
|
|
print(f' Train samples - Pos: {train_pos_total}, Neg: {train_neg_total}')
|
|
print(f' Val samples - Pos: {val_pos_total}, Neg: {val_neg_total}')
|
|
|
|
# Step the scheduler with validation accuracy
|
|
scheduler.step(val_acc)
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
print(f' Current learning rate: {current_lr:.6f}')
|
|
|
|
# Log to wandb
|
|
wandb.log({
|
|
'epoch': epoch + 1,
|
|
'train/loss': train_loss / len(train_loader),
|
|
'train/acc': train_acc,
|
|
'train/pos_acc': train_pos_acc,
|
|
'train/neg_acc': train_neg_acc,
|
|
'train/pos_samples': train_pos_total,
|
|
'train/neg_samples': train_neg_total,
|
|
'val/loss': val_loss / len(val_loader),
|
|
'val/acc': val_acc,
|
|
'val/pos_acc': val_pos_acc,
|
|
'val/neg_acc': val_neg_acc,
|
|
'val/pos_samples': val_pos_total,
|
|
'val/neg_samples': val_neg_total,
|
|
'learning_rate': current_lr,
|
|
})
|
|
|
|
# Save best model
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
mode_suffix = "_quick" if args.quick else ""
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1:03d}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
|
save_model(model, transform, best_model_path, train_metadata)
|
|
print(f' New best validation accuracy: {val_acc:.2f}%')
|
|
print(f' Best model saved: {best_model_path}')
|
|
wandb.log({'best_val_acc': best_val_acc, 'best_epoch': epoch + 1})
|
|
|
|
return model, val_acc, val_pos_acc, val_neg_acc
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train ResNet18 model on grid data')
|
|
parser.add_argument('--data-dir', default='data', help='Data directory')
|
|
parser.add_argument('--batch-size', type=int, default=64, help='Batch size (increased for better GPU utilization)')
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
|
|
parser.add_argument('--num-workers', type=int, default=None, help='Number of workers for preprocessing (default: auto)')
|
|
parser.add_argument('--quick', action='store_true', help='Quick mode: use only 1%% of scans for faster testing')
|
|
parser.add_argument('--hue-jitter', type=float, default=0.1, help='Hue jitter parameter for ColorJitter (default: 0.1)')
|
|
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
|
|
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
|
|
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
|
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
|
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
|
|
parser.add_argument('--wandb-project', type=str, default='themblem', help='W&B project name (default: themblem)')
|
|
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name (default: auto-generated)')
|
|
args = parser.parse_args()
|
|
|
|
# Set device
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Using device: {device}')
|
|
print(f'Using hue jitter: {args.hue_jitter}')
|
|
|
|
# Create models directory if it doesn't exist
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
os.makedirs(models_dir, exist_ok=True)
|
|
|
|
# Load full scan data (needed for finetune mode to find old codes)
|
|
print("Loading scan data...")
|
|
full_scan_data = load_scan_data(args.data_dir)
|
|
|
|
if not full_scan_data:
|
|
print("No scan data found!")
|
|
return
|
|
|
|
# Start with full scan data, then filter
|
|
scan_data = full_scan_data.copy()
|
|
|
|
# Filter by scan IDs if provided
|
|
if args.scan_ids:
|
|
ranges = parse_ranges(args.scan_ids)
|
|
filtered_scan_data = {}
|
|
for scan_id, metadata in scan_data.items():
|
|
try:
|
|
scan_id_int = int(scan_id)
|
|
for val_range in ranges:
|
|
if in_range(scan_id_int, val_range):
|
|
filtered_scan_data[scan_id] = metadata
|
|
break
|
|
except ValueError:
|
|
# Skip non-numeric scan IDs
|
|
continue
|
|
if not filtered_scan_data:
|
|
print(f"Error: No scan IDs found in range '{args.scan_ids}'")
|
|
return
|
|
scan_data = filtered_scan_data
|
|
print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'")
|
|
|
|
# Quick mode: use only 1% of scans for faster testing
|
|
if args.quick:
|
|
original_count = len(scan_data)
|
|
scan_ids = list(scan_data.keys())
|
|
# Take 1% of scans, but at least 10 scans for meaningful training
|
|
quick_count = max(10, int(len(scan_ids) * 0.005))
|
|
selected_scan_ids = random.sample(scan_ids, quick_count)
|
|
scan_data = {scan_id: scan_data[scan_id] for scan_id in selected_scan_ids}
|
|
print(f"Quick mode enabled: Using {len(scan_data)} scans out of {original_count} (1%)")
|
|
|
|
# Split into train and validation
|
|
print("Splitting data into train and validation...")
|
|
# In finetune mode (when model is specified), don't avoid overlap between train and val
|
|
avoid_overlap = args.model is None
|
|
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
|
|
|
|
# In finetune mode, add old codes to prevent forgetting
|
|
old_codes_info = None
|
|
if args.model:
|
|
# Get distinct codes from current finetune dataset
|
|
finetune_codes = set()
|
|
for scan_id, metadata in scan_data.items():
|
|
code = metadata.get('code')
|
|
if code:
|
|
finetune_codes.add(code)
|
|
|
|
num_finetune_codes = len(finetune_codes)
|
|
print(f"Finetune mode: Found {num_finetune_codes} distinct codes in finetune dataset")
|
|
|
|
if num_finetune_codes > 0:
|
|
# Find old codes (codes in full dataset but not in finetune dataset)
|
|
old_codes = set()
|
|
for scan_id, metadata in full_scan_data.items():
|
|
code = metadata.get('code')
|
|
if code and code not in finetune_codes:
|
|
old_codes.add(code)
|
|
|
|
print(f"Found {len(old_codes)} distinct old codes in full dataset")
|
|
|
|
if len(old_codes) > 0:
|
|
# Select the same number of distinct random old codes
|
|
num_old_codes_to_select = min(num_finetune_codes, len(old_codes))
|
|
selected_old_codes = random.sample(list(old_codes), num_old_codes_to_select)
|
|
print(f"Selected {num_old_codes_to_select} old codes to prevent forgetting: {selected_old_codes}")
|
|
|
|
# Add all scans for selected old codes to training set
|
|
old_scans_added = 0
|
|
for scan_id, metadata in full_scan_data.items():
|
|
code = metadata.get('code')
|
|
if code and code in selected_old_codes:
|
|
# Only add if not already in train_data (avoid duplicates)
|
|
if scan_id not in train_data:
|
|
train_data[scan_id] = metadata
|
|
old_scans_added += 1
|
|
|
|
print(f"Added {old_scans_added} scans from old codes to training set")
|
|
print(f"Training set size: {len(train_data)} scans (including {old_scans_added} from old codes)")
|
|
old_codes_info = {
|
|
'num_old_codes': num_old_codes_to_select,
|
|
'old_codes': selected_old_codes,
|
|
'old_scans_added': old_scans_added
|
|
}
|
|
else:
|
|
print("No old codes found in full dataset")
|
|
else:
|
|
print("No codes found in finetune dataset, skipping old code selection")
|
|
|
|
# Define transforms
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Create datasets
|
|
print("Creating training dataset...")
|
|
train_dataset = GridDataset(train_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
|
print(f"Training samples: {len(train_dataset)}")
|
|
|
|
print("Creating validation dataset...")
|
|
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
|
print(f"Validation samples: {len(val_dataset)}")
|
|
|
|
# Validate datasets have samples
|
|
if len(train_dataset) == 0:
|
|
print("Error: Training dataset is empty. Cannot proceed with training.")
|
|
print("This usually means all training scans are missing sbs.jpg files.")
|
|
return
|
|
|
|
if len(val_dataset) == 0:
|
|
print("Error: Validation dataset is empty. Cannot proceed with training.")
|
|
print("This usually means all validation scans are missing sbs.jpg files.")
|
|
return
|
|
|
|
# Create data loaders
|
|
print("Creating data loaders...")
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=8,
|
|
pin_memory=False,
|
|
persistent_workers=True,
|
|
prefetch_factor=2
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
pin_memory=False,
|
|
persistent_workers=True,
|
|
prefetch_factor=2
|
|
)
|
|
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
|
|
|
# Create model
|
|
print("Creating ResNet18 model...")
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
if args.model:
|
|
# Load specified model for finetuning
|
|
model_path = os.path.expanduser(args.model)
|
|
if not os.path.exists(model_path):
|
|
print(f"Error: Model file not found: {model_path}")
|
|
return
|
|
print(f"Loading model from {model_path}...")
|
|
model, _ = load_model(model_path)
|
|
else:
|
|
# Load default model
|
|
default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
|
|
print(f"Loading default model from {default_model_path}...")
|
|
model, _ = load_model(default_model_path)
|
|
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
|
if hasattr(model, '_orig_mod'):
|
|
model = model._orig_mod
|
|
# Modify final layer for binary classification
|
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
|
model = model.to(device)
|
|
print(f"Model loaded and moved to {device}")
|
|
|
|
# Define loss function and optimizer
|
|
print("Setting up loss function and optimizer...")
|
|
# Use FocalLoss with 0.99:0.01 weights for positive/negative classes
|
|
pos_weight = 0.99
|
|
criterion = FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device))
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
# Create ReduceLROnPlateau scheduler
|
|
scheduler = ReduceLROnPlateau(
|
|
optimizer,
|
|
mode='max',
|
|
factor=args.scheduler_factor,
|
|
patience=args.scheduler_patience,
|
|
min_lr=args.scheduler_min_lr
|
|
)
|
|
|
|
print(f"Using Adam optimizer with lr={args.lr}")
|
|
print(f"Using FocalLoss with weights: Negative={1.0 - pos_weight:.2f}, Positive={pos_weight:.2f}")
|
|
print(f"Using ReduceLROnPlateau scheduler with factor={args.scheduler_factor}, patience={args.scheduler_patience}, min_lr={args.scheduler_min_lr}")
|
|
|
|
# Collect training metadata
|
|
train_scan_ids = list(train_data.keys())
|
|
train_codes = set()
|
|
for scan_id, metadata in train_data.items():
|
|
code = metadata.get('code')
|
|
if code:
|
|
train_codes.add(code)
|
|
|
|
train_metadata = {
|
|
'train_scan_ids': train_scan_ids,
|
|
'train_codes': list(train_codes),
|
|
'hue_jitter': args.hue_jitter,
|
|
'quick_mode': args.quick
|
|
}
|
|
|
|
print(f"Training metadata:")
|
|
print(f" Train scan IDs: {len(train_scan_ids)}")
|
|
print(f" Train codes: {len(train_codes)}")
|
|
print(f" Hue jitter: {args.hue_jitter}")
|
|
print(f" Quick mode: {args.quick}")
|
|
|
|
# Initialize wandb
|
|
wandb.login(key='ec22e6ed1ed9891779d57da600f889294f83e41d')
|
|
wandb_config = {
|
|
'batch_size': args.batch_size,
|
|
'lr': args.lr,
|
|
'epochs': args.epochs,
|
|
'hue_jitter': args.hue_jitter,
|
|
'scheduler_patience': args.scheduler_patience,
|
|
'scheduler_factor': args.scheduler_factor,
|
|
'scheduler_min_lr': args.scheduler_min_lr,
|
|
'quick_mode': args.quick,
|
|
'train_scans': len(train_scan_ids),
|
|
'val_scans': len(val_data),
|
|
'train_samples': len(train_dataset),
|
|
'val_samples': len(val_dataset),
|
|
'model_path': args.model if args.model else 'default',
|
|
'finetune_mode': args.model is not None,
|
|
}
|
|
if old_codes_info:
|
|
wandb_config['old_codes_count'] = old_codes_info['num_old_codes']
|
|
wandb_config['old_scans_added'] = old_codes_info['old_scans_added']
|
|
wandb_name = args.wandb_name or f"train2-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
wandb.init(
|
|
project=args.wandb_project,
|
|
name=wandb_name,
|
|
config=wandb_config,
|
|
)
|
|
print(f"W&B logging enabled: project={args.wandb_project}, run={wandb_name}")
|
|
|
|
# Train the model
|
|
print("Starting training...")
|
|
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
|
|
|
|
# Save final model
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
mode_suffix = "_quick" if args.quick else ""
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs:03d}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
|
save_model(model, transform, final_model_path, train_metadata)
|
|
print(f"Training completed! Final model saved: {final_model_path}")
|
|
|
|
# Finish wandb run
|
|
wandb.finish()
|
|
|
|
if __name__ == '__main__':
|
|
main() |