themblem/emblem5/ai/check_accuracy.py
2025-12-27 10:33:49 +00:00

171 lines
6.0 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Script to check model accuracy on scans
By default, randomly selects 1000 images and compares with pos/neg labels from metadata
"""
import sys
import os
import json
import argparse
import random
from PIL import Image
from tqdm import tqdm
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from common import load_model, predict, parse_ranges, in_range
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_path', help='Path to model file')
parser.add_argument('data_dir', help='Path to data directory')
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
parser.add_argument('--sample-size', type=int, default=1000, help='Number of scans to randomly sample (default: 1000)')
args = parser.parse_args()
model_path = args.model_path
data_dir = args.data_dir
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
print(f"Loading model from {model_path}...")
model, transforms = load_model(model_path)
# Ensure model is not compiled
if hasattr(model, '_orig_mod'):
model = model._orig_mod
print("Model loaded successfully\n")
scans_dir = os.path.join(data_dir, 'scans')
if not os.path.exists(scans_dir):
print(f"Error: Scans directory not found: {scans_dir}")
sys.exit(1)
# Collect scan IDs with metadata
scan_candidates = []
print("Collecting scan IDs...")
all_scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))]
# If SCAN_IDS is provided, filter by those ranges
if args.scan_ids:
filter_ranges = parse_ranges(args.scan_ids)
filtered_scan_ids = []
for scan_id_str in all_scan_ids:
try:
scan_id_int = int(scan_id_str)
for val_range in filter_ranges:
if in_range(scan_id_int, val_range):
filtered_scan_ids.append(scan_id_str)
break
except ValueError:
continue
all_scan_ids = filtered_scan_ids
print(f"Filtered to {len(all_scan_ids)} scans matching range(s): {args.scan_ids}")
# Load metadata for all candidate scans
for scan_id in all_scan_ids:
scan_dir = os.path.join(scans_dir, scan_id)
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
metadata_file = os.path.join(scan_dir, 'metadata.json')
if not os.path.exists(sbs_file):
continue
if not os.path.exists(metadata_file):
continue
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
labels = metadata.get('labels', [])
if 'pos' not in labels and 'neg' not in labels:
continue
scan_candidates.append({
'scan_id': scan_id,
'label': 1 if 'pos' in labels else 0,
'sbs_file': sbs_file
})
except Exception as e:
continue
print(f"Found {len(scan_candidates)} scans with valid metadata and sbs.jpg")
# Randomly sample if needed
if len(scan_candidates) > args.sample_size:
scan_candidates = random.sample(scan_candidates, args.sample_size)
print(f"Randomly sampled {args.sample_size} scans")
print(f"Evaluating {len(scan_candidates)} scans...\n")
# Run predictions
pos_correct = 0
pos_total = 0
neg_correct = 0
neg_total = 0
errors = []
for scan_info in tqdm(scan_candidates, desc="Evaluating"):
scan_id = scan_info['scan_id']
true_label = scan_info['label']
sbs_file = scan_info['sbs_file']
try:
image = Image.open(sbs_file).convert('RGB')
predicted_class, probabilities = predict(model, transforms, image, ncells=3)
neg_sum = sum([p[0] for p in probabilities])
pos_sum = sum([p[1] for p in probabilities])
total = neg_sum + pos_sum
neg_prob = neg_sum / total if total > 0 else 0
pos_prob = pos_sum / total if total > 0 else 0
if true_label == 1:
pos_total += 1
if predicted_class == 1:
pos_correct += 1
else:
# Store first few errors for debugging
if len(errors) < 5:
errors.append((scan_id, 'pos', neg_prob, pos_prob))
else:
neg_total += 1
if predicted_class == 0:
neg_correct += 1
else:
# Store first few errors for debugging
if len(errors) < 5:
errors.append((scan_id, 'neg', neg_prob, pos_prob))
except Exception as e:
print(f"\nError processing {scan_id}: {e}")
continue
# Calculate accuracies
pos_acc = pos_correct / pos_total if pos_total > 0 else 0
neg_acc = neg_correct / neg_total if neg_total > 0 else 0
total_correct = pos_correct + neg_correct
total_samples = pos_total + neg_total
overall_acc = total_correct / total_samples if total_samples > 0 else 0
# Print results
print("\n" + "="*60)
print("Accuracy Results:")
print("="*60)
print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})")
print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})")
print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})")
print("="*60)
if errors:
print("\nSample misclassifications:")
for scan_id, true_label, neg_prob, pos_prob in errors:
print(f" Scan {scan_id} (true: {true_label}): neg={neg_prob:.2%}, pos={pos_prob:.2%}")
if __name__ == '__main__':
main()