refactor: update check_accuracy.py
This commit is contained in:
parent
5e1fe5cb5a
commit
bb230733c5
@ -1,37 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Temporary script to check model accuracy on specific scan ranges
|
||||
Positive samples: 357193-358023
|
||||
Negative samples: 358024-358808
|
||||
Script to check model accuracy on scans
|
||||
By default, randomly selects 1000 images and compares with pos/neg labels from metadata
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import random
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Add current directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from common import load_model, predict
|
||||
from common import load_model, predict, parse_ranges, in_range
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python3 check_accuracy.py <model_path> <data_dir>")
|
||||
print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem")
|
||||
sys.exit(1)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model_path', help='Path to model file')
|
||||
parser.add_argument('data_dir', help='Path to data directory')
|
||||
parser.add_argument('--scan-ids', type=str, default=None, help='Filter scan IDs by range (e.g., "357193-358808" or "357193-358808,359000-359010")')
|
||||
parser.add_argument('--sample-size', type=int, default=1000, help='Number of scans to randomly sample (default: 1000)')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
data_dir = sys.argv[2]
|
||||
model_path = args.model_path
|
||||
data_dir = args.data_dir
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Define scan ranges
|
||||
pos_start, pos_end = 357193, 358023
|
||||
neg_start, neg_end = 358024, 358808
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
model, transforms = load_model(model_path)
|
||||
# Ensure model is not compiled
|
||||
@ -44,57 +43,80 @@ def main():
|
||||
print(f"Error: Scans directory not found: {scans_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Collect scan IDs in ranges
|
||||
pos_scans = []
|
||||
neg_scans = []
|
||||
# Collect scan IDs with metadata
|
||||
scan_candidates = []
|
||||
|
||||
print("Collecting scan IDs...")
|
||||
for scan_id in range(pos_start, pos_end + 1):
|
||||
scan_dir = os.path.join(scans_dir, str(scan_id))
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
if os.path.exists(sbs_file):
|
||||
pos_scans.append(str(scan_id))
|
||||
all_scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))]
|
||||
|
||||
for scan_id in range(neg_start, neg_end + 1):
|
||||
scan_dir = os.path.join(scans_dir, str(scan_id))
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
if os.path.exists(sbs_file):
|
||||
neg_scans.append(str(scan_id))
|
||||
# If SCAN_IDS is provided, filter by those ranges
|
||||
if args.scan_ids:
|
||||
filter_ranges = parse_ranges(args.scan_ids)
|
||||
filtered_scan_ids = []
|
||||
for scan_id_str in all_scan_ids:
|
||||
try:
|
||||
scan_id_int = int(scan_id_str)
|
||||
for val_range in filter_ranges:
|
||||
if in_range(scan_id_int, val_range):
|
||||
filtered_scan_ids.append(scan_id_str)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
all_scan_ids = filtered_scan_ids
|
||||
print(f"Filtered to {len(all_scan_ids)} scans matching range(s): {args.scan_ids}")
|
||||
|
||||
print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})")
|
||||
print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n")
|
||||
# Load metadata for all candidate scans
|
||||
for scan_id in all_scan_ids:
|
||||
scan_dir = os.path.join(scans_dir, scan_id)
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
metadata_file = os.path.join(scan_dir, 'metadata.json')
|
||||
|
||||
if not os.path.exists(sbs_file):
|
||||
continue
|
||||
|
||||
if not os.path.exists(metadata_file):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
labels = metadata.get('labels', [])
|
||||
if 'pos' not in labels and 'neg' not in labels:
|
||||
continue
|
||||
|
||||
scan_candidates.append({
|
||||
'scan_id': scan_id,
|
||||
'label': 1 if 'pos' in labels else 0,
|
||||
'sbs_file': sbs_file
|
||||
})
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
print(f"Found {len(scan_candidates)} scans with valid metadata and sbs.jpg")
|
||||
|
||||
# Randomly sample if needed
|
||||
if len(scan_candidates) > args.sample_size:
|
||||
scan_candidates = random.sample(scan_candidates, args.sample_size)
|
||||
print(f"Randomly sampled {args.sample_size} scans")
|
||||
|
||||
print(f"Evaluating {len(scan_candidates)} scans...\n")
|
||||
|
||||
# Run predictions
|
||||
pos_correct = 0
|
||||
pos_total = 0
|
||||
neg_correct = 0
|
||||
neg_total = 0
|
||||
errors = []
|
||||
|
||||
print("Evaluating positive samples...")
|
||||
for scan_id in tqdm(pos_scans, desc="Positive"):
|
||||
scan_dir = os.path.join(scans_dir, scan_id)
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
for scan_info in tqdm(scan_candidates, desc="Evaluating"):
|
||||
scan_id = scan_info['scan_id']
|
||||
true_label = scan_info['label']
|
||||
sbs_file = scan_info['sbs_file']
|
||||
|
||||
try:
|
||||
image = Image.open(sbs_file).convert('RGB')
|
||||
predicted_class, probabilities = predict(model, transforms, image)
|
||||
|
||||
pos_total += 1
|
||||
if predicted_class == 1:
|
||||
pos_correct += 1
|
||||
except Exception as e:
|
||||
print(f"\nError processing {scan_id}: {e}")
|
||||
continue
|
||||
|
||||
print("\nEvaluating negative samples...")
|
||||
neg_errors = []
|
||||
for scan_id in tqdm(neg_scans, desc="Negative"):
|
||||
scan_dir = os.path.join(scans_dir, scan_id)
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
|
||||
try:
|
||||
image = Image.open(sbs_file).convert('RGB')
|
||||
predicted_class, probabilities = predict(model, transforms, image)
|
||||
predicted_class, probabilities = predict(model, transforms, image, ncells=3)
|
||||
|
||||
neg_sum = sum([p[0] for p in probabilities])
|
||||
pos_sum = sum([p[1] for p in probabilities])
|
||||
@ -102,13 +124,22 @@ def main():
|
||||
neg_prob = neg_sum / total if total > 0 else 0
|
||||
pos_prob = pos_sum / total if total > 0 else 0
|
||||
|
||||
if true_label == 1:
|
||||
pos_total += 1
|
||||
if predicted_class == 1:
|
||||
pos_correct += 1
|
||||
else:
|
||||
# Store first few errors for debugging
|
||||
if len(errors) < 5:
|
||||
errors.append((scan_id, 'pos', neg_prob, pos_prob))
|
||||
else:
|
||||
neg_total += 1
|
||||
if predicted_class == 0:
|
||||
neg_correct += 1
|
||||
else:
|
||||
# Store first few errors for debugging
|
||||
if len(neg_errors) < 5:
|
||||
neg_errors.append((scan_id, neg_prob, pos_prob))
|
||||
if len(errors) < 5:
|
||||
errors.append((scan_id, 'neg', neg_prob, pos_prob))
|
||||
except Exception as e:
|
||||
print(f"\nError processing {scan_id}: {e}")
|
||||
continue
|
||||
@ -129,10 +160,10 @@ def main():
|
||||
print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})")
|
||||
print("="*60)
|
||||
|
||||
if neg_errors:
|
||||
print("\nSample negative misclassifications:")
|
||||
for scan_id, neg_prob, pos_prob in neg_errors:
|
||||
print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}")
|
||||
if errors:
|
||||
print("\nSample misclassifications:")
|
||||
for scan_id, true_label, neg_prob, pos_prob in errors:
|
||||
print(f" Scan {scan_id} (true: {true_label}): neg={neg_prob:.2%}, pos={pos_prob:.2%}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user