140 lines
4.7 KiB
Python
Executable File
140 lines
4.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Temporary script to check model accuracy on specific scan ranges
|
|
Positive samples: 357193-358023
|
|
Negative samples: 358024-358808
|
|
"""
|
|
import sys
|
|
import os
|
|
import json
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
# Add current directory to path for imports
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from common import load_model, predict
|
|
|
|
def main():
|
|
if len(sys.argv) < 3:
|
|
print("Usage: python3 check_accuracy.py <model_path> <data_dir>")
|
|
print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem")
|
|
sys.exit(1)
|
|
|
|
model_path = sys.argv[1]
|
|
data_dir = sys.argv[2]
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Error: Model file not found: {model_path}")
|
|
sys.exit(1)
|
|
|
|
# Define scan ranges
|
|
pos_start, pos_end = 357193, 358023
|
|
neg_start, neg_end = 358024, 358808
|
|
|
|
print(f"Loading model from {model_path}...")
|
|
model, transforms = load_model(model_path)
|
|
# Ensure model is not compiled
|
|
if hasattr(model, '_orig_mod'):
|
|
model = model._orig_mod
|
|
print("Model loaded successfully\n")
|
|
|
|
scans_dir = os.path.join(data_dir, 'scans')
|
|
if not os.path.exists(scans_dir):
|
|
print(f"Error: Scans directory not found: {scans_dir}")
|
|
sys.exit(1)
|
|
|
|
# Collect scan IDs in ranges
|
|
pos_scans = []
|
|
neg_scans = []
|
|
|
|
print("Collecting scan IDs...")
|
|
for scan_id in range(pos_start, pos_end + 1):
|
|
scan_dir = os.path.join(scans_dir, str(scan_id))
|
|
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
|
if os.path.exists(sbs_file):
|
|
pos_scans.append(str(scan_id))
|
|
|
|
for scan_id in range(neg_start, neg_end + 1):
|
|
scan_dir = os.path.join(scans_dir, str(scan_id))
|
|
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
|
if os.path.exists(sbs_file):
|
|
neg_scans.append(str(scan_id))
|
|
|
|
print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})")
|
|
print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n")
|
|
|
|
# Run predictions
|
|
pos_correct = 0
|
|
pos_total = 0
|
|
neg_correct = 0
|
|
neg_total = 0
|
|
|
|
print("Evaluating positive samples...")
|
|
for scan_id in tqdm(pos_scans, desc="Positive"):
|
|
scan_dir = os.path.join(scans_dir, scan_id)
|
|
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
|
|
|
try:
|
|
image = Image.open(sbs_file).convert('RGB')
|
|
predicted_class, probabilities = predict(model, transforms, image)
|
|
|
|
pos_total += 1
|
|
if predicted_class == 1:
|
|
pos_correct += 1
|
|
except Exception as e:
|
|
print(f"\nError processing {scan_id}: {e}")
|
|
continue
|
|
|
|
print("\nEvaluating negative samples...")
|
|
neg_errors = []
|
|
for scan_id in tqdm(neg_scans, desc="Negative"):
|
|
scan_dir = os.path.join(scans_dir, scan_id)
|
|
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
|
|
|
try:
|
|
image = Image.open(sbs_file).convert('RGB')
|
|
predicted_class, probabilities = predict(model, transforms, image)
|
|
|
|
neg_sum = sum([p[0] for p in probabilities])
|
|
pos_sum = sum([p[1] for p in probabilities])
|
|
total = neg_sum + pos_sum
|
|
neg_prob = neg_sum / total if total > 0 else 0
|
|
pos_prob = pos_sum / total if total > 0 else 0
|
|
|
|
neg_total += 1
|
|
if predicted_class == 0:
|
|
neg_correct += 1
|
|
else:
|
|
# Store first few errors for debugging
|
|
if len(neg_errors) < 5:
|
|
neg_errors.append((scan_id, neg_prob, pos_prob))
|
|
except Exception as e:
|
|
print(f"\nError processing {scan_id}: {e}")
|
|
continue
|
|
|
|
# Calculate accuracies
|
|
pos_acc = pos_correct / pos_total if pos_total > 0 else 0
|
|
neg_acc = neg_correct / neg_total if neg_total > 0 else 0
|
|
total_correct = pos_correct + neg_correct
|
|
total_samples = pos_total + neg_total
|
|
overall_acc = total_correct / total_samples if total_samples > 0 else 0
|
|
|
|
# Print results
|
|
print("\n" + "="*60)
|
|
print("Accuracy Results:")
|
|
print("="*60)
|
|
print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})")
|
|
print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})")
|
|
print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})")
|
|
print("="*60)
|
|
|
|
if neg_errors:
|
|
print("\nSample negative misclassifications:")
|
|
for scan_id, neg_prob, pos_prob in neg_errors:
|
|
print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|