Add wandb logging, finetune old code selection, and improve file naming
- Add required wandb integration with hardcoded API key - Log training metrics (loss, accuracy, per-class metrics) to wandb - In finetune mode, automatically add old codes to training set to prevent forgetting - Zero-pad epoch numbers in model filenames (ep001, ep002) for better sorting - Change upload.py to display sizes and speeds in KB instead of bytes - Update Makefile finetune target to use 100 epochs instead of 30 - Set default wandb project to 'euphon/themblem'
This commit is contained in:
parent
90043c2541
commit
32380a8082
2
Makefile
2
Makefile
@ -123,7 +123,7 @@ finetune: FORCE
|
|||||||
echo "Error: Model file not found: $$MODEL_PATH"; \
|
echo "Error: Model file not found: $$MODEL_PATH"; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 30
|
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --scan-ids "$(SCAN_IDS)" --model "$$MODEL_PATH" --epochs 100
|
||||||
|
|
||||||
browser: FORCE
|
browser: FORCE
|
||||||
cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR)
|
cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR)
|
||||||
|
|||||||
@ -42,6 +42,7 @@ import multiprocessing as mp
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from common import *
|
from common import *
|
||||||
|
import wandb
|
||||||
|
|
||||||
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
|
def process_scan_grid(scan_item, data_dir='data', hue_jitter=0.1):
|
||||||
"""Process a single scan to create grid files and metadata
|
"""Process a single scan to create grid files and metadata
|
||||||
@ -424,16 +425,35 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler
|
|||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
print(f' Current learning rate: {current_lr:.6f}')
|
print(f' Current learning rate: {current_lr:.6f}')
|
||||||
|
|
||||||
|
# Log to wandb
|
||||||
|
wandb.log({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'train/loss': train_loss / len(train_loader),
|
||||||
|
'train/acc': train_acc,
|
||||||
|
'train/pos_acc': train_pos_acc,
|
||||||
|
'train/neg_acc': train_neg_acc,
|
||||||
|
'train/pos_samples': train_pos_total,
|
||||||
|
'train/neg_samples': train_neg_total,
|
||||||
|
'val/loss': val_loss / len(val_loader),
|
||||||
|
'val/acc': val_acc,
|
||||||
|
'val/pos_acc': val_pos_acc,
|
||||||
|
'val/neg_acc': val_neg_acc,
|
||||||
|
'val/pos_samples': val_pos_total,
|
||||||
|
'val/neg_samples': val_neg_total,
|
||||||
|
'learning_rate': current_lr,
|
||||||
|
})
|
||||||
|
|
||||||
# Save best model
|
# Save best model
|
||||||
if val_acc > best_val_acc:
|
if val_acc > best_val_acc:
|
||||||
best_val_acc = val_acc
|
best_val_acc = val_acc
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
mode_suffix = "_quick" if args.quick else ""
|
mode_suffix = "_quick" if args.quick else ""
|
||||||
models_dir = os.path.expanduser('~/emblem/models')
|
models_dir = os.path.expanduser('~/emblem/models')
|
||||||
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1:03d}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
||||||
save_model(model, transform, best_model_path, train_metadata)
|
save_model(model, transform, best_model_path, train_metadata)
|
||||||
print(f' New best validation accuracy: {val_acc:.2f}%')
|
print(f' New best validation accuracy: {val_acc:.2f}%')
|
||||||
print(f' Best model saved: {best_model_path}')
|
print(f' Best model saved: {best_model_path}')
|
||||||
|
wandb.log({'best_val_acc': best_val_acc, 'best_epoch': epoch + 1})
|
||||||
|
|
||||||
return model, val_acc, val_pos_acc, val_neg_acc
|
return model, val_acc, val_pos_acc, val_neg_acc
|
||||||
|
|
||||||
@ -451,6 +471,8 @@ def main():
|
|||||||
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
parser.add_argument('--scheduler-min-lr', type=float, default=1e-6, help='Minimum learning rate for ReduceLROnPlateau scheduler (default: 1e-6)')
|
||||||
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
||||||
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
|
parser.add_argument('--model', type=str, default=None, help='Path to model file to load for finetuning')
|
||||||
|
parser.add_argument('--wandb-project', type=str, default='euphon/themblem', help='W&B project name (default: euphon/themblem)')
|
||||||
|
parser.add_argument('--wandb-name', type=str, default=None, help='W&B run name (default: auto-generated)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
@ -462,14 +484,17 @@ def main():
|
|||||||
models_dir = os.path.expanduser('~/emblem/models')
|
models_dir = os.path.expanduser('~/emblem/models')
|
||||||
os.makedirs(models_dir, exist_ok=True)
|
os.makedirs(models_dir, exist_ok=True)
|
||||||
|
|
||||||
# Load scan data
|
# Load full scan data (needed for finetune mode to find old codes)
|
||||||
print("Loading scan data...")
|
print("Loading scan data...")
|
||||||
scan_data = load_scan_data(args.data_dir)
|
full_scan_data = load_scan_data(args.data_dir)
|
||||||
|
|
||||||
if not scan_data:
|
if not full_scan_data:
|
||||||
print("No scan data found!")
|
print("No scan data found!")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Start with full scan data, then filter
|
||||||
|
scan_data = full_scan_data.copy()
|
||||||
|
|
||||||
# Filter by scan IDs if provided
|
# Filter by scan IDs if provided
|
||||||
if args.scan_ids:
|
if args.scan_ids:
|
||||||
ranges = parse_ranges(args.scan_ids)
|
ranges = parse_ranges(args.scan_ids)
|
||||||
@ -506,6 +531,57 @@ def main():
|
|||||||
avoid_overlap = args.model is None
|
avoid_overlap = args.model is None
|
||||||
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
|
train_data, val_data = split_train_val(scan_data, avoid_overlap=avoid_overlap)
|
||||||
|
|
||||||
|
# In finetune mode, add old codes to prevent forgetting
|
||||||
|
old_codes_info = None
|
||||||
|
if args.model:
|
||||||
|
# Get distinct codes from current finetune dataset
|
||||||
|
finetune_codes = set()
|
||||||
|
for scan_id, metadata in scan_data.items():
|
||||||
|
code = metadata.get('code')
|
||||||
|
if code:
|
||||||
|
finetune_codes.add(code)
|
||||||
|
|
||||||
|
num_finetune_codes = len(finetune_codes)
|
||||||
|
print(f"Finetune mode: Found {num_finetune_codes} distinct codes in finetune dataset")
|
||||||
|
|
||||||
|
if num_finetune_codes > 0:
|
||||||
|
# Find old codes (codes in full dataset but not in finetune dataset)
|
||||||
|
old_codes = set()
|
||||||
|
for scan_id, metadata in full_scan_data.items():
|
||||||
|
code = metadata.get('code')
|
||||||
|
if code and code not in finetune_codes:
|
||||||
|
old_codes.add(code)
|
||||||
|
|
||||||
|
print(f"Found {len(old_codes)} distinct old codes in full dataset")
|
||||||
|
|
||||||
|
if len(old_codes) > 0:
|
||||||
|
# Select the same number of distinct random old codes
|
||||||
|
num_old_codes_to_select = min(num_finetune_codes, len(old_codes))
|
||||||
|
selected_old_codes = random.sample(list(old_codes), num_old_codes_to_select)
|
||||||
|
print(f"Selected {num_old_codes_to_select} old codes to prevent forgetting: {selected_old_codes}")
|
||||||
|
|
||||||
|
# Add all scans for selected old codes to training set
|
||||||
|
old_scans_added = 0
|
||||||
|
for scan_id, metadata in full_scan_data.items():
|
||||||
|
code = metadata.get('code')
|
||||||
|
if code and code in selected_old_codes:
|
||||||
|
# Only add if not already in train_data (avoid duplicates)
|
||||||
|
if scan_id not in train_data:
|
||||||
|
train_data[scan_id] = metadata
|
||||||
|
old_scans_added += 1
|
||||||
|
|
||||||
|
print(f"Added {old_scans_added} scans from old codes to training set")
|
||||||
|
print(f"Training set size: {len(train_data)} scans (including {old_scans_added} from old codes)")
|
||||||
|
old_codes_info = {
|
||||||
|
'num_old_codes': num_old_codes_to_select,
|
||||||
|
'old_codes': selected_old_codes,
|
||||||
|
'old_scans_added': old_scans_added
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("No old codes found in full dataset")
|
||||||
|
else:
|
||||||
|
print("No codes found in finetune dataset, skipping old code selection")
|
||||||
|
|
||||||
# Define transforms
|
# Define transforms
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
@ -620,6 +696,35 @@ def main():
|
|||||||
print(f" Hue jitter: {args.hue_jitter}")
|
print(f" Hue jitter: {args.hue_jitter}")
|
||||||
print(f" Quick mode: {args.quick}")
|
print(f" Quick mode: {args.quick}")
|
||||||
|
|
||||||
|
# Initialize wandb
|
||||||
|
wandb.login(key='ec22e6ed1ed9891779d57da600f889294f83e41d')
|
||||||
|
wandb_config = {
|
||||||
|
'batch_size': args.batch_size,
|
||||||
|
'lr': args.lr,
|
||||||
|
'epochs': args.epochs,
|
||||||
|
'hue_jitter': args.hue_jitter,
|
||||||
|
'scheduler_patience': args.scheduler_patience,
|
||||||
|
'scheduler_factor': args.scheduler_factor,
|
||||||
|
'scheduler_min_lr': args.scheduler_min_lr,
|
||||||
|
'quick_mode': args.quick,
|
||||||
|
'train_scans': len(train_scan_ids),
|
||||||
|
'val_scans': len(val_data),
|
||||||
|
'train_samples': len(train_dataset),
|
||||||
|
'val_samples': len(val_dataset),
|
||||||
|
'model_path': args.model if args.model else 'default',
|
||||||
|
'finetune_mode': args.model is not None,
|
||||||
|
}
|
||||||
|
if old_codes_info:
|
||||||
|
wandb_config['old_codes_count'] = old_codes_info['num_old_codes']
|
||||||
|
wandb_config['old_scans_added'] = old_codes_info['old_scans_added']
|
||||||
|
wandb_name = args.wandb_name or f"train2-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
wandb.init(
|
||||||
|
project=args.wandb_project,
|
||||||
|
name=wandb_name,
|
||||||
|
config=wandb_config,
|
||||||
|
)
|
||||||
|
print(f"W&B logging enabled: project={args.wandb_project}, run={wandb_name}")
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
|
model, final_val_acc, final_val_pos_acc, final_val_neg_acc = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, transform, args, train_metadata, args.epochs)
|
||||||
@ -628,9 +733,12 @@ def main():
|
|||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
mode_suffix = "_quick" if args.quick else ""
|
mode_suffix = "_quick" if args.quick else ""
|
||||||
models_dir = os.path.expanduser('~/emblem/models')
|
models_dir = os.path.expanduser('~/emblem/models')
|
||||||
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs:03d}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
||||||
save_model(model, transform, final_model_path, train_metadata)
|
save_model(model, transform, final_model_path, train_metadata)
|
||||||
print(f"Training completed! Final model saved: {final_model_path}")
|
print(f"Training completed! Final model saved: {final_model_path}")
|
||||||
|
|
||||||
|
# Finish wandb run
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
@ -22,14 +22,17 @@ class ProgressCallback:
|
|||||||
self.consumed = consumed
|
self.consumed = consumed
|
||||||
elapsed = time.time() - self.start_time
|
elapsed = time.time() - self.start_time
|
||||||
if elapsed > 0:
|
if elapsed > 0:
|
||||||
bps = consumed / elapsed
|
kbps = (consumed / elapsed) / 1024 # Convert to KB/s
|
||||||
percent = 100 * consumed / total if total else 0
|
percent = 100 * consumed / total if total else 0
|
||||||
print(f"\r{self.filename}: {percent:.1f}% ({consumed}/{total} bytes, {bps:.0f} B/s)", end='', flush=True)
|
consumed_kb = consumed / 1024
|
||||||
|
total_kb = total / 1024
|
||||||
|
print(f"\r{self.filename}: {percent:.1f}% ({consumed_kb:.0f}/{total_kb:.0f} KB, {kbps:.0f} KB/s)", end='', flush=True)
|
||||||
|
|
||||||
for x in sys.argv[1:]:
|
for x in sys.argv[1:]:
|
||||||
file_size = os.path.getsize(x)
|
file_size = os.path.getsize(x)
|
||||||
|
file_size_kb = file_size / 1024
|
||||||
filename = os.path.basename(x)
|
filename = os.path.basename(x)
|
||||||
print(f"Uploading {filename} ({file_size} bytes)...")
|
print(f"Uploading {filename} ({file_size_kb:.0f} KB)...")
|
||||||
bucket.put_object_from_file(filename, x, progress_callback=ProgressCallback(filename, file_size))
|
bucket.put_object_from_file(filename, x, progress_callback=ProgressCallback(filename, file_size))
|
||||||
print() # New line after progress completes
|
print() # New line after progress completes
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user