#!/usr/bin/env python3 """ Temporary script to check model accuracy on specific scan ranges Positive samples: 357193-358023 Negative samples: 358024-358808 """ import sys import os import json from PIL import Image from tqdm import tqdm # Add current directory to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from common import load_model, predict def main(): if len(sys.argv) < 3: print("Usage: python3 check_accuracy.py ") print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem") sys.exit(1) model_path = sys.argv[1] data_dir = sys.argv[2] if not os.path.exists(model_path): print(f"Error: Model file not found: {model_path}") sys.exit(1) # Define scan ranges pos_start, pos_end = 357193, 358023 neg_start, neg_end = 358024, 358808 print(f"Loading model from {model_path}...") model, transforms = load_model(model_path) # Ensure model is not compiled if hasattr(model, '_orig_mod'): model = model._orig_mod print("Model loaded successfully\n") scans_dir = os.path.join(data_dir, 'scans') if not os.path.exists(scans_dir): print(f"Error: Scans directory not found: {scans_dir}") sys.exit(1) # Collect scan IDs in ranges pos_scans = [] neg_scans = [] print("Collecting scan IDs...") for scan_id in range(pos_start, pos_end + 1): scan_dir = os.path.join(scans_dir, str(scan_id)) sbs_file = os.path.join(scan_dir, 'sbs.jpg') if os.path.exists(sbs_file): pos_scans.append(str(scan_id)) for scan_id in range(neg_start, neg_end + 1): scan_dir = os.path.join(scans_dir, str(scan_id)) sbs_file = os.path.join(scan_dir, 'sbs.jpg') if os.path.exists(sbs_file): neg_scans.append(str(scan_id)) print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})") print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n") # Run predictions pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 print("Evaluating positive samples...") for scan_id in tqdm(pos_scans, desc="Positive"): scan_dir = os.path.join(scans_dir, scan_id) sbs_file = os.path.join(scan_dir, 'sbs.jpg') try: image = Image.open(sbs_file).convert('RGB') predicted_class, probabilities = predict(model, transforms, image) pos_total += 1 if predicted_class == 1: pos_correct += 1 except Exception as e: print(f"\nError processing {scan_id}: {e}") continue print("\nEvaluating negative samples...") neg_errors = [] for scan_id in tqdm(neg_scans, desc="Negative"): scan_dir = os.path.join(scans_dir, scan_id) sbs_file = os.path.join(scan_dir, 'sbs.jpg') try: image = Image.open(sbs_file).convert('RGB') predicted_class, probabilities = predict(model, transforms, image) neg_sum = sum([p[0] for p in probabilities]) pos_sum = sum([p[1] for p in probabilities]) total = neg_sum + pos_sum neg_prob = neg_sum / total if total > 0 else 0 pos_prob = pos_sum / total if total > 0 else 0 neg_total += 1 if predicted_class == 0: neg_correct += 1 else: # Store first few errors for debugging if len(neg_errors) < 5: neg_errors.append((scan_id, neg_prob, pos_prob)) except Exception as e: print(f"\nError processing {scan_id}: {e}") continue # Calculate accuracies pos_acc = pos_correct / pos_total if pos_total > 0 else 0 neg_acc = neg_correct / neg_total if neg_total > 0 else 0 total_correct = pos_correct + neg_correct total_samples = pos_total + neg_total overall_acc = total_correct / total_samples if total_samples > 0 else 0 # Print results print("\n" + "="*60) print("Accuracy Results:") print("="*60) print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})") print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})") print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})") print("="*60) if neg_errors: print("\nSample negative misclassifications:") for scan_id, neg_prob, pos_prob in neg_errors: print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}") if __name__ == '__main__': main()