From c14deba2ca0382530c51efc9cd1ec0e3070ff731 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Sat, 27 Dec 2025 10:33:50 +0000 Subject: [PATCH] refactor: update train2.py --- emblem5/ai/train2.py | 121 +++++++++++++++++++++++++++++++------------ 1 file changed, 89 insertions(+), 32 deletions(-) diff --git a/emblem5/ai/train2.py b/emblem5/ai/train2.py index c3e3eba..a69dcbe 100755 --- a/emblem5/ai/train2.py +++ b/emblem5/ai/train2.py @@ -230,8 +230,14 @@ def load_scan_data(data_dir): return scan_data -def split_train_val(scan_data): - """Split data into train and validation sets""" +def split_train_val(scan_data, avoid_overlap=True): + """Split data into train and validation sets + + Args: + scan_data: Dictionary of scan_id -> metadata + avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap. + If False, use simple random split without overlap avoidance. + """ scan_ids = list(scan_data.keys()) print("Splitting data into train and validation sets...") @@ -242,32 +248,37 @@ def split_train_val(scan_data): print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") - # Get all (code, labels) combinations that appear in the initial validation set - val_code_labels = set() - for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): - if scan_id in scan_data: - code = scan_data[scan_id].get('code') - labels = tuple(sorted(scan_data[scan_id].get('labels', []))) - if code: - val_code_labels.add((code, labels)) - - print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set") - - # Find all scans in train that have matching (code, labels) combinations and move them to validation - additional_val_scan_ids = set() - train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) - - for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"): - if scan_id in scan_data: - code = scan_data[scan_id].get('code') - labels = tuple(sorted(scan_data[scan_id].get('labels', []))) - # If (code, labels) combination matches, move to validation - if code and (code, labels) in val_code_labels: - additional_val_scan_ids.add(scan_id) - - # Combine validation sets - all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids - all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids + if avoid_overlap: + # Get all (code, labels) combinations that appear in the initial validation set + val_code_labels = set() + for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): + if scan_id in scan_data: + code = scan_data[scan_id].get('code') + labels = tuple(sorted(scan_data[scan_id].get('labels', []))) + if code: + val_code_labels.add((code, labels)) + + print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set") + + # Find all scans in train that have matching (code, labels) combinations and move them to validation + additional_val_scan_ids = set() + train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) + + for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"): + if scan_id in scan_data: + code = scan_data[scan_id].get('code') + labels = tuple(sorted(scan_data[scan_id].get('labels', []))) + # If (code, labels) combination matches, move to validation + if code and (code, labels) in val_code_labels: + additional_val_scan_ids.add(scan_id) + + # Combine validation sets + all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids + all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids + else: + # Simple split without overlap avoidance + all_val_scan_ids = set(initial_val_scan_ids) + all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids # Create train and validation data dictionaries train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} @@ -438,6 +449,8 @@ def main(): parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)') parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)') parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)') + parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")') + parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning') args = parser.parse_args() # Set device @@ -457,6 +470,26 @@ def main(): print("No scan data found!") return + # Filter by scan IDs if provided + if args.scan_ids: + ranges = parse_ranges(args.scan_ids) + filtered_scan_data = {} + for scan_id, metadata in scan_data.items(): + try: + scan_id_int = int(scan_id) + for val_range in ranges: + if in_range(scan_id_int, val_range): + filtered_scan_data[scan_id] = metadata + break + except ValueError: + # Skip non-numeric scan IDs + continue + if not filtered_scan_data: + print(f"Error: No scan IDs found in range '{args.scan_ids}'") + return + scan_data = filtered_scan_data + print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'") + # Quick mode: use only 1% of scans for faster testing if args.quick: original_count = len(scan_data) @@ -469,7 +502,9 @@ def main(): # Split into train and validation print("Splitting data into train and validation...") - train_data, val_data = split_train_val(scan_data) + # In finetune mode (when model is specified), don't avoid overlap between train and val + avoid_overlap = args.model is None + train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap) # Define transforms transform = transforms.Compose([ @@ -487,6 +522,17 @@ def main(): val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter) print(f"Validation samples: {len(val_dataset)}") + # Validate datasets have samples + if len(train_dataset) == 0: + print("Error: Training dataset is empty. Cannot proceed with training.") + print("This usually means all training scans are missing sbs.jpg files.") + return + + if len(val_dataset) == 0: + print("Error: Validation dataset is empty. Cannot proceed with training.") + print("This usually means all validation scans are missing sbs.jpg files.") + return + # Create data loaders print("Creating data loaders...") train_loader = DataLoader( @@ -511,16 +557,27 @@ def main(): # Create model print("Creating ResNet18 model...") - # model = torchvision.models.resnet18(pretrained=True) models_dir = os.path.expanduser('~/emblem/models') - model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')) + if args.model: + # Load specified model for finetuning + model_path = os.path.expanduser(args.model) + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + return + print(f"Loading model from {model_path}...") + model, _ = load_model(model_path) + else: + # Load default model + default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt') + print(f"Loading default model from {default_model_path}...") + model, _ = load_model(default_model_path) # Unwrap compiled model if it exists (torch.compile can cause issues when loading) if hasattr(model, '_orig_mod'): model = model._orig_mod # Modify final layer for binary classification # model.fc = nn.Linear(model.fc.in_features, 2) model = model.to(device) - print(f"Model created and moved to {device}") + print(f"Model loaded and moved to {device}") # Define loss function and optimizer print("Setting up loss function and optimizer...")