556 lines
22 KiB
Python
Executable File
556 lines
22 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
'''
|
|
all children of data/scans are scan_ids
|
|
in each scan_id, there is a file called "metadata.json"
|
|
labels key has 'pos' or 'neg'
|
|
it also has 'code' key which is a string
|
|
|
|
For the largest 20% scan_ids, we use as validation set
|
|
for all codes appearing in the validation set in all scans, we also use as validation set
|
|
|
|
the rest are training set
|
|
|
|
preprocess the sbs.jpg for both train and validation:
|
|
1. split left and right half
|
|
2. crop 3x3 of left
|
|
3. crop 3x3 of right
|
|
4. for each pair of crop, concat into grid-i-j.jpg, and apply some colorjitter
|
|
5. all the grid-i-j.jpg are used with the label
|
|
|
|
load a resnet18 model, and train it on the training set
|
|
train for 10 epochs, and print accuracy on validation set each epoch
|
|
|
|
save the model in the end.
|
|
'''
|
|
|
|
import os
|
|
import json
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from PIL import Image
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import argparse
|
|
from collections import defaultdict
|
|
import multiprocessing as mp
|
|
from functools import partial
|
|
from datetime import datetime
|
|
from common import *
|
|
|
|
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
|
|
"""Process a single scan to create grid files and metadata
|
|
Returns: (sample_metadata, reason) where reason is None on success, or a string describing the failure"""
|
|
scan_id, metadata = scan_item
|
|
sample_metadata = []
|
|
|
|
sbs_path = os.path.join(data_dir, 'scans', scan_id, 'sbs.jpg')
|
|
if not os.path.exists(sbs_path):
|
|
return sample_metadata, 'sbs.jpg missing'
|
|
|
|
# Create grid directory if it doesn't exist
|
|
grid_dir = os.path.join(data_dir, 'scans', scan_id, 'grids')
|
|
os.makedirs(grid_dir, exist_ok=True)
|
|
|
|
# Check if all grid files already exist
|
|
all_grids_exist = True
|
|
for i in range(3):
|
|
for j in range(3):
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
if not os.path.exists(grid_path):
|
|
all_grids_exist = False
|
|
break
|
|
if not all_grids_exist:
|
|
break
|
|
|
|
# If all grid files exist, just return metadata
|
|
if all_grids_exist:
|
|
for i in range(3):
|
|
for j in range(3):
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
label = 1 if 'pos' in metadata['labels'] else 0
|
|
sample_metadata.append({
|
|
'scan_id': scan_id,
|
|
'grid_path': grid_path,
|
|
'grid_i': i,
|
|
'grid_j': j,
|
|
'label': label
|
|
})
|
|
return sample_metadata, None
|
|
|
|
# Load the side-by-side image
|
|
try:
|
|
sbs_img = Image.open(sbs_path).convert('RGB')
|
|
except Exception as e:
|
|
return sample_metadata, f'sbs.jpg unreadable: {str(e)}'
|
|
width, height = sbs_img.size
|
|
|
|
# Check if image is too small
|
|
if width < 6 or height < 3:
|
|
return sample_metadata, f'sbs.jpg too small: {width}x{height}'
|
|
|
|
# Calculate crop dimensions
|
|
crop_width = width // 6 # width/2 / 3
|
|
crop_height = height // 3
|
|
|
|
if crop_width <= 0 or crop_height <= 0:
|
|
return sample_metadata, f'invalid crop dimensions: {crop_width}x{crop_height}'
|
|
|
|
# Generate all 3x3 grid combinations
|
|
try:
|
|
for i in range(3):
|
|
for j in range(3):
|
|
# Calculate crop positions directly from original image
|
|
left_x = i * crop_width
|
|
right_x = (i + 3) * crop_width # Skip middle section
|
|
y = j * crop_height
|
|
|
|
# Crop directly from original image
|
|
left_crop = sbs_img.crop((left_x, y, left_x + crop_width, y + crop_height))
|
|
right_crop = sbs_img.crop((right_x, y, right_x + crop_width, y + crop_height))
|
|
|
|
# Apply color jitter only to left crop
|
|
color_jitter = transforms.ColorJitter(
|
|
brightness=0.2,
|
|
contrast=0.2,
|
|
saturation=0.2,
|
|
hue=hue_jitter
|
|
)
|
|
left_crop = color_jitter(left_crop)
|
|
|
|
# Concatenate left and right crops horizontally
|
|
grid_img = Image.new('RGB', (crop_width * 2, crop_height))
|
|
grid_img.paste(left_crop, (0, 0))
|
|
grid_img.paste(right_crop, (crop_width, 0))
|
|
|
|
# Save grid image
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
grid_img.save(grid_path, 'JPEG', quality=95)
|
|
|
|
# Store metadata
|
|
label = 1 if 'pos' in metadata['labels'] else 0
|
|
sample_metadata.append({
|
|
'scan_id': scan_id,
|
|
'grid_path': grid_path,
|
|
'grid_i': i,
|
|
'grid_j': j,
|
|
'label': label
|
|
})
|
|
except Exception as e:
|
|
return sample_metadata, f'error creating grids: {str(e)}'
|
|
|
|
return sample_metadata, None
|
|
|
|
class GridDataset(Dataset):
|
|
def __init__(self, scan_data, data_dir='data', transform=None, num_workers=None, hue_jitter=0.1):
|
|
self.scan_data = scan_data
|
|
self.data_dir = data_dir
|
|
self.transform = transform
|
|
self.sample_metadata = []
|
|
self.hue_jitter = hue_jitter
|
|
|
|
if num_workers is None:
|
|
num_workers = min(mp.cpu_count(), 8) # Limit to 8 workers max
|
|
|
|
print(f"Preprocessing {len(scan_data)} scans to create grid files using {num_workers} workers...")
|
|
|
|
# Use multiprocessing to create grid files
|
|
with mp.Pool(processes=num_workers) as pool:
|
|
# Process all scans in parallel with hue_jitter parameter
|
|
process_func = partial(process_scan_grid, data_dir=self.data_dir, hue_jitter=self.hue_jitter)
|
|
results = list(tqdm(
|
|
pool.imap(process_func, scan_data.items()),
|
|
total=len(scan_data),
|
|
desc="Creating grid files"
|
|
))
|
|
|
|
# Collect all sample metadata and statistics
|
|
stats = defaultdict(int)
|
|
for result in results:
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
metadata_list, reason = result
|
|
self.sample_metadata.extend(metadata_list)
|
|
if reason is not None:
|
|
stats[reason] += 1
|
|
else:
|
|
# Backward compatibility - old format without reason
|
|
self.sample_metadata.extend(result)
|
|
|
|
print(f"Created {len(self.sample_metadata)} grid files")
|
|
if stats:
|
|
print("\nStatistics on why grid files were not created:")
|
|
for reason, count in sorted(stats.items(), key=lambda x: -x[1]):
|
|
print(f" {reason}: {count} scans")
|
|
|
|
def __len__(self):
|
|
return len(self.sample_metadata)
|
|
|
|
def __getitem__(self, idx):
|
|
# Get sample metadata
|
|
metadata = self.sample_metadata[idx]
|
|
|
|
# Load the pre-saved grid image directly
|
|
grid_img = Image.open(metadata['grid_path']).convert('RGB')
|
|
|
|
# Apply transforms
|
|
if self.transform:
|
|
grid_img = self.transform(grid_img)
|
|
|
|
return grid_img, torch.tensor(metadata['label'], dtype=torch.long)
|
|
|
|
def load_scan_data(data_dir):
|
|
"""Load all scan metadata and organize by scan_id"""
|
|
scan_data = {}
|
|
scans_dir = os.path.join(data_dir, 'scans')
|
|
|
|
if not os.path.exists(scans_dir):
|
|
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
|
|
|
|
scan_ids = [d for d in os.listdir(scans_dir)
|
|
if os.path.isdir(os.path.join(scans_dir, d))]
|
|
|
|
print(f"Loading metadata for {len(scan_ids)} scans...")
|
|
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
|
|
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
scan_data[scan_id] = metadata
|
|
except Exception as e:
|
|
print(f"Error loading metadata for {scan_id}: {e}")
|
|
|
|
return scan_data
|
|
|
|
def split_train_val(scan_data):
|
|
"""Split data into train and validation sets"""
|
|
scan_ids = list(scan_data.keys())
|
|
random.shuffle(scan_ids)
|
|
|
|
print("Splitting data into train and validation sets...")
|
|
|
|
# Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation
|
|
train_size = int(len(scan_ids) * 0.2)
|
|
val_size = int(len(scan_ids) * 0.4)
|
|
|
|
train_scan_ids = scan_ids[:train_size]
|
|
val_scan_ids = scan_ids[train_size:train_size + val_size]
|
|
|
|
print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)")
|
|
|
|
# Create train and validation data dictionaries
|
|
train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids}
|
|
val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids}
|
|
|
|
print(f"Final split:")
|
|
print(f" Total scans: {len(scan_data)}")
|
|
print(f" Training scans: {len(train_data)}")
|
|
print(f" Validation scans: {len(val_data)}")
|
|
|
|
return train_data, val_data
|
|
|
|
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata=None, num_epochs=10):
|
|
"""Train the model for specified number of epochs"""
|
|
best_val_acc = 0.0
|
|
|
|
# Enable mixed precision training
|
|
scaler = torch.amp.GradScaler('cuda')
|
|
|
|
for epoch in range(num_epochs):
|
|
# Training phase
|
|
model.train()
|
|
train_loss = 0.0
|
|
train_correct = 0
|
|
train_total = 0
|
|
train_pos_correct = 0
|
|
train_pos_total = 0
|
|
train_neg_correct = 0
|
|
train_neg_total = 0
|
|
|
|
train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
|
|
for batch_idx, (data, target) in enumerate(train_pbar):
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Use mixed precision training
|
|
with torch.amp.autocast('cuda'):
|
|
output = model(data)
|
|
loss = criterion(output, target).mean()
|
|
|
|
# Scale loss and backpropagate
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
train_loss += loss.detach().item()
|
|
pred = output.argmax(dim=1, keepdim=True)
|
|
train_correct += pred.eq(target.view_as(pred)).sum().item()
|
|
train_total += target.size(0)
|
|
|
|
# Track per-class accuracy
|
|
for i in range(target.size(0)):
|
|
if target[i] == 1: # Positive class
|
|
train_pos_total += 1
|
|
if pred[i] == target[i]:
|
|
train_pos_correct += 1
|
|
else: # Negative class
|
|
train_neg_total += 1
|
|
if pred[i] == target[i]:
|
|
train_neg_correct += 1
|
|
|
|
# Clear memory after each batch
|
|
del data, target, output, loss, pred
|
|
|
|
# Calculate per-class accuracies
|
|
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
|
|
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
|
|
|
|
train_pbar.set_postfix({
|
|
'Loss': f'{train_loss/(batch_idx+1):.4f}',
|
|
'Acc': f'{100.*train_correct/train_total:.2f}%',
|
|
'Pos': f'{train_pos_acc:.2f}%',
|
|
'Neg': f'{train_neg_acc:.2f}%'
|
|
})
|
|
|
|
# Validation phase
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_correct = 0
|
|
val_total = 0
|
|
val_pos_correct = 0
|
|
val_pos_total = 0
|
|
val_neg_correct = 0
|
|
val_neg_total = 0
|
|
|
|
with torch.no_grad():
|
|
val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
|
|
for batch_idx, (data, target) in enumerate(val_pbar):
|
|
data, target = data.to(device), target.to(device)
|
|
|
|
# Use mixed precision for validation too
|
|
with torch.amp.autocast('cuda'):
|
|
output = model(data)
|
|
loss = criterion(output, target).mean()
|
|
|
|
val_loss += loss.item()
|
|
pred = output.argmax(dim=1, keepdim=True)
|
|
val_correct += pred.eq(target.view_as(pred)).sum().item()
|
|
val_total += target.size(0)
|
|
|
|
# Track per-class accuracy
|
|
for i in range(target.size(0)):
|
|
if target[i] == 1: # Positive class
|
|
val_pos_total += 1
|
|
if pred[i] == target[i]:
|
|
val_pos_correct += 1
|
|
else: # Negative class
|
|
val_neg_total += 1
|
|
if pred[i] == target[i]:
|
|
val_neg_correct += 1
|
|
|
|
# Clear memory after each batch
|
|
del data, target, output, loss, pred
|
|
|
|
# Calculate per-class accuracies
|
|
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
|
|
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
|
|
|
|
val_pbar.set_postfix({
|
|
'Loss': f'{val_loss/(batch_idx+1):.4f}',
|
|
'Acc': f'{100.*val_correct/val_total:.2f}%',
|
|
'Pos': f'{val_pos_acc:.2f}%',
|
|
'Neg': f'{val_neg_acc:.2f}%'
|
|
})
|
|
|
|
# Calculate final accuracies
|
|
train_acc = 100. * train_correct / train_total
|
|
val_acc = 100. * val_correct / val_total
|
|
train_pos_acc = 100. * train_pos_correct / max(train_pos_total, 1)
|
|
train_neg_acc = 100. * train_neg_correct / max(train_neg_total, 1)
|
|
val_pos_acc = 100. * val_pos_correct / max(val_pos_total, 1)
|
|
val_neg_acc = 100. * val_neg_correct / max(val_neg_total, 1)
|
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}:')
|
|
print(f' Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% (Pos: {train_pos_acc:.2f}%, Neg: {train_neg_acc:.2f}%)')
|
|
print(f' Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}% (Pos: {val_pos_acc:.2f}%, Neg: {val_neg_acc:.2f}%)')
|
|
print(f' Train samples - Pos: {train_pos_total}, Neg: {train_neg_total}')
|
|
print(f' Val samples - Pos: {val_pos_total}, Neg: {val_neg_total}')
|
|
|
|
# Step the scheduler with validation accuracy
|
|
scheduler.step(val_acc)
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
print(f' Current learning rate: {current_lr:.6f}')
|
|
|
|
# Save best model
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
mode_suffix = "_quick" if args.quick else ""
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
|
save_model(model, transform, best_model_path, train_metadata)
|
|
print(f' New best validation accuracy: {val_acc:.2f}%')
|
|
print(f' Best model saved: {best_model_path}')
|
|
|
|
return model, val_acc, val_pos_acc, val_neg_acc
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train ResNet18 model on grid data')
|
|
parser.add_argument('--data-dir', default='data', help='Data directory')
|
|
parser.add_argument('--batch-size', type=int, default=64, help='Batch size (increased for better GPU utilization)')
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
|
|
parser.add_argument('--num-workers', type=int, default=None, help='Number of workers for preprocessing (default: auto)')
|
|
parser.add_argument('--quick', action='store_true', help='Quick mode: use only 1%% of scans for faster testing')
|
|
parser.add_argument('--hue-jitter', type=float, default=0.1, help='Hue jitter parameter for ColorJitter (default: 0.1)')
|
|
parser.add_argument('--scheduler-patience', type=int, default=3, help='Patience for ReduceLROnPlateau scheduler (default: 3)')
|
|
parser.add_argument('--scheduler-factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau scheduler (default: 0.1)')
|
|
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
|
args = parser.parse_args()
|
|
|
|
# Set device
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Using device: {device}')
|
|
print(f'Using hue jitter: {args.hue_jitter}')
|
|
|
|
# Create models directory if it doesn't exist
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
os.makedirs(models_dir, exist_ok=True)
|
|
|
|
# Load scan data
|
|
print("Loading scan data...")
|
|
scan_data = load_scan_data(args.data_dir)
|
|
|
|
if not scan_data:
|
|
print("No scan data found!")
|
|
return
|
|
|
|
# Quick mode: use only 1% of scans for faster testing
|
|
if args.quick:
|
|
original_count = len(scan_data)
|
|
scan_ids = list(scan_data.keys())
|
|
# Take 1% of scans, but at least 10 scans for meaningful training
|
|
quick_count = max(10, int(len(scan_ids) * 0.005))
|
|
selected_scan_ids = random.sample(scan_ids, quick_count)
|
|
scan_data = {scan_id: scan_data[scan_id] for scan_id in selected_scan_ids}
|
|
print(f"Quick mode enabled: Using {len(scan_data)} scans out of {original_count} (1%)")
|
|
|
|
# Split into train and validation
|
|
print("Splitting data into train and validation...")
|
|
train_data, val_data = split_train_val(scan_data)
|
|
|
|
# Define transforms
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Create datasets
|
|
print("Creating training dataset...")
|
|
train_dataset = GridDataset(train_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
|
print(f"Training samples: {len(train_dataset)}")
|
|
|
|
print("Creating validation dataset...")
|
|
val_dataset = GridDataset(val_data, data_dir=args.data_dir, transform=transform, num_workers=args.num_workers, hue_jitter=args.hue_jitter)
|
|
print(f"Validation samples: {len(val_dataset)}")
|
|
|
|
# Create data loaders
|
|
print("Creating data loaders...")
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=8,
|
|
pin_memory=False,
|
|
persistent_workers=True,
|
|
prefetch_factor=2
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
pin_memory=False,
|
|
persistent_workers=True,
|
|
prefetch_factor=2
|
|
)
|
|
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
|
|
|
# Create model
|
|
print("Creating ResNet18 model...")
|
|
# model = torchvision.models.resnet18(pretrained=True)
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt'))
|
|
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
|
if hasattr(model, '_orig_mod'):
|
|
model = model._orig_mod
|
|
# Modify final layer for binary classification
|
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
|
model = model.to(device)
|
|
print(f"Model created and moved to {device}")
|
|
|
|
# Define loss function and optimizer
|
|
print("Setting up loss function and optimizer...")
|
|
# Use FocalLoss with 0.99:0.01 weights for positive/negative classes
|
|
pos_weight = 0.99
|
|
criterion = FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device))
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
# Create ReduceLROnPlateau scheduler
|
|
scheduler = ReduceLROnPlateau(
|
|
optimizer,
|
|
mode='max',
|
|
factor=args.scheduler_factor,
|
|
patience=args.scheduler_patience,
|
|
min_lr=args.scheduler_min_lr
|
|
)
|
|
|
|
print(f"Using Adam optimizer with lr={args.lr}")
|
|
print(f"Using FocalLoss with weights: Negative={1.0 - pos_weight:.2f}, Positive={pos_weight:.2f}")
|
|
print(f"Using ReduceLROnPlateau scheduler with factor={args.scheduler_factor}, patience={args.scheduler_patience}, min_lr={args.scheduler_min_lr}")
|
|
|
|
# Collect training metadata
|
|
train_scan_ids = list(train_data.keys())
|
|
train_codes = set()
|
|
for scan_id, metadata in train_data.items():
|
|
code = metadata.get('code')
|
|
if code:
|
|
train_codes.add(code)
|
|
|
|
train_metadata = {
|
|
'train_scan_ids': train_scan_ids,
|
|
'train_codes': list(train_codes),
|
|
'hue_jitter': args.hue_jitter,
|
|
'quick_mode': args.quick
|
|
}
|
|
|
|
print(f"Training metadata:")
|
|
print(f" Train scan IDs: {len(train_scan_ids)}")
|
|
print(f" Train codes: {len(train_codes)}")
|
|
print(f" Hue jitter: {args.hue_jitter}")
|
|
print(f" Quick mode: {args.quick}")
|
|
|
|
# Train the model
|
|
print("Starting training...")
|
|
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
|
|
|
|
# Save final model
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
mode_suffix = "_quick" if args.quick else ""
|
|
models_dir = os.path.expanduser('~/emblem/models')
|
|
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
|
save_model(model, transform, final_model_path, train_metadata)
|
|
print(f"Training completed! Final model saved: {final_model_path}")
|
|
|
|
if __name__ == '__main__':
|
|
main() |