799 lines
28 KiB
Python
799 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
import sys
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
|
import torchvision
|
|
from PIL import Image, ImageFilter
|
|
import os
|
|
from datetime import datetime
|
|
from collections import defaultdict
|
|
import argparse
|
|
from kornia.losses.focal import FocalLoss
|
|
from kornia.augmentation import ColorJiggle
|
|
import cv2
|
|
import numpy as np
|
|
import random
|
|
import re
|
|
import shutil
|
|
import json
|
|
import tempfile
|
|
import time
|
|
import base64
|
|
import subprocess
|
|
from loguru import logger
|
|
from collections import defaultdict
|
|
from multiprocessing import Pool, cpu_count, set_start_method
|
|
from tqdm import tqdm
|
|
import importlib
|
|
|
|
concurrency = max(1, cpu_count() - 2)
|
|
|
|
def info(msg):
|
|
logger.info(msg)
|
|
|
|
def debug(msg):
|
|
logger.debug(msg)
|
|
|
|
cuda_available = torch.cuda.is_available()
|
|
device = torch.device('cuda' if cuda_available else 'cpu')
|
|
default_model = 'best_model_ep82_pos98.94_neg96.13_20250720_222102.pt'
|
|
clarity_model = 'models/clarity-ep15-pos88.14-neg92.23-20250518_164155.pt'
|
|
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
def batch_generator(labels, batch_size):
|
|
for i in range(0, len(labels), batch_size):
|
|
yield labels[i:i+batch_size]
|
|
|
|
def make_side_by_side_img(left, right):
|
|
min_width = min(left.width, right.width)
|
|
min_height = min(left.height, right.height)
|
|
left = left.resize((min_width, min_height))
|
|
right = right.resize((min_width, min_height))
|
|
ret = Image.new('RGB', (min_width * 2, min_height))
|
|
ret.paste(left, (0, 0))
|
|
ret.paste(right, (min_width, 0))
|
|
return ret
|
|
|
|
def warp_with_margin_ratio(orig, edge, corners, margin_ratio):
|
|
src_points = np.float32(corners)
|
|
dst_points = np.float32([
|
|
[edge * margin_ratio, edge * margin_ratio],
|
|
[edge * (1 - margin_ratio), edge * margin_ratio],
|
|
[edge * (1 - margin_ratio), edge * (1 - margin_ratio)],
|
|
[edge * margin_ratio, edge * (1 - margin_ratio)],
|
|
])
|
|
M = cv2.getPerspectiveTransform(src_points, dst_points)
|
|
warped = cv2.warpPerspective(np.array(orig), M, (edge, edge))
|
|
return Image.fromarray(warped)
|
|
|
|
def find_min_margin_ratio(img, corners):
|
|
min_margin = None
|
|
for i in range(4):
|
|
point = corners[i]
|
|
x = point[0]
|
|
y = point[1]
|
|
this_min = min(
|
|
x / img.width,
|
|
(img.width - x) / img.width,
|
|
y / img.height,
|
|
(img.height - y) / img.height,
|
|
)
|
|
if min_margin is None or this_min < min_margin:
|
|
min_margin = this_min
|
|
return min_margin
|
|
|
|
def make_side_by_side_img_with_margins(frame_img, std_img):
|
|
std_qr, std_corners = find_qr(std_img)
|
|
frame_qr, frame_corners = find_qr(frame_img)
|
|
if std_corners is None or frame_corners is None:
|
|
return None
|
|
|
|
# Use a reasonable size - use the larger dimension or at least 512
|
|
edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512)
|
|
std_margin_ratio = find_min_margin_ratio(std_img, std_corners)
|
|
frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners)
|
|
# Use the minimum margin ratio to ensure both QR areas are captured correctly
|
|
margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05)
|
|
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
|
|
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)
|
|
|
|
# Create horizontal layout: frame on left, std on right
|
|
# Each warped image is edge_length x edge_length
|
|
ret = Image.new('RGB', (edge_length * 2, edge_length))
|
|
ret.paste(frame_warped, (0, 0)) # frame on left
|
|
ret.paste(std_warped, (edge_length, 0)) # std on right
|
|
|
|
return ret
|
|
|
|
def get_top_margin(img, margin_ratio):
|
|
ret = Image.new('RGB', (img.width, int(img.height * margin_ratio)))
|
|
ret.paste(img, (0, 0))
|
|
return ret
|
|
|
|
def make_top_margin_stacked_img(frame_img, std_img):
|
|
std_qr, std_corners = find_qr(std_img)
|
|
frame_qr, frame_corners = find_qr(frame_img)
|
|
if std_corners is None or frame_corners is None:
|
|
return None
|
|
|
|
edge_length = min(std_img.width, std_img.height)
|
|
margin_ratio = find_min_margin_ratio(std_img, std_corners)
|
|
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
|
|
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)
|
|
std_top_margin = get_top_margin(std_warped, margin_ratio)
|
|
frame_top_margin = get_top_margin(frame_warped, margin_ratio)
|
|
outheight = int(edge_length * margin_ratio * 2)
|
|
ret = Image.new('RGB', (edge_length, outheight))
|
|
ret.paste(std_top_margin, (0, 0))
|
|
ret.paste(frame_top_margin, (0, outheight // 2))
|
|
|
|
return ret
|
|
|
|
def make_stripe_img(left, right, nstripes):
|
|
min_width = min(left.width, right.width)
|
|
min_height = min(left.height, right.height)
|
|
ret = Image.new('RGB', (min_width * 2, min_height))
|
|
left_stripe_width = left.width // nstripes
|
|
right_stripe_width = right.width // nstripes
|
|
stripe_width = min_width // nstripes
|
|
for i in range(nstripes):
|
|
left_stripe = left.crop((i * left_stripe_width, 0, (i + 1) * left_stripe_width, left.height))
|
|
left_stripe = left_stripe.resize((stripe_width, min_height))
|
|
right_stripe = right.crop((i * right_stripe_width, 0, (i + 1) * right_stripe_width, right.height))
|
|
right_stripe = right_stripe.resize((stripe_width, min_height))
|
|
ret.paste(left_stripe, (i * stripe_width * 2, 0))
|
|
ret.paste(right_stripe, (i * stripe_width * 2 + stripe_width, 0))
|
|
return ret
|
|
|
|
def predict_multi(model, transforms, images, ncells=1):
|
|
results_per_img = ncells * ncells * 2
|
|
ret = []
|
|
with torch.no_grad():
|
|
tensors = []
|
|
for image in images:
|
|
for xcoord in range(ncells):
|
|
for ycoord in range(ncells):
|
|
img = crop_side_by_side(image, ncells, xcoord, ycoord)
|
|
img_tensor = transforms(img).to(device)
|
|
tensors.append(img_tensor)
|
|
output = model(torch.stack(tensors, dim=0))
|
|
sub_probs = torch.nn.functional.softmax(output, dim=1).cpu().numpy().tolist()
|
|
for i in range(len(images)):
|
|
probs = sub_probs[i * results_per_img:(i + 1) * results_per_img]
|
|
neg_sum = sum([x[0] for x in probs])
|
|
pos_sum = sum([x[1] for x in probs])
|
|
predicted_class = 1 if pos_sum > neg_sum else 0
|
|
ret.append((predicted_class, probs))
|
|
return ret
|
|
|
|
def predict(model, transforms, image, ncells=1):
|
|
r = predict_multi(model, transforms, [image], ncells)
|
|
return r[0]
|
|
|
|
qr_detector = cv2.wechat_qrcode_WeChatQRCode()
|
|
|
|
def find_qr(img, scale=1.0):
|
|
# Convert PIL Image to OpenCV format
|
|
orig_size = img.size
|
|
if scale < 1.0:
|
|
img_w, img_h = img.size
|
|
new_w = int(img_w * scale)
|
|
new_h = int(img_h * scale)
|
|
new_size = (new_w, new_h)
|
|
resized_img = img.resize(new_size)
|
|
else:
|
|
new_size = orig_size
|
|
resized_img = img
|
|
img_cv = cv2.cvtColor(np.array(resized_img), cv2.COLOR_RGB2BGR)
|
|
qr, corners = qr_detector.detectAndDecode(img_cv)
|
|
if not qr:
|
|
if scale > 0.05:
|
|
return find_qr(img, scale * 2 / 3)
|
|
else:
|
|
return None, None
|
|
corners = np.array(corners[0], dtype=np.float32)
|
|
corners /= scale
|
|
return qr[0], corners
|
|
|
|
def extract_qr(img):
|
|
qr, corners = find_qr(img)
|
|
if not qr:
|
|
raise Exception('No QR code found')
|
|
corners = np.array(corners, dtype=np.float32)
|
|
|
|
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
# Define target rectangle corners (clockwise from top-left)
|
|
min_x = min(corners[:, 0])
|
|
max_x = max(corners[:, 0])
|
|
min_y = min(corners[:, 1])
|
|
max_y = max(corners[:, 1])
|
|
width = max_x - min_x
|
|
height = max_y - min_y
|
|
width = height = int(min(width, height))
|
|
dst_corners = np.array([
|
|
[0, 0],
|
|
[width, 0],
|
|
[width, height],
|
|
[0, height]
|
|
], dtype=np.float32)
|
|
|
|
matrix = cv2.getPerspectiveTransform(corners, dst_corners)
|
|
warped = cv2.warpPerspective(img_cv, matrix, (width, height))
|
|
r = Image.fromarray(cv2.cvtColor(warped, cv2.COLOR_BGR2RGB))
|
|
return qr, r
|
|
|
|
def load_model(model_path):
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
|
model = checkpoint['model'] # Load the complete model structure
|
|
model.load_state_dict(checkpoint['model_state_dict']) # Load the weights
|
|
model.eval()
|
|
model = model.to(device)
|
|
transforms = checkpoint['transforms']
|
|
return model, transforms
|
|
|
|
def save_model(model, transforms, save_path, metadata=None):
|
|
checkpoint = {
|
|
'model': model, # Save the complete model structure
|
|
'model_state_dict': model.state_dict(),
|
|
'transforms': transforms
|
|
}
|
|
if metadata:
|
|
checkpoint['metadata'] = metadata
|
|
torch.save(checkpoint, save_path)
|
|
|
|
def make_model(model_name):
|
|
model_makers = {
|
|
'resnet': make_resnet,
|
|
'resnet18': make_resnet18,
|
|
'resnet101': make_resnet101,
|
|
'regnet': make_regnet,
|
|
'convnext': make_convnext,
|
|
'efficientnet': make_efficientnet,
|
|
'densenet': make_densenet,
|
|
'mobilenet': make_mobilenet,
|
|
}
|
|
return model_makers[model_name]()
|
|
|
|
def make_mobilenet():
|
|
weights = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.mobilenet_v3_small(weights=weights)
|
|
model.classifier = nn.Sequential(
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(576, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_densenet():
|
|
weights = models.DenseNet121_Weights.IMAGENET1K_V1
|
|
model = models.densenet121(weights=weights)
|
|
model.classifier = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(1024, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_efficientnet():
|
|
weights = torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.efficientnet_b3(weights=weights)
|
|
model.classifier = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(1536, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_convnext():
|
|
weights = torchvision.models.ConvNeXt_Base_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.convnext_base(weights=weights)
|
|
model.classifier = nn.Sequential(
|
|
nn.Flatten(1), # 先 flatten (从 [B, 1024, 1, 1] 到 [B, 1024])
|
|
nn.LayerNorm(1024, eps=1e-6),
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(1024, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_resnet101():
|
|
weights = torchvision.models.ResNet101_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.resnet101(weights=weights)
|
|
model.fc = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(2048, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_resnet18():
|
|
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.resnet18(weights=weights)
|
|
model.fc = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_resnet():
|
|
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
|
|
model = torchvision.models.resnet50(weights=weights)
|
|
model.fc = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(2048, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.2),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_regnet():
|
|
weights = models.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1
|
|
model = models.regnet_y_3_2gf(weights=weights)
|
|
model.fc = nn.Sequential(
|
|
nn.Dropout(p=0.4),
|
|
nn.Linear(1512, 512),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.3),
|
|
nn.Linear(512, 128),
|
|
nn.GELU(),
|
|
nn.Dropout(p=0.1),
|
|
nn.Linear(128, 2)
|
|
)
|
|
return model, make_generic_transforms()
|
|
|
|
def make_generic_transforms():
|
|
return torchvision.transforms.Compose([
|
|
torchvision.transforms.Resize((128, 256)),
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
class ScanDataset(Dataset):
|
|
def __init__(self, scans, transforms):
|
|
self.transforms = transforms
|
|
self.scans = scans
|
|
|
|
def __len__(self):
|
|
return len(self.scans)
|
|
|
|
def __getitem__(self, idx):
|
|
scan = self.scans[idx]
|
|
scan_id = scan['scan_id']
|
|
gridcrop2_dir = os.path.join('data', 'gridcrop2', scan_id)
|
|
sbs_file = os.path.join(gridcrop2_dir, 'sbs.jpg')
|
|
sbs_img = Image.open(sbs_file).convert('RGB')
|
|
return self.transforms(sbs_img), 1 if 'pos' in scan['labels'] else 0
|
|
|
|
def stats(self):
|
|
return {
|
|
'pos': sum(1 for scan in self.scans if 'pos' in scan['labels']),
|
|
'neg': sum(1 for scan in self.scans if 'neg' in scan['labels']),
|
|
}
|
|
|
|
def do_train(cfg):
|
|
train_dataset = cfg['train_dataset']
|
|
val_dataset = cfg['val_dataset']
|
|
|
|
batch_size = cfg['batch_size']
|
|
num_workers = cfg['num_workers']
|
|
|
|
prefetch_factor = 4
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
persistent_workers=True,
|
|
prefetch_factor=prefetch_factor,
|
|
pin_memory=True,
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=batch_size*2,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
persistent_workers=True,
|
|
prefetch_factor=prefetch_factor,
|
|
pin_memory=True,
|
|
)
|
|
|
|
model = cfg['model']
|
|
|
|
# Set memory fraction for better GPU utilization
|
|
torch.cuda.set_per_process_memory_fraction(0.8)
|
|
|
|
# Compile model with optimized settings
|
|
model = torch.compile(model).to(device)
|
|
|
|
criterion = cfg['criterion']
|
|
optimizer = cfg['optimizer'](model)
|
|
scheduler = cfg['scheduler'](optimizer)
|
|
max_epochs = cfg['max_epochs']
|
|
|
|
# Initialize variables for early stopping
|
|
best_epoch = None
|
|
|
|
# Create GradScaler once at the start of training
|
|
scaler = torch.amp.GradScaler('cuda')
|
|
|
|
for epoch in range(max_epochs):
|
|
info(f"Starting epoch {epoch+1}/{max_epochs}")
|
|
model.train()
|
|
running_loss = 0.0
|
|
|
|
for images, labels in tqdm(train_loader, desc=f"Train {epoch+1}/{max_epochs}", unit_scale=batch_size):
|
|
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
|
|
optimizer.zero_grad()
|
|
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels).mean()
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
running_loss += loss.detach().item() * len(images)
|
|
|
|
running_loss = running_loss / len(train_loader) / cfg['learning_rate']
|
|
info(f'Epoch [{epoch+1}/{max_epochs}], Loss: {running_loss:.5f}')
|
|
|
|
# 验证阶段
|
|
model.eval()
|
|
pos_correct = 0
|
|
pos_total = 0
|
|
neg_correct = 0
|
|
neg_total = 0
|
|
with torch.no_grad():
|
|
for images, labels in tqdm(val_loader, desc=f"Validating {epoch+1}/{max_epochs}", unit_scale=batch_size):
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
pos_correct += ((predicted == labels) & (labels == 1)).sum().item()
|
|
pos_total += (labels == 1).sum().item()
|
|
neg_correct += ((predicted == labels) & (labels == 0)).sum().item()
|
|
neg_total += (labels == 0).sum().item()
|
|
|
|
pos_accu = pos_correct / pos_total if pos_total else 0
|
|
neg_accu = neg_correct / neg_total if neg_total else 0
|
|
total_accu = (pos_correct + neg_correct) / (pos_total + neg_total)
|
|
info(f'Pos Accu: {pos_accu:.2%} ({pos_correct}/{pos_total})')
|
|
info(f'Neg Accu: {neg_accu:.2%} ({neg_correct}/{neg_total})')
|
|
info(f'Total Accu: {total_accu:.2%} ({pos_correct + neg_correct}/{pos_total + neg_total})')
|
|
|
|
if not best_epoch or total_accu > best_epoch['total_accu']:
|
|
best_epoch = {
|
|
'epoch': epoch,
|
|
'pos_accu': pos_accu,
|
|
'neg_accu': neg_accu,
|
|
'total_accu': total_accu,
|
|
'model_state_dict': model.state_dict().copy(),
|
|
}
|
|
info(f'New best model found with total accuracy: {best_epoch["total_accu"]:.2%}')
|
|
|
|
scheduler.step(total_accu)
|
|
|
|
# Load the best model weights
|
|
if best_epoch is not None:
|
|
model.load_state_dict(best_epoch['model_state_dict'])
|
|
info(f'Loaded best model with total accuracy: {best_epoch["total_accu"]:.2%}')
|
|
|
|
return model, best_epoch
|
|
|
|
def verify_frame(model, transforms, frame_img, orig_img):
|
|
side_by_side_img = make_side_by_side_img_with_margins(frame_img, orig_img)
|
|
if side_by_side_img is None:
|
|
raise Exception("Failed to create side-by-side image with margins")
|
|
side_by_side_img = side_by_side_img.convert('RGB')
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
|
|
side_by_side_img.save(f.name)
|
|
return predict(model, transforms, Image.open(f.name).convert('RGB'))
|
|
|
|
def parse_ranges(s):
|
|
ret = []
|
|
for tr in s.split(','):
|
|
if '-' in tr:
|
|
begin, end = tr.split('-')
|
|
ret.append([int(begin), int(end)])
|
|
else:
|
|
ret.append([int(tr), int(tr)])
|
|
return ret
|
|
|
|
def in_range(x, val_range):
|
|
if not val_range:
|
|
return False
|
|
start, end = val_range
|
|
return start <= int(x) <= end
|
|
|
|
def train_model(model, train_dataset, val_dataset, max_epochs, pos_weight=0.99):
|
|
info(f"Train count: {len(train_dataset)}, val count: {len(val_dataset)}")
|
|
learning_rate = 0.001
|
|
cfg = {
|
|
'model': model,
|
|
'train_dataset': train_dataset,
|
|
'val_dataset': val_dataset,
|
|
'criterion': FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device)),
|
|
'learning_rate': learning_rate,
|
|
'optimizer': lambda model: optim.AdamW(
|
|
model.parameters(),
|
|
lr=learning_rate,
|
|
weight_decay=0.001,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-8
|
|
),
|
|
'batch_size': 32,
|
|
'max_epochs': max_epochs,
|
|
'scheduler': lambda optimizer: torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer,
|
|
mode='max',
|
|
factor=0.1,
|
|
patience=3,
|
|
min_lr=1e-6
|
|
),
|
|
'num_workers': concurrency,
|
|
}
|
|
|
|
return do_train(cfg)
|
|
|
|
def motion_blur_indicator(image):
|
|
def f(img):
|
|
equalized = cv2.equalizeHist(np.array(img))
|
|
sobelx = cv2.Sobel(equalized, cv2.CV_64F, 1, 0, ksize=3)
|
|
sobely = cv2.Sobel(equalized, cv2.CV_64F, 0, 1, ksize=3)
|
|
|
|
var_x = sobelx.var()
|
|
var_y = sobely.var()
|
|
|
|
ratio = min(var_x, var_y) / max(var_x, var_y)
|
|
|
|
return ratio
|
|
return f(image) * f(image.rotate(45))
|
|
|
|
def calc_clarity(img):
|
|
gray = img.convert('L')
|
|
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1.5))
|
|
equalized = cv2.equalizeHist(np.array(blurred))
|
|
lap = cv2.Laplacian(equalized, cv2.CV_64F)
|
|
return lap.var() * motion_blur_indicator(gray)
|
|
|
|
def load_codes(codes_file):
|
|
with open(codes_file, 'r') as f:
|
|
return [line.strip() for line in f.readlines() if line.strip()]
|
|
|
|
def load_cell_labels(dataset_dir):
|
|
if not os.path.exists(dataset_dir):
|
|
raise Exception(f"Dataset directory {dataset_dir} does not exist")
|
|
neg_labels = []
|
|
pos_labels = []
|
|
all_files = []
|
|
for label in ["pos", "neg"]:
|
|
for r, ds, fs in os.walk(os.path.join(dataset_dir, label)):
|
|
for f in fs:
|
|
if f.startswith('cell_') and f.endswith('.jpg'):
|
|
fp = os.path.join(r, f)
|
|
all_files.append((label, fp))
|
|
random.shuffle(all_files)
|
|
for label, fp in all_files:
|
|
jpg_file = fp
|
|
json_file = os.path.join(os.path.dirname(fp), "metadata.json")
|
|
if not os.path.exists(json_file):
|
|
continue
|
|
with open(json_file, 'r') as f:
|
|
md = json.load(f)
|
|
if 'code' not in md:
|
|
continue
|
|
code = md['code']
|
|
if label == "pos":
|
|
pos_labels.append((jpg_file, 1))
|
|
else:
|
|
neg_labels.append((jpg_file, 0))
|
|
total_pos = len(pos_labels)
|
|
total_neg = len(neg_labels)
|
|
info(f"Total positive: {total_pos}, total negative: {total_neg}")
|
|
if not total_pos:
|
|
raise Exception("No positive labels found")
|
|
if not total_neg:
|
|
raise Exception("No negative labels found")
|
|
return pos_labels, neg_labels
|
|
|
|
def crop_side_by_side(img, cells, xcoord, ycoord, jitter=False):
|
|
width = img.width // (cells * 2)
|
|
height = img.height // cells
|
|
left = img.crop((xcoord * width, ycoord * height, (xcoord + 1) * width, (ycoord + 1) * height))
|
|
right = img.crop(((xcoord + cells) * width, ycoord * height, (xcoord + cells + 1) * width, (ycoord + 1) * height))
|
|
|
|
if jitter:
|
|
movement = 0.02
|
|
left = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(left)
|
|
right = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(right)
|
|
ret = Image.new('RGB', (width * 2, height))
|
|
ret.paste(left, (0, 0))
|
|
ret.paste(right, (width, 0))
|
|
if jitter:
|
|
ret = torchvision.transforms.ColorJitter(brightness=0.2, saturation=0.2, hue=0.5)(ret)
|
|
return ret
|
|
|
|
def average(values):
|
|
return sum(values) / len(values)
|
|
|
|
def random_sample(values, count):
|
|
if len(values) <= count:
|
|
return values
|
|
return random.sample(values, count)
|
|
|
|
class ClarityPredictor(object):
|
|
def __init__(self, model_path=clarity_model):
|
|
self.model_path = clarity_model
|
|
self.model = None
|
|
|
|
def __call__(self, img):
|
|
if not self.model:
|
|
if not os.path.exists(self.model_path):
|
|
return None
|
|
self.model, self.transforms = load_model(self.model_path)
|
|
tensor = self.transforms(img).to(device).unsqueeze(0)
|
|
with torch.no_grad():
|
|
output = self.model(tensor)
|
|
return output.argmax(dim=1).detach().cpu().item() == 1
|
|
|
|
clarity_predictor = ClarityPredictor()
|
|
|
|
class BaseMethod(object):
|
|
def __init__(self, datadir):
|
|
self.datadir = datadir
|
|
|
|
def preprocess(self, scans):
|
|
pass
|
|
|
|
def load_scans(self, datadir, sample_rate=1.0):
|
|
self.datadir = datadir
|
|
pool = Pool(concurrency)
|
|
counts = defaultdict(int)
|
|
pos_scans = []
|
|
neg_scans = []
|
|
all_scan_ids = os.listdir(os.path.join(datadir, "scans"))
|
|
if sample_rate < 1.0:
|
|
all_scan_ids = random.sample(all_scan_ids, int(len(all_scan_ids) * sample_rate))
|
|
for x in tqdm(pool.imap(self.load_one_scan, all_scan_ids), total=len(all_scan_ids), desc="Loading dataset"):
|
|
counts[(x.get('ok'), x.get('error'), x.get('lables'))] += 1
|
|
if x.get('ok'):
|
|
labels = x['data'].get('labels', [])
|
|
if 'pos' in labels:
|
|
pos_scans.append(x['data'])
|
|
elif 'neg' in labels:
|
|
neg_scans.append(x['data'])
|
|
info(f"Counts: {counts}")
|
|
return sorted(pos_scans, key=lambda x: int(x['scan_id'])), sorted(neg_scans, key=lambda x: int(x['scan_id']))
|
|
|
|
def pre_load_one_scan(self, scan_id):
|
|
return True
|
|
|
|
def load_one_scan(self, scan_id):
|
|
if not self.pre_load_one_scan(scan_id):
|
|
return {
|
|
"ok": False,
|
|
"error": f"pre_load_one_scan: skip",
|
|
}
|
|
scan_dir = os.path.join(self.datadir, 'scans', scan_id)
|
|
if not os.path.exists(scan_dir):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Scan dir does not exist",
|
|
}
|
|
std_qr_file = os.path.join(scan_dir, "std-qr.jpg")
|
|
if not os.path.exists(std_qr_file):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Std QR file does not exist",
|
|
}
|
|
mdfile = os.path.join(scan_dir, "metadata.json")
|
|
if not os.path.exists(mdfile):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Metadata file does not exist",
|
|
}
|
|
with open(mdfile, 'r') as f:
|
|
try:
|
|
md = json.load(f)
|
|
except Exception as e:
|
|
return {
|
|
"ok": False,
|
|
"error": f"Error loading metadata file: {e}",
|
|
}
|
|
if not md.get('code'):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Code not found in metadata file",
|
|
}
|
|
if not md.get('labels'):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Labels not found in metadata file",
|
|
}
|
|
if not md.get('relative_clarity'):
|
|
std_qr_img = Image.open(std_qr_file)
|
|
std_qr_clarity = calc_clarity(std_qr_img)
|
|
if not std_qr_clarity:
|
|
return {
|
|
"ok": False,
|
|
"error": f"Std QR clarity invalid",
|
|
}
|
|
frame_qr_file = os.path.join(scan_dir, "frame-qr.jpg")
|
|
if not os.path.exists(frame_qr_file):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Frame QR file does not exist",
|
|
}
|
|
frame_qr_img = Image.open(frame_qr_file)
|
|
frame_qr_clarity = calc_clarity(frame_qr_img)
|
|
relative_clarity = frame_qr_clarity / std_qr_clarity
|
|
md['relative_clarity'] = relative_clarity
|
|
with open(mdfile, 'w') as f:
|
|
json.dump(md, f, indent=2)
|
|
if md['relative_clarity'] < 0.5:
|
|
return {
|
|
"ok": False,
|
|
"error": f"Relative clarity too low",
|
|
}
|
|
self.post_load_one_scan(scan_id)
|
|
return {
|
|
"ok": True,
|
|
"data": {
|
|
"scan_id": scan_id,
|
|
"std_qr_file": std_qr_file,
|
|
"code": md['code'],
|
|
"labels": md['labels'],
|
|
},
|
|
}
|
|
|
|
def train(self, dataset, epochs):
|
|
raise Exception("Not implemented")
|
|
|
|
def load_method(method_name, datadir):
|
|
mod = importlib.import_module('methods.' + method_name)
|
|
return mod.Method(datadir)
|
|
|
|
def balance_pos_and_neg(scans):
|
|
pos_scans = [s for s in scans if 'pos' in s['labels']]
|
|
neg_scans = [s for s in scans if 'neg' in s['labels']]
|
|
random.shuffle(pos_scans)
|
|
random.shuffle(neg_scans)
|
|
min_count = min(len(pos_scans), len(neg_scans))
|
|
pos = pos_scans[:min_count]
|
|
neg = neg_scans[:min_count]
|
|
info(f'balanced from pos: {len(pos_scans)} to {len(pos)}')
|
|
info(f'balanced from neg: {len(neg_scans)} to {len(neg)}')
|
|
ret = pos + neg
|
|
random.shuffle(ret)
|
|
return ret
|