diff --git a/emblem5/ai/check_accuracy.py b/emblem5/ai/check_accuracy.py index 31d5410..3b83d9a 100755 --- a/emblem5/ai/check_accuracy.py +++ b/emblem5/ai/check_accuracy.py @@ -1,37 +1,36 @@ #!/usr/bin/env python3 """ -Temporary script to check model accuracy on specific scan ranges -Positive samples: 357193-358023 -Negative samples: 358024-358808 +Script to check model accuracy on scans +By default, randomly selects 1000 images and compares with pos/neg labels from metadata """ import sys import os import json +import argparse +import random from PIL import Image from tqdm import tqdm # Add current directory to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -from common import load_model, predict +from common import load_model, predict, parse_ranges, in_range def main(): - if len(sys.argv) < 3: - print("Usage: python3 check_accuracy.py ") - print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem") - sys.exit(1) + parser = argparse.ArgumentParser() + parser.add_argument('model_path', help='Path to model file') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")') + parser.add_argument('--sample-size', type=int, default=1000, help='Number of scans to randomly sample (default: 1000)') + args = parser.parse_args() - model_path = sys.argv[1] - data_dir = sys.argv[2] + model_path = args.model_path + data_dir = args.data_dir if not os.path.exists(model_path): print(f"Error: Model file not found: {model_path}") sys.exit(1) - # Define scan ranges - pos_start, pos_end = 357193, 358023 - neg_start, neg_end = 358024, 358808 - print(f"Loading model from {model_path}...") model, transforms = load_model(model_path) # Ensure model is not compiled @@ -44,57 +43,80 @@ def main(): print(f"Error: Scans directory not found: {scans_dir}") sys.exit(1) - # Collect scan IDs in ranges - pos_scans = [] - neg_scans = [] + # Collect scan IDs with metadata + scan_candidates = [] print("Collecting scan IDs...") - for scan_id in range(pos_start, pos_end + 1): - scan_dir = os.path.join(scans_dir, str(scan_id)) - sbs_file = os.path.join(scan_dir, 'sbs.jpg') - if os.path.exists(sbs_file): - pos_scans.append(str(scan_id)) + all_scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))] - for scan_id in range(neg_start, neg_end + 1): - scan_dir = os.path.join(scans_dir, str(scan_id)) - sbs_file = os.path.join(scan_dir, 'sbs.jpg') - if os.path.exists(sbs_file): - neg_scans.append(str(scan_id)) + # If SCAN_IDS is provided, filter by those ranges + if args.scan_ids: + filter_ranges = parse_ranges(args.scan_ids) + filtered_scan_ids = [] + for scan_id_str in all_scan_ids: + try: + scan_id_int = int(scan_id_str) + for val_range in filter_ranges: + if in_range(scan_id_int, val_range): + filtered_scan_ids.append(scan_id_str) + break + except ValueError: + continue + all_scan_ids = filtered_scan_ids + print(f"Filtered to {len(all_scan_ids)} scans matching range(s): {args.scan_ids}") - print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})") - print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n") + # Load metadata for all candidate scans + for scan_id in all_scan_ids: + scan_dir = os.path.join(scans_dir, scan_id) + sbs_file = os.path.join(scan_dir, 'sbs.jpg') + metadata_file = os.path.join(scan_dir, 'metadata.json') + + if not os.path.exists(sbs_file): + continue + + if not os.path.exists(metadata_file): + continue + + try: + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + labels = metadata.get('labels', []) + if 'pos' not in labels and 'neg' not in labels: + continue + + scan_candidates.append({ + 'scan_id': scan_id, + 'label': 1 if 'pos' in labels else 0, + 'sbs_file': sbs_file + }) + except Exception as e: + continue + + print(f"Found {len(scan_candidates)} scans with valid metadata and sbs.jpg") + + # Randomly sample if needed + if len(scan_candidates) > args.sample_size: + scan_candidates = random.sample(scan_candidates, args.sample_size) + print(f"Randomly sampled {args.sample_size} scans") + + print(f"Evaluating {len(scan_candidates)} scans...\n") # Run predictions pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 + errors = [] - print("Evaluating positive samples...") - for scan_id in tqdm(pos_scans, desc="Positive"): - scan_dir = os.path.join(scans_dir, scan_id) - sbs_file = os.path.join(scan_dir, 'sbs.jpg') + for scan_info in tqdm(scan_candidates, desc="Evaluating"): + scan_id = scan_info['scan_id'] + true_label = scan_info['label'] + sbs_file = scan_info['sbs_file'] try: image = Image.open(sbs_file).convert('RGB') - predicted_class, probabilities = predict(model, transforms, image) - - pos_total += 1 - if predicted_class == 1: - pos_correct += 1 - except Exception as e: - print(f"\nError processing {scan_id}: {e}") - continue - - print("\nEvaluating negative samples...") - neg_errors = [] - for scan_id in tqdm(neg_scans, desc="Negative"): - scan_dir = os.path.join(scans_dir, scan_id) - sbs_file = os.path.join(scan_dir, 'sbs.jpg') - - try: - image = Image.open(sbs_file).convert('RGB') - predicted_class, probabilities = predict(model, transforms, image) + predicted_class, probabilities = predict(model, transforms, image, ncells=3) neg_sum = sum([p[0] for p in probabilities]) pos_sum = sum([p[1] for p in probabilities]) @@ -102,13 +124,22 @@ def main(): neg_prob = neg_sum / total if total > 0 else 0 pos_prob = pos_sum / total if total > 0 else 0 - neg_total += 1 - if predicted_class == 0: - neg_correct += 1 + if true_label == 1: + pos_total += 1 + if predicted_class == 1: + pos_correct += 1 + else: + # Store first few errors for debugging + if len(errors) < 5: + errors.append((scan_id, 'pos', neg_prob, pos_prob)) else: - # Store first few errors for debugging - if len(neg_errors) < 5: - neg_errors.append((scan_id, neg_prob, pos_prob)) + neg_total += 1 + if predicted_class == 0: + neg_correct += 1 + else: + # Store first few errors for debugging + if len(errors) < 5: + errors.append((scan_id, 'neg', neg_prob, pos_prob)) except Exception as e: print(f"\nError processing {scan_id}: {e}") continue @@ -129,10 +160,10 @@ def main(): print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})") print("="*60) - if neg_errors: - print("\nSample negative misclassifications:") - for scan_id, neg_prob, pos_prob in neg_errors: - print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}") + if errors: + print("\nSample misclassifications:") + for scan_id, true_label, neg_prob, pos_prob in errors: + print(f" Scan {scan_id} (true: {true_label}): neg={neg_prob:.2%}, pos={pos_prob:.2%}") if __name__ == '__main__': main()