themblem/emblem5/ai/train2.py
2025-10-29 21:27:29 +00:00

545 lines
22 KiB
Python

#!/usr/bin/env python3
'''
all children of data/scans are scan_ids
in each scan_id, there is a file called "metadata.json"
labels key has 'pos' or 'neg'
it also has 'code' key which is a string
For the largest 20% scan_ids, we use as validation set
for all codes appearing in the validation set in all scans, we also use as validation set
the rest are training set
preprocess the sbs.jpg for both train and validation:
1. split left and right half
2. crop 3x3 of left
3. crop 3x3 of right
4. for each pair of crop, concat into grid-i-j.jpg, and apply some colorjitter
5. all the grid-i-j.jpg are used with the label
load a resnet18 model, and train it on the training set
train for 10 epochs, and print accuracy on validation set each epoch
save the model in the end.
'''
import os
import json
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
from collections import defaultdict
import multiprocessing as mp
from functools import partial
from datetime import datetime
from common import *
def process_scan_grid(scan_item, hue_jitter=0.1):
"""Process a single scan to create grid files and metadata"""
scan_id, metadata = scan_item
sample_metadata = []
sbs_path = os.path.join('data/scans', scan_id, 'sbs.jpg')
if not os.path.exists(sbs_path):
return sample_metadata
# Create grid directory if it doesn't exist
grid_dir = os.path.join('data/scans', scan_id, 'grids')
os.makedirs(grid_dir, exist_ok=True)
# Check if all grid files already exist
all_grids_exist = True
for i in range(3):
for j in range(3):
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
if not os.path.exists(grid_path):
all_grids_exist = False
break
if not all_grids_exist:
break
# If all grid files exist, just return metadata
if all_grids_exist:
for i in range(3):
for j in range(3):
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
label = 1 if 'pos' in metadata['labels'] else 0
sample_metadata.append({
'scan_id': scan_id,
'grid_path': grid_path,
'grid_i': i,
'grid_j': j,
'label': label
})
return sample_metadata
# Load the side-by-side image
sbs_img = Image.open(sbs_path).convert('RGB')
width, height = sbs_img.size
# Calculate crop dimensions
crop_width = width // 6 # width/2 / 3
crop_height = height // 3
# Generate all 3x3 grid combinations
for i in range(3):
for j in range(3):
# Calculate crop positions directly from original image
left_x = i * crop_width
right_x = (i + 3) * crop_width # Skip middle section
y = j * crop_height
# Crop directly from original image
left_crop = sbs_img.crop((left_x, y, left_x + crop_width, y + crop_height))
right_crop = sbs_img.crop((right_x, y, right_x + crop_width, y + crop_height))
# Apply color jitter only to left crop
color_jitter = transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=hue_jitter
)
left_crop = color_jitter(left_crop)
# Concatenate left and right crops horizontally
grid_img = Image.new('RGB', (crop_width * 2, crop_height))
grid_img.paste(left_crop, (0, 0))
grid_img.paste(right_crop, (crop_width, 0))
# Save grid image
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
grid_img.save(grid_path, 'JPEG', quality=95)
# Store metadata
label = 1 if 'pos' in metadata['labels'] else 0
sample_metadata.append({
'scan_id': scan_id,
'grid_path': grid_path,
'grid_i': i,
'grid_j': j,
'label': label
})
return sample_metadata
class GridDataset(Dataset):
def __init__(self, scan_data, transform=None, num_workers=None, hue_jitter=0.1):
self.scan_data = scan_data
self.transform = transform
self.sample_metadata = []
self.hue_jitter = hue_jitter
if num_workers is None:
num_workers = min(mp.cpu_count(), 8) # Limit to 8 workers max
print(f"Preprocessing {len(scan_data)} scans to create grid files using {num_workers} workers...")
# Use multiprocessing to create grid files
with mp.Pool(processes=num_workers) as pool:
# Process all scans in parallel with hue_jitter parameter
process_func = partial(process_scan_grid, hue_jitter=self.hue_jitter)
results = list(tqdm(
pool.imap(process_func, scan_data.items()),
total=len(scan_data),
desc="Creating grid files"
))
# Collect all sample metadata
for result in results:
self.sample_metadata.extend(result)
print(f"Created {len(self.sample_metadata)} grid files")
def __len__(self):
return len(self.sample_metadata)
def __getitem__(self, idx):
# Get sample metadata
metadata = self.sample_metadata[idx]
# Load the pre-saved grid image directly
grid_img = Image.open(metadata['grid_path']).convert('RGB')
# Apply transforms
if self.transform:
grid_img = self.transform(grid_img)
return grid_img, torch.tensor(metadata['label'], dtype=torch.long)
def load_scan_data(data_dir):
"""Load all scan metadata and organize by scan_id"""
scan_data = {}
scans_dir = os.path.join(data_dir, 'scans')
if not os.path.exists(scans_dir):
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
scan_ids = [d for d in os.listdir(scans_dir)
if os.path.isdir(os.path.join(scans_dir, d))]
print(f"Loading metadata for {len(scan_ids)} scans...")
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
scan_data[scan_id] = metadata
except Exception as e:
print(f"Error loading metadata for {scan_id}: {e}")
return scan_data
def split_train_val(scan_data):
"""Split data into train and validation sets"""
scan_ids = list(scan_data.keys())
print("Splitting data into train and validation sets...")
# Select 10% random scan_ids for initial validation set
val_size = int(len(scan_ids) * 0.1)
initial_val_scan_ids = random.sample(scan_ids, val_size)
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
# Get all (code, labels) combinations that appear in the initial validation set
val_code_labels = set()
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
if code:
val_code_labels.add((code, labels))
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
# Find all scans in train that have matching (code, labels) combinations and move them to validation
additional_val_scan_ids = set()
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
# If (code, labels) combination matches, move to validation
if code and (code, labels) in val_code_labels:
additional_val_scan_ids.add(scan_id)
# Combine validation sets
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
# Create train and validation data dictionaries
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
print(f"Final split:")
print(f" Total scans: {len(scan_data)}")
print(f" Training scans: {len(train_data)}")
print(f" Validation scans: {len(val_data)}")
return train_data, val_data
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata=None, num_epochs=10):
"""Train the model for specified number of epochs"""
best_val_acc = 0.0
# Enable mixed precision training
scaler = torch.amp.GradScaler('cuda')
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
train_pos_correct = 0
train_pos_total = 0
train_neg_correct = 0
train_neg_total = 0
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
for batch_idx, (data, target) in enumerate(train_pbar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Use mixed precision training
with torch.amp.autocast('cuda'):
output = model(data)
loss = criterion(output, target).mean()
# Scale loss and backpropagate
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.detach().item()
pred = output.argmax(dim=1, keepdim=True)
train_correct += pred.eq(target.view_as(pred)).sum().item()
train_total += target.size(0)
# Track per-class accuracy
for i in range(target.size(0)):
if target[i] == 1: # Positive class
train_pos_total += 1
if pred[i] == target[i]:
train_pos_correct += 1
else: # Negative class
train_neg_total += 1
if pred[i] == target[i]:
train_neg_correct += 1
# Clear memory after each batch
del data, target, output, loss, pred
# Calculate per-class accuracies
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
train_pbar.set_postfix({
'Loss': f'{train_loss/(batch_idx+1):.4f}',
'Acc': f'{100.*train_correct/train_total:.2f}%',
'Pos': f'{train_pos_acc:.2f}%',
'Neg': f'{train_neg_acc:.2f}%'
})
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
val_pos_correct = 0
val_pos_total = 0
val_neg_correct = 0
val_neg_total = 0
with torch.no_grad():
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
for batch_idx, (data, target) in enumerate(val_pbar):
data, target = data.to(device), target.to(device)
# Use mixed precision for validation too
with torch.amp.autocast('cuda'):
output = model(data)
loss = criterion(output, target).mean()
val_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
val_correct += pred.eq(target.view_as(pred)).sum().item()
val_total += target.size(0)
# Track per-class accuracy
for i in range(target.size(0)):
if target[i] == 1: # Positive class
val_pos_total += 1
if pred[i] == target[i]:
val_pos_correct += 1
else: # Negative class
val_neg_total += 1
if pred[i] == target[i]:
val_neg_correct += 1
# Clear memory after each batch
del data, target, output, loss, pred
# Calculate per-class accuracies
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
val_pbar.set_postfix({
'Loss': f'{val_loss/(batch_idx+1):.4f}',
'Acc': f'{100.*val_correct/val_total:.2f}%',
'Pos': f'{val_pos_acc:.2f}%',
'Neg': f'{val_neg_acc:.2f}%'
})
# Calculate final accuracies
train_acc = 100. * train_correct / train_total
val_acc = 100. * val_correct / val_total
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f' Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% (Pos: {train_pos_acc:.2f}%, Neg: {train_neg_acc:.2f}%)')
print(f' Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}% (Pos: {val_pos_acc:.2f}%, Neg: {val_neg_acc:.2f}%)')
print(f' Train samples - Pos: {train_pos_total}, Neg: {train_neg_total}')
print(f' Val samples - Pos: {val_pos_total}, Neg: {val_neg_total}')
# Step the scheduler with validation accuracy
scheduler.step(val_acc)
current_lr = optimizer.param_groups[0]['lr']
print(f' Current learning rate: {current_lr:.6f}')
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
best_model_path = f'models/best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt'
save_model(model, transform, best_model_path, train_metadata)
print(f' New best validation accuracy: {val_acc:.2f}%')
print(f' Best model saved: {best_model_path}')
return model, val_acc, val_pos_acc, val_neg_acc
def main():
parser = argparse.ArgumentParser(description='Train ResNet18 model on grid data')
parser.add_argument('--data-dir', default='data', help='Data directory')
parser.add_argument('--batch-size', type=int, default=64, help='Batch size (increased for better GPU utilization)')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
parser.add_argument('--num-workers', type=int, default=None, help='Number of workers for preprocessing (default: auto)')
parser.add_argument('--quick', action='store_true', help='Quick mode: use only 1%% of scans for faster testing')
parser.add_argument('--hue-jitter', type=float, default=0.1, help='Hue jitter parameter for ColorJitter (default: 0.1)')
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
args = parser.parse_args()
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'Using hue jitter: {args.hue_jitter}')
# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)
# Load scan data
print("Loading scan data...")
scan_data = load_scan_data(args.data_dir)
if not scan_data:
print("No scan data found!")
return
# Quick mode: use only 1% of scans for faster testing
if args.quick:
original_count = len(scan_data)
scan_ids = list(scan_data.keys())
# Take 1% of scans, but at least 10 scans for meaningful training
quick_count = max(10, int(len(scan_ids) * 0.005))
selected_scan_ids = random.sample(scan_ids, quick_count)
scan_data = {scan_id: scan_data[scan_id] for scan_id in selected_scan_ids}
print(f"Quick mode enabled: Using {len(scan_data)} scans out of {original_count} (1%)")
# Split into train and validation
print("Splitting data into train and validation...")
train_data, val_data = split_train_val(scan_data)
# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create datasets
print("Creating training dataset...")
train_dataset = GridDataset(train_data, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
print(f"Training samples: {len(train_dataset)}")
print("Creating validation dataset...")
val_dataset = GridDataset(val_data, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
print(f"Validation samples: {len(val_dataset)}")
# Create data loaders
print("Creating data loaders...")
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
pin_memory=False,
persistent_workers=True,
prefetch_factor=2
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=8,
pin_memory=False,
persistent_workers=True,
prefetch_factor=2
)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
# Create model
print("Creating ResNet18 model...")
# model = torchvision.models.resnet18(pretrained=True)
model, _ = load_model('models/gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
# Modify final layer for binary classification
# model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)
print(f"Model created and moved to {device}")
# Define loss function and optimizer
print("Setting up loss function and optimizer...")
# Use FocalLoss with 0.99:0.01 weights for positive/negative classes
pos_weight = 0.99
criterion = FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device))
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Create ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(
optimizer,
mode='max',
factor=args.scheduler_factor,
patience=args.scheduler_patience,
min_lr=args.scheduler_min_lr
)
print(f"Using Adam optimizer with lr={args.lr}")
print(f"Using FocalLoss with weights: Negative={1.0 - pos_weight:.2f}, Positive={pos_weight:.2f}")
print(f"Using ReduceLROnPlateau scheduler with factor={args.scheduler_factor}, patience={args.scheduler_patience}, min_lr={args.scheduler_min_lr}")
# Collect training metadata
train_scan_ids = list(train_data.keys())
train_codes = set()
for scan_id, metadata in train_data.items():
code = metadata.get('code')
if code:
train_codes.add(code)
train_metadata = {
'train_scan_ids': train_scan_ids,
'train_codes': list(train_codes),
'hue_jitter': args.hue_jitter,
'quick_mode': args.quick
}
print(f"Training metadata:")
print(f" Train scan IDs: {len(train_scan_ids)}")
print(f" Train codes: {len(train_codes)}")
print(f" Hue jitter: {args.hue_jitter}")
print(f" Quick mode: {args.quick}")
# Train the model
print("Starting training...")
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
# Save final model
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else ""
final_model_path = f'models/final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt'
save_model(model, transform, final_model_path, train_metadata)
print(f"Training completed! Final model saved: {final_model_path}")
if __name__ == '__main__':
main()