#!/usr/bin/env python3 ''' Load the best model and run inference on all scan IDs. For each scan, predict on all 3x3 grid images and use voting to determine the final scan label. Calculate accuracy against the original scan labels. ''' import os import json import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader from PIL import Image import numpy as np from tqdm import tqdm import argparse import random from collections import defaultdict from common import * class GridInferenceDataset(Dataset): def __init__(self, scan_data, transform=None): self.scan_data = scan_data self.transform = transform self.sample_metadata = [] print(f"Loading grid files for {len(scan_data)} scans...") # Collect all grid files for each scan for scan_id, metadata in tqdm(scan_data.items(), desc="Loading grid metadata"): grid_dir = os.path.join('data/scans', scan_id, 'grids') if not os.path.exists(grid_dir): continue # Check for all 9 grid files grid_files = [] for i in range(3): for j in range(3): grid_filename = f'grid-{i}-{j}.jpg' grid_path = os.path.join(grid_dir, grid_filename) if os.path.exists(grid_path): grid_files.append({ 'scan_id': scan_id, 'grid_path': grid_path, 'grid_i': i, 'grid_j': j, 'label': 1 if 'pos' in metadata['labels'] else 0 }) # Only include scans that have all 9 grid files if len(grid_files) == 9: self.sample_metadata.extend(grid_files) else: print(f"Warning: Scan {scan_id} has {len(grid_files)}/9 grid files, skipping") print(f"Loaded {len(self.sample_metadata)} grid files from {len(self.sample_metadata)//9} complete scans") def __len__(self): return len(self.sample_metadata) def __getitem__(self, idx): # Get sample metadata metadata = self.sample_metadata[idx] # Load the grid image grid_img = Image.open(metadata['grid_path']).convert('RGB') # Apply transforms if self.transform: grid_img = self.transform(grid_img) return grid_img, torch.tensor(metadata['label'], dtype=torch.long), metadata['scan_id'] def load_scan_data(data_dir): """Load all scan metadata and organize by scan_id""" scan_data = {} scans_dir = os.path.join(data_dir, 'scans') if not os.path.exists(scans_dir): raise FileNotFoundError(f"Scans directory not found: {scans_dir}") scan_ids = [d for d in os.listdir(scans_dir) if os.path.isdir(os.path.join(scans_dir, d))] print(f"Loading metadata for {len(scan_ids)} scans...") for scan_id in tqdm(scan_ids, desc="Loading scan metadata"): metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json') if os.path.exists(metadata_path): try: with open(metadata_path, 'r') as f: metadata = json.load(f) scan_data[scan_id] = metadata except Exception as e: print(f"Error loading metadata for {scan_id}: {e}") return scan_data def run_inference(model, data_loader, device): """Run inference on all data and collect predictions by scan_id""" model.eval() # Dictionary to store predictions for each scan scan_predictions = defaultdict(list) scan_labels = {} # Track running accuracy total_predictions = 0 correct_predictions = 0 pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 print("Running inference...") with torch.no_grad(): pbar = tqdm(data_loader, desc="Inference") for batch_idx, (data, target, scan_ids) in enumerate(pbar): data = data.to(device) # Get predictions output = model(data) probabilities = torch.softmax(output, dim=1) predictions = torch.argmax(output, dim=1) # Move target to same device as predictions for comparison target_device = target.to(device) # Calculate batch accuracy batch_correct = (predictions == target_device).sum().item() correct_predictions += batch_correct total_predictions += len(target) # Track per-class accuracy for i in range(len(target)): if target[i] == 1: # Positive class pos_total += 1 if predictions[i] == target_device[i]: pos_correct += 1 else: # Negative class neg_total += 1 if predictions[i] == target_device[i]: neg_correct += 1 # Store predictions by scan_id for i in range(len(scan_ids)): scan_id = scan_ids[i] pred = predictions[i].item() prob = probabilities[i][1].item() # Probability of positive class scan_predictions[scan_id].append({ 'prediction': pred, 'probability': prob }) # Store the true label (should be the same for all grids in a scan) if scan_id not in scan_labels: scan_labels[scan_id] = target[i].item() # Calculate running accuracies overall_acc = 100. * correct_predictions / total_predictions if total_predictions > 0 else 0 pos_acc = 100. * pos_correct / pos_total if pos_total > 0 else 0 neg_acc = 100. * neg_correct / neg_total if neg_total > 0 else 0 # Update progress bar with accuracy info pbar.set_postfix({ 'Overall': f'{overall_acc:.2f}%', 'Pos': f'{pos_acc:.2f}%', 'Neg': f'{neg_acc:.2f}%', 'Pos/Total': f'{pos_correct}/{pos_total}', 'Neg/Total': f'{neg_correct}/{neg_total}' }) # Clear memory del data, output, probabilities, predictions return scan_predictions, scan_labels def calculate_voting_accuracy(scan_predictions, scan_labels): """Calculate accuracy using voting mechanism""" correct_predictions = 0 total_scans = 0 pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 # Create log file log_file = 'data/verify2.log' os.makedirs(os.path.dirname(log_file), exist_ok=True) with open(log_file, 'w') as f: f.write("Voting Results:\n") f.write("-" * 80 + "\n") f.write(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}\n") f.write("-" * 80 + "\n") print("\nVoting Results:") print("-" * 80) print(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}") print("-" * 80) for scan_id, predictions in scan_predictions.items(): if scan_id not in scan_labels: continue true_label = scan_labels[scan_id] total_scans += 1 # Count positive and negative votes pos_votes = sum(1 for p in predictions if p['prediction'] == 1) neg_votes = sum(1 for p in predictions if p['prediction'] == 0) # Determine final prediction by majority vote final_prediction = 1 if pos_votes > neg_votes else 0 # Check if prediction is correct is_correct = final_prediction == true_label if is_correct: correct_predictions += 1 # Track per-class accuracy if true_label == 1: # Positive class pos_total += 1 if is_correct: pos_correct += 1 else: # Negative class neg_total += 1 if is_correct: neg_correct += 1 # Print result status = "✓" if is_correct else "✗" result_line = f"{scan_id:<15} {true_label:<12} {final_prediction:<12} {pos_votes:<10} {neg_votes:<10} {status:<8}" print(result_line) # Write to log file with open(log_file, 'a') as f: f.write(result_line + "\n") # Calculate accuracies overall_accuracy = 100. * correct_predictions / total_scans if total_scans > 0 else 0 pos_accuracy = 100. * pos_correct / pos_total if pos_total > 0 else 0 neg_accuracy = 100. * neg_correct / neg_total if neg_total > 0 else 0 # Write summary to log file with open(log_file, 'a') as f: f.write("-" * 80 + "\n") f.write(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})\n") f.write(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})\n") f.write(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})\n") print("-" * 80) print(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})") print(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})") print(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})") return overall_accuracy, pos_accuracy, neg_accuracy def main(): parser = argparse.ArgumentParser(description='Verify model accuracy using voting mechanism') parser.add_argument('--data-dir', default='data', help='Data directory') parser.add_argument('--model', default='models/final_model_ep10_pos98.54_neg78.52_20250706_143910.pt', help='Path to the model file') parser.add_argument('--batch-size', type=int, default=128, help='Batch size for inference') parser.add_argument('--sample', type=float, default=1.0, help='Fraction of scans to sample (0.0-1.0, default: 1.0 for all scans)') args = parser.parse_args() # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # Load scan data print("Loading scan data...") scan_data = load_scan_data(args.data_dir) if not scan_data: print("No scan data found!") return # Sample scans if requested if args.sample < 1.0: scan_ids = list(scan_data.keys()) num_to_sample = max(1, int(len(scan_ids) * args.sample)) sampled_scan_ids = random.sample(scan_ids, num_to_sample) scan_data = {scan_id: scan_data[scan_id] for scan_id in sampled_scan_ids} print(f"Sampled {len(sampled_scan_ids)} scans out of {len(scan_ids)} total scans ({args.sample*100:.1f}%)") else: print(f"Using all {len(scan_data)} scans") # Define transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Create dataset print("Creating inference dataset...") dataset = GridInferenceDataset(scan_data, transform=transform) if len(dataset) == 0: print("No valid grid files found!") return # Create data loader print("Creating data loader...") data_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False, persistent_workers=True, prefetch_factor=2 ) print(f"Total batches: {len(data_loader)}") # Load model print(f"Loading model from {args.model}...") try: model, _ = load_model(args.model) model = model.to(device) print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") return # Run inference scan_predictions, scan_labels = run_inference(model, data_loader, device) # Calculate voting accuracy overall_acc, pos_acc, neg_acc = calculate_voting_accuracy(scan_predictions, scan_labels) print(f"\nFinal Results:") print(f"Overall Accuracy: {overall_acc:.2f}%") print(f"Positive Accuracy: {pos_acc:.2f}%") print(f"Negative Accuracy: {neg_acc:.2f}%") # Write final results to log file with open('data/verify2.log', 'a') as f: f.write(f"\nFinal Results:\n") f.write(f"Overall Accuracy: {overall_acc:.2f}%\n") f.write(f"Positive Accuracy: {pos_acc:.2f}%\n") f.write(f"Negative Accuracy: {neg_acc:.2f}%\n") print(f"\nResults have been written to data/verify2.log") if __name__ == '__main__': main()