diff --git a/emblem5/ai/train2.py b/emblem5/ai/train2.py index 227d096..c3e3eba 100755 --- a/emblem5/ai/train2.py +++ b/emblem5/ai/train2.py @@ -233,22 +233,45 @@ def load_scan_data(data_dir): def split_train_val(scan_data): """Split data into train and validation sets""" scan_ids = list(scan_data.keys()) - random.shuffle(scan_ids) print("Splitting data into train and validation sets...") - # Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation - train_size = int(len(scan_ids) * 0.2) - val_size = int(len(scan_ids) * 0.4) + # Select 10% random scan_ids for initial validation set + val_size = int(len(scan_ids) * 0.1) + initial_val_scan_ids = random.sample(scan_ids, val_size) - train_scan_ids = scan_ids[:train_size] - val_scan_ids = scan_ids[train_size:train_size + val_size] + print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") - print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)") + # Get all (code, labels) combinations that appear in the initial validation set + val_code_labels = set() + for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): + if scan_id in scan_data: + code = scan_data[scan_id].get('code') + labels = tuple(sorted(scan_data[scan_id].get('labels', []))) + if code: + val_code_labels.add((code, labels)) + + print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set") + + # Find all scans in train that have matching (code, labels) combinations and move them to validation + additional_val_scan_ids = set() + train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) + + for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"): + if scan_id in scan_data: + code = scan_data[scan_id].get('code') + labels = tuple(sorted(scan_data[scan_id].get('labels', []))) + # If (code, labels) combination matches, move to validation + if code and (code, labels) in val_code_labels: + additional_val_scan_ids.add(scan_id) + + # Combine validation sets + all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids + all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids # Create train and validation data dictionaries - train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids} - val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids} + train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} + val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids} print(f"Final split:") print(f" Total scans: {len(scan_data)}")