themblem/emblem5/ai/verify2.py
Fam Zheng a1a982fb06 Revert "Increase image resolution to 512x512 in training and verification"
This reverts commit 3715bafcb4988d54e5cf0d601ae0eb896c7fe205.
2026-02-28 09:14:53 +00:00

349 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
'''
Load the best model and run inference on all scan IDs.
For each scan, predict on all 3x3 grid images and use voting to determine the final scan label.
Calculate accuracy against the original scan labels.
'''
import os
import json
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
import random
from collections import defaultdict
from common import *
class GridInferenceDataset(Dataset):
def __init__(self, scan_data, transform=None):
self.scan_data = scan_data
self.transform = transform
self.sample_metadata = []
print(f"Loading grid files for {len(scan_data)} scans...")
# Collect all grid files for each scan
for scan_id, metadata in tqdm(scan_data.items(), desc="Loading grid metadata"):
grid_dir = os.path.join('data/scans', scan_id, 'grids')
if not os.path.exists(grid_dir):
continue
# Check for all 9 grid files
grid_files = []
for i in range(3):
for j in range(3):
grid_filename = f'grid-{i}-{j}.jpg'
grid_path = os.path.join(grid_dir, grid_filename)
if os.path.exists(grid_path):
grid_files.append({
'scan_id': scan_id,
'grid_path': grid_path,
'grid_i': i,
'grid_j': j,
'label': 1 if 'pos' in metadata['labels'] else 0
})
# Only include scans that have all 9 grid files
if len(grid_files) == 9:
self.sample_metadata.extend(grid_files)
else:
print(f"Warning: Scan {scan_id} has {len(grid_files)}/9 grid files, skipping")
print(f"Loaded {len(self.sample_metadata)} grid files from {len(self.sample_metadata)//9} complete scans")
def __len__(self):
return len(self.sample_metadata)
def __getitem__(self, idx):
# Get sample metadata
metadata = self.sample_metadata[idx]
# Load the grid image
grid_img = Image.open(metadata['grid_path']).convert('RGB')
# Apply transforms
if self.transform:
grid_img = self.transform(grid_img)
return grid_img, torch.tensor(metadata['label'], dtype=torch.long), metadata['scan_id']
def load_scan_data(data_dir):
"""Load all scan metadata and organize by scan_id"""
scan_data = {}
scans_dir = os.path.join(data_dir, 'scans')
if not os.path.exists(scans_dir):
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
scan_ids = [d for d in os.listdir(scans_dir)
if os.path.isdir(os.path.join(scans_dir, d))]
print(f"Loading metadata for {len(scan_ids)} scans...")
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
scan_data[scan_id] = metadata
except Exception as e:
print(f"Error loading metadata for {scan_id}: {e}")
return scan_data
def run_inference(model, data_loader, device):
"""Run inference on all data and collect predictions by scan_id"""
model.eval()
# Dictionary to store predictions for each scan
scan_predictions = defaultdict(list)
scan_labels = {}
# Track running accuracy
total_predictions = 0
correct_predictions = 0
pos_correct = 0
pos_total = 0
neg_correct = 0
neg_total = 0
print("Running inference...")
with torch.no_grad():
pbar = tqdm(data_loader, desc="Inference")
for batch_idx, (data, target, scan_ids) in enumerate(pbar):
data = data.to(device)
# Get predictions
output = model(data)
probabilities = torch.softmax(output, dim=1)
predictions = torch.argmax(output, dim=1)
# Move target to same device as predictions for comparison
target_device = target.to(device)
# Calculate batch accuracy
batch_correct = (predictions == target_device).sum().item()
correct_predictions += batch_correct
total_predictions += len(target)
# Track per-class accuracy
for i in range(len(target)):
if target[i] == 1: # Positive class
pos_total += 1
if predictions[i] == target_device[i]:
pos_correct += 1
else: # Negative class
neg_total += 1
if predictions[i] == target_device[i]:
neg_correct += 1
# Store predictions by scan_id
for i in range(len(scan_ids)):
scan_id = scan_ids[i]
pred = predictions[i].item()
prob = probabilities[i][1].item() # Probability of positive class
scan_predictions[scan_id].append({
'prediction': pred,
'probability': prob
})
# Store the true label (should be the same for all grids in a scan)
if scan_id not in scan_labels:
scan_labels[scan_id] = target[i].item()
# Calculate running accuracies
overall_acc = 100. * correct_predictions / total_predictions if total_predictions > 0 else 0
pos_acc = 100. * pos_correct / pos_total if pos_total > 0 else 0
neg_acc = 100. * neg_correct / neg_total if neg_total > 0 else 0
# Update progress bar with accuracy info
pbar.set_postfix({
'Overall': f'{overall_acc:.2f}%',
'Pos': f'{pos_acc:.2f}%',
'Neg': f'{neg_acc:.2f}%',
'Pos/Total': f'{pos_correct}/{pos_total}',
'Neg/Total': f'{neg_correct}/{neg_total}'
})
# Clear memory
del data, output, probabilities, predictions
return scan_predictions, scan_labels
def calculate_voting_accuracy(scan_predictions, scan_labels):
"""Calculate accuracy using voting mechanism"""
correct_predictions = 0
total_scans = 0
pos_correct = 0
pos_total = 0
neg_correct = 0
neg_total = 0
# Create log file
log_file = 'data/verify2.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
with open(log_file, 'w') as f:
f.write("Voting Results:\n")
f.write("-" * 80 + "\n")
f.write(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}\n")
f.write("-" * 80 + "\n")
print("\nVoting Results:")
print("-" * 80)
print(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}")
print("-" * 80)
for scan_id, predictions in scan_predictions.items():
if scan_id not in scan_labels:
continue
true_label = scan_labels[scan_id]
total_scans += 1
# Count positive and negative votes
pos_votes = sum(1 for p in predictions if p['prediction'] == 1)
neg_votes = sum(1 for p in predictions if p['prediction'] == 0)
# Determine final prediction by majority vote
final_prediction = 1 if pos_votes > neg_votes else 0
# Check if prediction is correct
is_correct = final_prediction == true_label
if is_correct:
correct_predictions += 1
# Track per-class accuracy
if true_label == 1: # Positive class
pos_total += 1
if is_correct:
pos_correct += 1
else: # Negative class
neg_total += 1
if is_correct:
neg_correct += 1
# Print result
status = "" if is_correct else ""
result_line = f"{scan_id:<15} {true_label:<12} {final_prediction:<12} {pos_votes:<10} {neg_votes:<10} {status:<8}"
print(result_line)
# Write to log file
with open(log_file, 'a') as f:
f.write(result_line + "\n")
# Calculate accuracies
overall_accuracy = 100. * correct_predictions / total_scans if total_scans > 0 else 0
pos_accuracy = 100. * pos_correct / pos_total if pos_total > 0 else 0
neg_accuracy = 100. * neg_correct / neg_total if neg_total > 0 else 0
# Write summary to log file
with open(log_file, 'a') as f:
f.write("-" * 80 + "\n")
f.write(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})\n")
f.write(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})\n")
f.write(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})\n")
print("-" * 80)
print(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})")
print(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})")
print(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})")
return overall_accuracy, pos_accuracy, neg_accuracy
def main():
parser = argparse.ArgumentParser(description='Verify model accuracy using voting mechanism')
parser.add_argument('--data-dir', default='data', help='Data directory')
parser.add_argument('--model', default='models/final_model_ep10_pos98.54_neg78.52_20250706_143910.pt', help='Path to the model file')
parser.add_argument('--batch-size', type=int, default=128, help='Batch size for inference')
parser.add_argument('--sample', type=float, default=1.0, help='Fraction of scans to sample (0.0-1.0, default: 1.0 for all scans)')
args = parser.parse_args()
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# Load scan data
print("Loading scan data...")
scan_data = load_scan_data(args.data_dir)
if not scan_data:
print("No scan data found!")
return
# Sample scans if requested
if args.sample < 1.0:
scan_ids = list(scan_data.keys())
num_to_sample = max(1, int(len(scan_ids) * args.sample))
sampled_scan_ids = random.sample(scan_ids, num_to_sample)
scan_data = {scan_id: scan_data[scan_id] for scan_id in sampled_scan_ids}
print(f"Sampled {len(sampled_scan_ids)} scans out of {len(scan_ids)} total scans ({args.sample*100:.1f}%)")
else:
print(f"Using all {len(scan_data)} scans")
# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create dataset
print("Creating inference dataset...")
dataset = GridInferenceDataset(scan_data, transform=transform)
if len(dataset) == 0:
print("No valid grid files found!")
return
# Create data loader
print("Creating data loader...")
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
pin_memory=False,
persistent_workers=True,
prefetch_factor=2
)
print(f"Total batches: {len(data_loader)}")
# Load model
print(f"Loading model from {args.model}...")
try:
model, _ = load_model(args.model)
model = model.to(device)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
return
# Run inference
scan_predictions, scan_labels = run_inference(model, data_loader, device)
# Calculate voting accuracy
overall_acc, pos_acc, neg_acc = calculate_voting_accuracy(scan_predictions, scan_labels)
print(f"\nFinal Results:")
print(f"Overall Accuracy: {overall_acc:.2f}%")
print(f"Positive Accuracy: {pos_acc:.2f}%")
print(f"Negative Accuracy: {neg_acc:.2f}%")
# Write final results to log file
with open('data/verify2.log', 'a') as f:
f.write(f"\nFinal Results:\n")
f.write(f"Overall Accuracy: {overall_acc:.2f}%\n")
f.write(f"Positive Accuracy: {pos_acc:.2f}%\n")
f.write(f"Negative Accuracy: {neg_acc:.2f}%\n")
print(f"\nResults have been written to data/verify2.log")
if __name__ == '__main__':
main()