Revert to code-label combination based train/val split

This commit is contained in:
Fam Zheng 2025-12-27 08:52:58 +00:00
parent aa2d4041e3
commit 73c35a4968

View File

@ -233,22 +233,45 @@ def load_scan_data(data_dir):
def split_train_val(scan_data): def split_train_val(scan_data):
"""Split data into train and validation sets""" """Split data into train and validation sets"""
scan_ids = list(scan_data.keys()) scan_ids = list(scan_data.keys())
random.shuffle(scan_ids)
print("Splitting data into train and validation sets...") print("Splitting data into train and validation sets...")
# Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation # Select 10% random scan_ids for initial validation set
train_size = int(len(scan_ids) * 0.2) val_size = int(len(scan_ids) * 0.1)
val_size = int(len(scan_ids) * 0.4) initial_val_scan_ids = random.sample(scan_ids, val_size)
train_scan_ids = scan_ids[:train_size] print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
val_scan_ids = scan_ids[train_size:train_size + val_size]
print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)") # Get all (code, labels) combinations that appear in the initial validation set
val_code_labels = set()
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
if code:
val_code_labels.add((code, labels))
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
# Find all scans in train that have matching (code, labels) combinations and move them to validation
additional_val_scan_ids = set()
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
# If (code, labels) combination matches, move to validation
if code and (code, labels) in val_code_labels:
additional_val_scan_ids.add(scan_id)
# Combine validation sets
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
# Create train and validation data dictionaries # Create train and validation data dictionaries
train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids} train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids} val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
print(f"Final split:") print(f"Final split:")
print(f" Total scans: {len(scan_data)}") print(f" Total scans: {len(scan_data)}")