refactor: update train2.py

This commit is contained in:
Fam Zheng 2025-12-27 10:33:50 +00:00
parent bb230733c5
commit c14deba2ca

View File

@ -230,8 +230,14 @@ def load_scan_data(data_dir):
return scan_data return scan_data
def split_train_val(scan_data): def split_train_val(scan_data, avoid_overlap=True):
"""Split data into train and validation sets""" """Split data into train and validation sets
Args:
scan_data: Dictionary of scan_id -> metadata
avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap.
If False, use simple random split without overlap avoidance.
"""
scan_ids = list(scan_data.keys()) scan_ids = list(scan_data.keys())
print("Splitting data into train and validation sets...") print("Splitting data into train and validation sets...")
@ -242,32 +248,37 @@ def split_train_val(scan_data):
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
# Get all (code, labels) combinations that appear in the initial validation set if avoid_overlap:
val_code_labels = set() # Get all (code, labels) combinations that appear in the initial validation set
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): val_code_labels = set()
if scan_id in scan_data: for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
code = scan_data[scan_id].get('code') if scan_id in scan_data:
labels = tuple(sorted(scan_data[scan_id].get('labels', []))) code = scan_data[scan_id].get('code')
if code: labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
val_code_labels.add((code, labels)) if code:
val_code_labels.add((code, labels))
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
# Find all scans in train that have matching (code, labels) combinations and move them to validation
additional_val_scan_ids = set() # Find all scans in train that have matching (code, labels) combinations and move them to validation
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) additional_val_scan_ids = set()
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
if scan_id in scan_data: for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
code = scan_data[scan_id].get('code') if scan_id in scan_data:
labels = tuple(sorted(scan_data[scan_id].get('labels', []))) code = scan_data[scan_id].get('code')
# If (code, labels) combination matches, move to validation labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
if code and (code, labels) in val_code_labels: # If (code, labels) combination matches, move to validation
additional_val_scan_ids.add(scan_id) if code and (code, labels) in val_code_labels:
additional_val_scan_ids.add(scan_id)
# Combine validation sets
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids # Combine validation sets
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
else:
# Simple split without overlap avoidance
all_val_scan_ids = set(initial_val_scan_ids)
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
# Create train and validation data dictionaries # Create train and validation data dictionaries
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
@ -438,6 +449,8 @@ def main():
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)') parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)') parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)') parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
args = parser.parse_args() args = parser.parse_args()
# Set device # Set device
@ -457,6 +470,26 @@ def main():
print("No scan data found!") print("No scan data found!")
return return
# Filter by scan IDs if provided
if args.scan_ids:
ranges = parse_ranges(args.scan_ids)
filtered_scan_data = {}
for scan_id, metadata in scan_data.items():
try:
scan_id_int = int(scan_id)
for val_range in ranges:
if in_range(scan_id_int, val_range):
filtered_scan_data[scan_id] = metadata
break
except ValueError:
# Skip non-numeric scan IDs
continue
if not filtered_scan_data:
print(f"Error: No scan IDs found in range '{args.scan_ids}'")
return
scan_data = filtered_scan_data
print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'")
# Quick mode: use only 1% of scans for faster testing # Quick mode: use only 1% of scans for faster testing
if args.quick: if args.quick:
original_count = len(scan_data) original_count = len(scan_data)
@ -469,7 +502,9 @@ def main():
# Split into train and validation # Split into train and validation
print("Splitting data into train and validation...") print("Splitting data into train and validation...")
train_data, val_data = split_train_val(scan_data) # In finetune mode (when model is specified), don't avoid overlap between train and val
avoid_overlap = args.model is None
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
# Define transforms # Define transforms
transform = transforms.Compose([ transform = transforms.Compose([
@ -487,6 +522,17 @@ def main():
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter) val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
print(f"Validation samples: {len(val_dataset)}") print(f"Validation samples: {len(val_dataset)}")
# Validate datasets have samples
if len(train_dataset) == 0:
print("Error: Training dataset is empty. Cannot proceed with training.")
print("This usually means all training scans are missing sbs.jpg files.")
return
if len(val_dataset) == 0:
print("Error: Validation dataset is empty. Cannot proceed with training.")
print("This usually means all validation scans are missing sbs.jpg files.")
return
# Create data loaders # Create data loaders
print("Creating data loaders...") print("Creating data loaders...")
train_loader = DataLoader( train_loader = DataLoader(
@ -511,16 +557,27 @@ def main():
# Create model # Create model
print("Creating ResNet18 model...") print("Creating ResNet18 model...")
# model = torchvision.models.resnet18(pretrained=True)
models_dir = os.path.expanduser('~/emblem/models') models_dir = os.path.expanduser('~/emblem/models')
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')) if args.model:
# Load specified model for finetuning
model_path = os.path.expanduser(args.model)
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
return
print(f"Loading model from {model_path}...")
model, _ = load_model(model_path)
else:
# Load default model
default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
print(f"Loading default model from {default_model_path}...")
model, _ = load_model(default_model_path)
# Unwrap compiled model if it exists (torch.compile can cause issues when loading) # Unwrap compiled model if it exists (torch.compile can cause issues when loading)
if hasattr(model, '_orig_mod'): if hasattr(model, '_orig_mod'):
model = model._orig_mod model = model._orig_mod
# Modify final layer for binary classification # Modify final layer for binary classification
# model.fc = nn.Linear(model.fc.in_features, 2) # model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device) model = model.to(device)
print(f"Model created and moved to {device}") print(f"Model loaded and moved to {device}")
# Define loss function and optimizer # Define loss function and optimizer
print("Setting up loss function and optimizer...") print("Setting up loss function and optimizer...")