349 lines
13 KiB
Python
Executable File
349 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
'''
|
|
Load the best model and run inference on all scan IDs.
|
|
For each scan, predict on all 3x3 grid images and use voting to determine the final scan label.
|
|
Calculate accuracy against the original scan labels.
|
|
'''
|
|
|
|
import os
|
|
import json
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from PIL import Image
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import argparse
|
|
import random
|
|
from collections import defaultdict
|
|
from common import *
|
|
|
|
class GridInferenceDataset(Dataset):
|
|
def __init__(self, scan_data, transform=None):
|
|
self.scan_data = scan_data
|
|
self.transform = transform
|
|
self.sample_metadata = []
|
|
|
|
print(f"Loading grid files for {len(scan_data)} scans...")
|
|
|
|
# Collect all grid files for each scan
|
|
for scan_id, metadata in tqdm(scan_data.items(), desc="Loading grid metadata"):
|
|
grid_dir = os.path.join('data/scans', scan_id, 'grids')
|
|
if not os.path.exists(grid_dir):
|
|
continue
|
|
|
|
# Check for all 9 grid files
|
|
grid_files = []
|
|
for i in range(3):
|
|
for j in range(3):
|
|
grid_filename = f'grid-{i}-{j}.jpg'
|
|
grid_path = os.path.join(grid_dir, grid_filename)
|
|
if os.path.exists(grid_path):
|
|
grid_files.append({
|
|
'scan_id': scan_id,
|
|
'grid_path': grid_path,
|
|
'grid_i': i,
|
|
'grid_j': j,
|
|
'label': 1 if 'pos' in metadata['labels'] else 0
|
|
})
|
|
|
|
# Only include scans that have all 9 grid files
|
|
if len(grid_files) == 9:
|
|
self.sample_metadata.extend(grid_files)
|
|
else:
|
|
print(f"Warning: Scan {scan_id} has {len(grid_files)}/9 grid files, skipping")
|
|
|
|
print(f"Loaded {len(self.sample_metadata)} grid files from {len(self.sample_metadata)//9} complete scans")
|
|
|
|
def __len__(self):
|
|
return len(self.sample_metadata)
|
|
|
|
def __getitem__(self, idx):
|
|
# Get sample metadata
|
|
metadata = self.sample_metadata[idx]
|
|
|
|
# Load the grid image
|
|
grid_img = Image.open(metadata['grid_path']).convert('RGB')
|
|
|
|
# Apply transforms
|
|
if self.transform:
|
|
grid_img = self.transform(grid_img)
|
|
|
|
return grid_img, torch.tensor(metadata['label'], dtype=torch.long), metadata['scan_id']
|
|
|
|
def load_scan_data(data_dir):
|
|
"""Load all scan metadata and organize by scan_id"""
|
|
scan_data = {}
|
|
scans_dir = os.path.join(data_dir, 'scans')
|
|
|
|
if not os.path.exists(scans_dir):
|
|
raise FileNotFoundError(f"Scans directory not found: {scans_dir}")
|
|
|
|
scan_ids = [d for d in os.listdir(scans_dir)
|
|
if os.path.isdir(os.path.join(scans_dir, d))]
|
|
|
|
print(f"Loading metadata for {len(scan_ids)} scans...")
|
|
for scan_id in tqdm(scan_ids, desc="Loading scan metadata"):
|
|
metadata_path = os.path.join(scans_dir, scan_id, 'metadata.json')
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
scan_data[scan_id] = metadata
|
|
except Exception as e:
|
|
print(f"Error loading metadata for {scan_id}: {e}")
|
|
|
|
return scan_data
|
|
|
|
def run_inference(model, data_loader, device):
|
|
"""Run inference on all data and collect predictions by scan_id"""
|
|
model.eval()
|
|
|
|
# Dictionary to store predictions for each scan
|
|
scan_predictions = defaultdict(list)
|
|
scan_labels = {}
|
|
|
|
# Track running accuracy
|
|
total_predictions = 0
|
|
correct_predictions = 0
|
|
pos_correct = 0
|
|
pos_total = 0
|
|
neg_correct = 0
|
|
neg_total = 0
|
|
|
|
print("Running inference...")
|
|
with torch.no_grad():
|
|
pbar = tqdm(data_loader, desc="Inference")
|
|
for batch_idx, (data, target, scan_ids) in enumerate(pbar):
|
|
data = data.to(device)
|
|
|
|
# Get predictions
|
|
output = model(data)
|
|
probabilities = torch.softmax(output, dim=1)
|
|
predictions = torch.argmax(output, dim=1)
|
|
|
|
# Move target to same device as predictions for comparison
|
|
target_device = target.to(device)
|
|
|
|
# Calculate batch accuracy
|
|
batch_correct = (predictions == target_device).sum().item()
|
|
correct_predictions += batch_correct
|
|
total_predictions += len(target)
|
|
|
|
# Track per-class accuracy
|
|
for i in range(len(target)):
|
|
if target[i] == 1: # Positive class
|
|
pos_total += 1
|
|
if predictions[i] == target_device[i]:
|
|
pos_correct += 1
|
|
else: # Negative class
|
|
neg_total += 1
|
|
if predictions[i] == target_device[i]:
|
|
neg_correct += 1
|
|
|
|
# Store predictions by scan_id
|
|
for i in range(len(scan_ids)):
|
|
scan_id = scan_ids[i]
|
|
pred = predictions[i].item()
|
|
prob = probabilities[i][1].item() # Probability of positive class
|
|
|
|
scan_predictions[scan_id].append({
|
|
'prediction': pred,
|
|
'probability': prob
|
|
})
|
|
|
|
# Store the true label (should be the same for all grids in a scan)
|
|
if scan_id not in scan_labels:
|
|
scan_labels[scan_id] = target[i].item()
|
|
|
|
# Calculate running accuracies
|
|
overall_acc = 100. * correct_predictions / total_predictions if total_predictions > 0 else 0
|
|
pos_acc = 100. * pos_correct / pos_total if pos_total > 0 else 0
|
|
neg_acc = 100. * neg_correct / neg_total if neg_total > 0 else 0
|
|
|
|
# Update progress bar with accuracy info
|
|
pbar.set_postfix({
|
|
'Overall': f'{overall_acc:.2f}%',
|
|
'Pos': f'{pos_acc:.2f}%',
|
|
'Neg': f'{neg_acc:.2f}%',
|
|
'Pos/Total': f'{pos_correct}/{pos_total}',
|
|
'Neg/Total': f'{neg_correct}/{neg_total}'
|
|
})
|
|
|
|
# Clear memory
|
|
del data, output, probabilities, predictions
|
|
|
|
return scan_predictions, scan_labels
|
|
|
|
def calculate_voting_accuracy(scan_predictions, scan_labels):
|
|
"""Calculate accuracy using voting mechanism"""
|
|
correct_predictions = 0
|
|
total_scans = 0
|
|
pos_correct = 0
|
|
pos_total = 0
|
|
neg_correct = 0
|
|
neg_total = 0
|
|
|
|
# Create log file
|
|
log_file = 'data/verify2.log'
|
|
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
|
|
|
with open(log_file, 'w') as f:
|
|
f.write("Voting Results:\n")
|
|
f.write("-" * 80 + "\n")
|
|
f.write(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}\n")
|
|
f.write("-" * 80 + "\n")
|
|
|
|
print("\nVoting Results:")
|
|
print("-" * 80)
|
|
print(f"{'Scan ID':<15} {'True Label':<12} {'Vote Result':<12} {'Pos Votes':<10} {'Neg Votes':<10} {'Correct':<8}")
|
|
print("-" * 80)
|
|
|
|
for scan_id, predictions in scan_predictions.items():
|
|
if scan_id not in scan_labels:
|
|
continue
|
|
|
|
true_label = scan_labels[scan_id]
|
|
total_scans += 1
|
|
|
|
# Count positive and negative votes
|
|
pos_votes = sum(1 for p in predictions if p['prediction'] == 1)
|
|
neg_votes = sum(1 for p in predictions if p['prediction'] == 0)
|
|
|
|
# Determine final prediction by majority vote
|
|
final_prediction = 1 if pos_votes > neg_votes else 0
|
|
|
|
# Check if prediction is correct
|
|
is_correct = final_prediction == true_label
|
|
if is_correct:
|
|
correct_predictions += 1
|
|
|
|
# Track per-class accuracy
|
|
if true_label == 1: # Positive class
|
|
pos_total += 1
|
|
if is_correct:
|
|
pos_correct += 1
|
|
else: # Negative class
|
|
neg_total += 1
|
|
if is_correct:
|
|
neg_correct += 1
|
|
|
|
# Print result
|
|
status = "✓" if is_correct else "✗"
|
|
result_line = f"{scan_id:<15} {true_label:<12} {final_prediction:<12} {pos_votes:<10} {neg_votes:<10} {status:<8}"
|
|
print(result_line)
|
|
|
|
# Write to log file
|
|
with open(log_file, 'a') as f:
|
|
f.write(result_line + "\n")
|
|
|
|
# Calculate accuracies
|
|
overall_accuracy = 100. * correct_predictions / total_scans if total_scans > 0 else 0
|
|
pos_accuracy = 100. * pos_correct / pos_total if pos_total > 0 else 0
|
|
neg_accuracy = 100. * neg_correct / neg_total if neg_total > 0 else 0
|
|
|
|
# Write summary to log file
|
|
with open(log_file, 'a') as f:
|
|
f.write("-" * 80 + "\n")
|
|
f.write(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})\n")
|
|
f.write(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})\n")
|
|
f.write(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})\n")
|
|
|
|
print("-" * 80)
|
|
print(f"Overall Accuracy: {overall_accuracy:.2f}% ({correct_predictions}/{total_scans})")
|
|
print(f"Positive Accuracy: {pos_accuracy:.2f}% ({pos_correct}/{pos_total})")
|
|
print(f"Negative Accuracy: {neg_accuracy:.2f}% ({neg_correct}/{neg_total})")
|
|
|
|
return overall_accuracy, pos_accuracy, neg_accuracy
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Verify model accuracy using voting mechanism')
|
|
parser.add_argument('--data-dir', default='data', help='Data directory')
|
|
parser.add_argument('--model', default='models/final_model_ep10_pos98.54_neg78.52_20250706_143910.pt', help='Path to the model file')
|
|
parser.add_argument('--batch-size', type=int, default=128, help='Batch size for inference')
|
|
parser.add_argument('--sample', type=float, default=1.0, help='Fraction of scans to sample (0.0-1.0, default: 1.0 for all scans)')
|
|
args = parser.parse_args()
|
|
|
|
# Set device
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Using device: {device}')
|
|
|
|
# Load scan data
|
|
print("Loading scan data...")
|
|
scan_data = load_scan_data(args.data_dir)
|
|
|
|
if not scan_data:
|
|
print("No scan data found!")
|
|
return
|
|
|
|
# Sample scans if requested
|
|
if args.sample < 1.0:
|
|
scan_ids = list(scan_data.keys())
|
|
num_to_sample = max(1, int(len(scan_ids) * args.sample))
|
|
sampled_scan_ids = random.sample(scan_ids, num_to_sample)
|
|
scan_data = {scan_id: scan_data[scan_id] for scan_id in sampled_scan_ids}
|
|
print(f"Sampled {len(sampled_scan_ids)} scans out of {len(scan_ids)} total scans ({args.sample*100:.1f}%)")
|
|
else:
|
|
print(f"Using all {len(scan_data)} scans")
|
|
|
|
# Define transforms
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Create dataset
|
|
print("Creating inference dataset...")
|
|
dataset = GridInferenceDataset(scan_data, transform=transform)
|
|
|
|
if len(dataset) == 0:
|
|
print("No valid grid files found!")
|
|
return
|
|
|
|
# Create data loader
|
|
print("Creating data loader...")
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=8,
|
|
pin_memory=False,
|
|
persistent_workers=True,
|
|
prefetch_factor=2
|
|
)
|
|
print(f"Total batches: {len(data_loader)}")
|
|
|
|
# Load model
|
|
print(f"Loading model from {args.model}...")
|
|
try:
|
|
model, _ = load_model(args.model)
|
|
model = model.to(device)
|
|
print("Model loaded successfully")
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
return
|
|
|
|
# Run inference
|
|
scan_predictions, scan_labels = run_inference(model, data_loader, device)
|
|
|
|
# Calculate voting accuracy
|
|
overall_acc, pos_acc, neg_acc = calculate_voting_accuracy(scan_predictions, scan_labels)
|
|
|
|
print(f"\nFinal Results:")
|
|
print(f"Overall Accuracy: {overall_acc:.2f}%")
|
|
print(f"Positive Accuracy: {pos_acc:.2f}%")
|
|
print(f"Negative Accuracy: {neg_acc:.2f}%")
|
|
|
|
# Write final results to log file
|
|
with open('data/verify2.log', 'a') as f:
|
|
f.write(f"\nFinal Results:\n")
|
|
f.write(f"Overall Accuracy: {overall_acc:.2f}%\n")
|
|
f.write(f"Positive Accuracy: {pos_acc:.2f}%\n")
|
|
f.write(f"Negative Accuracy: {neg_acc:.2f}%\n")
|
|
|
|
print(f"\nResults have been written to data/verify2.log")
|
|
|
|
if __name__ == '__main__':
|
|
main() |