Revert to code-label combination based train/val split
This commit is contained in:
parent
aa2d4041e3
commit
73c35a4968
@ -233,22 +233,45 @@ def load_scan_data(data_dir):
|
|||||||
def split_train_val(scan_data):
|
def split_train_val(scan_data):
|
||||||
"""Split data into train and validation sets"""
|
"""Split data into train and validation sets"""
|
||||||
scan_ids = list(scan_data.keys())
|
scan_ids = list(scan_data.keys())
|
||||||
random.shuffle(scan_ids)
|
|
||||||
|
|
||||||
print("Splitting data into train and validation sets...")
|
print("Splitting data into train and validation sets...")
|
||||||
|
|
||||||
# Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation
|
# Select 10% random scan_ids for initial validation set
|
||||||
train_size = int(len(scan_ids) * 0.2)
|
val_size = int(len(scan_ids) * 0.1)
|
||||||
val_size = int(len(scan_ids) * 0.4)
|
initial_val_scan_ids = random.sample(scan_ids, val_size)
|
||||||
|
|
||||||
train_scan_ids = scan_ids[:train_size]
|
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
||||||
val_scan_ids = scan_ids[train_size:train_size + val_size]
|
|
||||||
|
|
||||||
print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)")
|
# Get all (code, labels) combinations that appear in the initial validation set
|
||||||
|
val_code_labels = set()
|
||||||
|
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
||||||
|
if scan_id in scan_data:
|
||||||
|
code = scan_data[scan_id].get('code')
|
||||||
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||||
|
if code:
|
||||||
|
val_code_labels.add((code, labels))
|
||||||
|
|
||||||
|
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
|
||||||
|
|
||||||
|
# Find all scans in train that have matching (code, labels) combinations and move them to validation
|
||||||
|
additional_val_scan_ids = set()
|
||||||
|
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
|
||||||
|
|
||||||
|
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
|
||||||
|
if scan_id in scan_data:
|
||||||
|
code = scan_data[scan_id].get('code')
|
||||||
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||||
|
# If (code, labels) combination matches, move to validation
|
||||||
|
if code and (code, labels) in val_code_labels:
|
||||||
|
additional_val_scan_ids.add(scan_id)
|
||||||
|
|
||||||
|
# Combine validation sets
|
||||||
|
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
||||||
|
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||||
|
|
||||||
# Create train and validation data dictionaries
|
# Create train and validation data dictionaries
|
||||||
train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids}
|
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
||||||
val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids}
|
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
|
||||||
|
|
||||||
print(f"Final split:")
|
print(f"Final split:")
|
||||||
print(f" Total scans: {len(scan_data)}")
|
print(f" Total scans: {len(scan_data)}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user