refactor: update train2.py
This commit is contained in:
parent
bb230733c5
commit
c14deba2ca
@ -230,8 +230,14 @@ def load_scan_data(data_dir):
|
|||||||
|
|
||||||
return scan_data
|
return scan_data
|
||||||
|
|
||||||
def split_train_val(scan_data):
|
def split_train_val(scan_data, avoid_overlap=True):
|
||||||
"""Split data into train and validation sets"""
|
"""Split data into train and validation sets
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scan_data: Dictionary of scan_id -> metadata
|
||||||
|
avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap.
|
||||||
|
If False, use simple random split without overlap avoidance.
|
||||||
|
"""
|
||||||
scan_ids = list(scan_data.keys())
|
scan_ids = list(scan_data.keys())
|
||||||
|
|
||||||
print("Splitting data into train and validation sets...")
|
print("Splitting data into train and validation sets...")
|
||||||
@ -242,32 +248,37 @@ def split_train_val(scan_data):
|
|||||||
|
|
||||||
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
||||||
|
|
||||||
# Get all (code, labels) combinations that appear in the initial validation set
|
if avoid_overlap:
|
||||||
val_code_labels = set()
|
# Get all (code, labels) combinations that appear in the initial validation set
|
||||||
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
val_code_labels = set()
|
||||||
if scan_id in scan_data:
|
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
||||||
code = scan_data[scan_id].get('code')
|
if scan_id in scan_data:
|
||||||
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
code = scan_data[scan_id].get('code')
|
||||||
if code:
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||||
val_code_labels.add((code, labels))
|
if code:
|
||||||
|
val_code_labels.add((code, labels))
|
||||||
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
|
|
||||||
|
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
|
||||||
# Find all scans in train that have matching (code, labels) combinations and move them to validation
|
|
||||||
additional_val_scan_ids = set()
|
# Find all scans in train that have matching (code, labels) combinations and move them to validation
|
||||||
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
|
additional_val_scan_ids = set()
|
||||||
|
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
|
||||||
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
|
|
||||||
if scan_id in scan_data:
|
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
|
||||||
code = scan_data[scan_id].get('code')
|
if scan_id in scan_data:
|
||||||
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
code = scan_data[scan_id].get('code')
|
||||||
# If (code, labels) combination matches, move to validation
|
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||||
if code and (code, labels) in val_code_labels:
|
# If (code, labels) combination matches, move to validation
|
||||||
additional_val_scan_ids.add(scan_id)
|
if code and (code, labels) in val_code_labels:
|
||||||
|
additional_val_scan_ids.add(scan_id)
|
||||||
# Combine validation sets
|
|
||||||
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
# Combine validation sets
|
||||||
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
||||||
|
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||||
|
else:
|
||||||
|
# Simple split without overlap avoidance
|
||||||
|
all_val_scan_ids = set(initial_val_scan_ids)
|
||||||
|
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||||
|
|
||||||
# Create train and validation data dictionaries
|
# Create train and validation data dictionaries
|
||||||
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
||||||
@ -438,6 +449,8 @@ def main():
|
|||||||
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
|
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
|
||||||
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
|
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
|
||||||
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
||||||
|
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
||||||
|
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
@ -457,6 +470,26 @@ def main():
|
|||||||
print("No scan data found!")
|
print("No scan data found!")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Filter by scan IDs if provided
|
||||||
|
if args.scan_ids:
|
||||||
|
ranges = parse_ranges(args.scan_ids)
|
||||||
|
filtered_scan_data = {}
|
||||||
|
for scan_id, metadata in scan_data.items():
|
||||||
|
try:
|
||||||
|
scan_id_int = int(scan_id)
|
||||||
|
for val_range in ranges:
|
||||||
|
if in_range(scan_id_int, val_range):
|
||||||
|
filtered_scan_data[scan_id] = metadata
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
# Skip non-numeric scan IDs
|
||||||
|
continue
|
||||||
|
if not filtered_scan_data:
|
||||||
|
print(f"Error: No scan IDs found in range '{args.scan_ids}'")
|
||||||
|
return
|
||||||
|
scan_data = filtered_scan_data
|
||||||
|
print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'")
|
||||||
|
|
||||||
# Quick mode: use only 1% of scans for faster testing
|
# Quick mode: use only 1% of scans for faster testing
|
||||||
if args.quick:
|
if args.quick:
|
||||||
original_count = len(scan_data)
|
original_count = len(scan_data)
|
||||||
@ -469,7 +502,9 @@ def main():
|
|||||||
|
|
||||||
# Split into train and validation
|
# Split into train and validation
|
||||||
print("Splitting data into train and validation...")
|
print("Splitting data into train and validation...")
|
||||||
train_data, val_data = split_train_val(scan_data)
|
# In finetune mode (when model is specified), don't avoid overlap between train and val
|
||||||
|
avoid_overlap = args.model is None
|
||||||
|
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
|
||||||
|
|
||||||
# Define transforms
|
# Define transforms
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
@ -487,6 +522,17 @@ def main():
|
|||||||
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
||||||
print(f"Validation samples: {len(val_dataset)}")
|
print(f"Validation samples: {len(val_dataset)}")
|
||||||
|
|
||||||
|
# Validate datasets have samples
|
||||||
|
if len(train_dataset) == 0:
|
||||||
|
print("Error: Training dataset is empty. Cannot proceed with training.")
|
||||||
|
print("This usually means all training scans are missing sbs.jpg files.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(val_dataset) == 0:
|
||||||
|
print("Error: Validation dataset is empty. Cannot proceed with training.")
|
||||||
|
print("This usually means all validation scans are missing sbs.jpg files.")
|
||||||
|
return
|
||||||
|
|
||||||
# Create data loaders
|
# Create data loaders
|
||||||
print("Creating data loaders...")
|
print("Creating data loaders...")
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
@ -511,16 +557,27 @@ def main():
|
|||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
print("Creating ResNet18 model...")
|
print("Creating ResNet18 model...")
|
||||||
# model = torchvision.models.resnet18(pretrained=True)
|
|
||||||
models_dir = os.path.expanduser('~/emblem/models')
|
models_dir = os.path.expanduser('~/emblem/models')
|
||||||
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt'))
|
if args.model:
|
||||||
|
# Load specified model for finetuning
|
||||||
|
model_path = os.path.expanduser(args.model)
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error: Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
print(f"Loading model from {model_path}...")
|
||||||
|
model, _ = load_model(model_path)
|
||||||
|
else:
|
||||||
|
# Load default model
|
||||||
|
default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
|
||||||
|
print(f"Loading default model from {default_model_path}...")
|
||||||
|
model, _ = load_model(default_model_path)
|
||||||
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
||||||
if hasattr(model, '_orig_mod'):
|
if hasattr(model, '_orig_mod'):
|
||||||
model = model._orig_mod
|
model = model._orig_mod
|
||||||
# Modify final layer for binary classification
|
# Modify final layer for binary classification
|
||||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
print(f"Model created and moved to {device}")
|
print(f"Model loaded and moved to {device}")
|
||||||
|
|
||||||
# Define loss function and optimizer
|
# Define loss function and optimizer
|
||||||
print("Setting up loss function and optimizer...")
|
print("Setting up loss function and optimizer...")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user