refactor: update train2.py
This commit is contained in:
parent
bb230733c5
commit
c14deba2ca
@ -230,8 +230,14 @@ def load_scan_data(data_dir):
|
||||
|
||||
return scan_data
|
||||
|
||||
def split_train_val(scan_data):
|
||||
"""Split data into train and validation sets"""
|
||||
def split_train_val(scan_data, avoid_overlap=True):
|
||||
"""Split data into train and validation sets
|
||||
|
||||
Args:
|
||||
scan_data: Dictionary of scan_id -> metadata
|
||||
avoid_overlap: If True, move scans with matching (code, labels) from train to val to avoid overlap.
|
||||
If False, use simple random split without overlap avoidance.
|
||||
"""
|
||||
scan_ids = list(scan_data.keys())
|
||||
|
||||
print("Splitting data into train and validation sets...")
|
||||
@ -242,6 +248,7 @@ def split_train_val(scan_data):
|
||||
|
||||
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
||||
|
||||
if avoid_overlap:
|
||||
# Get all (code, labels) combinations that appear in the initial validation set
|
||||
val_code_labels = set()
|
||||
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
||||
@ -268,6 +275,10 @@ def split_train_val(scan_data):
|
||||
# Combine validation sets
|
||||
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
||||
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||
else:
|
||||
# Simple split without overlap avoidance
|
||||
all_val_scan_ids = set(initial_val_scan_ids)
|
||||
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||
|
||||
# Create train and validation data dictionaries
|
||||
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
||||
@ -438,6 +449,8 @@ def main():
|
||||
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
|
||||
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
|
||||
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
||||
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
||||
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
@ -457,6 +470,26 @@ def main():
|
||||
print("No scan data found!")
|
||||
return
|
||||
|
||||
# Filter by scan IDs if provided
|
||||
if args.scan_ids:
|
||||
ranges = parse_ranges(args.scan_ids)
|
||||
filtered_scan_data = {}
|
||||
for scan_id, metadata in scan_data.items():
|
||||
try:
|
||||
scan_id_int = int(scan_id)
|
||||
for val_range in ranges:
|
||||
if in_range(scan_id_int, val_range):
|
||||
filtered_scan_data[scan_id] = metadata
|
||||
break
|
||||
except ValueError:
|
||||
# Skip non-numeric scan IDs
|
||||
continue
|
||||
if not filtered_scan_data:
|
||||
print(f"Error: No scan IDs found in range '{args.scan_ids}'")
|
||||
return
|
||||
scan_data = filtered_scan_data
|
||||
print(f"Filtered to {len(scan_data)} scans matching range '{args.scan_ids}'")
|
||||
|
||||
# Quick mode: use only 1% of scans for faster testing
|
||||
if args.quick:
|
||||
original_count = len(scan_data)
|
||||
@ -469,7 +502,9 @@ def main():
|
||||
|
||||
# Split into train and validation
|
||||
print("Splitting data into train and validation...")
|
||||
train_data, val_data = split_train_val(scan_data)
|
||||
# In finetune mode (when model is specified), don't avoid overlap between train and val
|
||||
avoid_overlap = args.model is None
|
||||
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
|
||||
|
||||
# Define transforms
|
||||
transform = transforms.Compose([
|
||||
@ -487,6 +522,17 @@ def main():
|
||||
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
||||
print(f"Validation samples: {len(val_dataset)}")
|
||||
|
||||
# Validate datasets have samples
|
||||
if len(train_dataset) == 0:
|
||||
print("Error: Training dataset is empty. Cannot proceed with training.")
|
||||
print("This usually means all training scans are missing sbs.jpg files.")
|
||||
return
|
||||
|
||||
if len(val_dataset) == 0:
|
||||
print("Error: Validation dataset is empty. Cannot proceed with training.")
|
||||
print("This usually means all validation scans are missing sbs.jpg files.")
|
||||
return
|
||||
|
||||
# Create data loaders
|
||||
print("Creating data loaders...")
|
||||
train_loader = DataLoader(
|
||||
@ -511,16 +557,27 @@ def main():
|
||||
|
||||
# Create model
|
||||
print("Creating ResNet18 model...")
|
||||
# model = torchvision.models.resnet18(pretrained=True)
|
||||
models_dir = os.path.expanduser('~/emblem/models')
|
||||
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt'))
|
||||
if args.model:
|
||||
# Load specified model for finetuning
|
||||
model_path = os.path.expanduser(args.model)
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
return
|
||||
print(f"Loading model from {model_path}...")
|
||||
model, _ = load_model(model_path)
|
||||
else:
|
||||
# Load default model
|
||||
default_model_path = os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
|
||||
print(f"Loading default model from {default_model_path}...")
|
||||
model, _ = load_model(default_model_path)
|
||||
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
||||
if hasattr(model, '_orig_mod'):
|
||||
model = model._orig_mod
|
||||
# Modify final layer for binary classification
|
||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||
model = model.to(device)
|
||||
print(f"Model created and moved to {device}")
|
||||
print(f"Model loaded and moved to {device}")
|
||||
|
||||
# Define loss function and optimizer
|
||||
print("Setting up loss function and optimizer...")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user