#!/usr/bin/env python3 """ Script to check model accuracy on scans By default, randomly selects 1000 images and compares with pos/neg labels from metadata """ import sys import os import json import argparse import random from PIL import Image from tqdm import tqdm # Add current directory to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from common import load_model, predict, parse_ranges, in_range def main(): parser = argparse.ArgumentParser() parser.add_argument('model_path', help='Path to model file') parser.add_argument('data_dir', help='Path to data directory') parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")') parser.add_argument('--sample-size', type=int, default=1000, help='Number of scans to randomly sample (default: 1000)') args = parser.parse_args() model_path = args.model_path data_dir = args.data_dir if not os.path.exists(model_path): print(f"Error: Model file not found: {model_path}") sys.exit(1) print(f"Loading model from {model_path}...") model, transforms = load_model(model_path) # Ensure model is not compiled if hasattr(model, '_orig_mod'): model = model._orig_mod print("Model loaded successfully\n") scans_dir = os.path.join(data_dir, 'scans') if not os.path.exists(scans_dir): print(f"Error: Scans directory not found: {scans_dir}") sys.exit(1) # Collect scan IDs with metadata scan_candidates = [] print("Collecting scan IDs...") all_scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))] # If SCAN_IDS is provided, filter by those ranges if args.scan_ids: filter_ranges = parse_ranges(args.scan_ids) filtered_scan_ids = [] for scan_id_str in all_scan_ids: try: scan_id_int = int(scan_id_str) for val_range in filter_ranges: if in_range(scan_id_int, val_range): filtered_scan_ids.append(scan_id_str) break except ValueError: continue all_scan_ids = filtered_scan_ids print(f"Filtered to {len(all_scan_ids)} scans matching range(s): {args.scan_ids}") # Load metadata for all candidate scans for scan_id in all_scan_ids: scan_dir = os.path.join(scans_dir, scan_id) sbs_file = os.path.join(scan_dir, 'sbs.jpg') metadata_file = os.path.join(scan_dir, 'metadata.json') if not os.path.exists(sbs_file): continue if not os.path.exists(metadata_file): continue try: with open(metadata_file, 'r') as f: metadata = json.load(f) labels = metadata.get('labels', []) if 'pos' not in labels and 'neg' not in labels: continue scan_candidates.append({ 'scan_id': scan_id, 'label': 1 if 'pos' in labels else 0, 'sbs_file': sbs_file }) except Exception as e: continue print(f"Found {len(scan_candidates)} scans with valid metadata and sbs.jpg") # Randomly sample if needed if len(scan_candidates) > args.sample_size: scan_candidates = random.sample(scan_candidates, args.sample_size) print(f"Randomly sampled {args.sample_size} scans") print(f"Evaluating {len(scan_candidates)} scans...\n") # Run predictions pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 errors = [] for scan_info in tqdm(scan_candidates, desc="Evaluating"): scan_id = scan_info['scan_id'] true_label = scan_info['label'] sbs_file = scan_info['sbs_file'] try: image = Image.open(sbs_file).convert('RGB') predicted_class, probabilities = predict(model, transforms, image, ncells=3) neg_sum = sum([p[0] for p in probabilities]) pos_sum = sum([p[1] for p in probabilities]) total = neg_sum + pos_sum neg_prob = neg_sum / total if total > 0 else 0 pos_prob = pos_sum / total if total > 0 else 0 if true_label == 1: pos_total += 1 if predicted_class == 1: pos_correct += 1 else: # Store first few errors for debugging if len(errors) < 5: errors.append((scan_id, 'pos', neg_prob, pos_prob)) else: neg_total += 1 if predicted_class == 0: neg_correct += 1 else: # Store first few errors for debugging if len(errors) < 5: errors.append((scan_id, 'neg', neg_prob, pos_prob)) except Exception as e: print(f"\nError processing {scan_id}: {e}") continue # Calculate accuracies pos_acc = pos_correct / pos_total if pos_total > 0 else 0 neg_acc = neg_correct / neg_total if neg_total > 0 else 0 total_correct = pos_correct + neg_correct total_samples = pos_total + neg_total overall_acc = total_correct / total_samples if total_samples > 0 else 0 # Print results print("\n" + "="*60) print("Accuracy Results:") print("="*60) print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})") print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})") print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})") print("="*60) if errors: print("\nSample misclassifications:") for scan_id, true_label, neg_prob, pos_prob in errors: print(f" Scan {scan_id} (true: {true_label}): neg={neg_prob:.2%}, pos={pos_prob:.2%}") if __name__ == '__main__': main()