themblem/emblem5/ai/common.py

806 lines
28 KiB
Python

#!/usr/bin/env python3
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision
from PIL import Image, ImageFilter
import os
from datetime import datetime
from collections import defaultdict
import argparse
from kornia.losses.focal import FocalLoss
from kornia.augmentation import ColorJiggle
import cv2
import numpy as np
import random
import re
import shutil
import json
import tempfile
import time
import base64
import subprocess
from loguru import logger
from collections import defaultdict
from multiprocessing import Pool, cpu_count, set_start_method
from tqdm import tqdm
import importlib
concurrency = max(1, cpu_count() - 2)
def info(msg):
logger.info(msg)
def debug(msg):
logger.debug(msg)
cuda_available = torch.cuda.is_available()
device = torch.device('cuda' if cuda_available else 'cpu')
default_model = 'best_model_ep82_pos98.94_neg96.13_20250720_222102.pt'
clarity_model = 'models/clarity-ep15-pos88.14-neg92.23-20250518_164155.pt'
torch.set_float32_matmul_precision('high')
def batch_generator(labels, batch_size):
for i in range(0, len(labels), batch_size):
yield labels[i:i+batch_size]
def make_side_by_side_img(left, right):
min_width = min(left.width, right.width)
min_height = min(left.height, right.height)
left = left.resize((min_width, min_height))
right = right.resize((min_width, min_height))
ret = Image.new('RGB', (min_width * 2, min_height))
ret.paste(left, (0, 0))
ret.paste(right, (min_width, 0))
return ret
def warp_with_margin_ratio(orig, edge, corners, margin_ratio):
src_points = np.float32(corners)
dst_points = np.float32([
[edge * margin_ratio, edge * margin_ratio],
[edge * (1 - margin_ratio), edge * margin_ratio],
[edge * (1 - margin_ratio), edge * (1 - margin_ratio)],
[edge * margin_ratio, edge * (1 - margin_ratio)],
])
M = cv2.getPerspectiveTransform(src_points, dst_points)
warped = cv2.warpPerspective(np.array(orig), M, (edge, edge))
return Image.fromarray(warped)
def find_min_margin_ratio(img, corners):
min_margin = None
for i in range(4):
point = corners[i]
x = point[0]
y = point[1]
this_min = min(
x / img.width,
(img.width - x) / img.width,
y / img.height,
(img.height - y) / img.height,
)
if min_margin is None or this_min < min_margin:
min_margin = this_min
return min_margin
def make_side_by_side_img_with_margins(frame_img, std_img):
std_qr, std_corners = find_qr(std_img)
frame_qr, frame_corners = find_qr(frame_img)
if std_corners is None or frame_corners is None:
return None
# Use a reasonable size - use the larger dimension or at least 512
edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512)
std_margin_ratio = find_min_margin_ratio(std_img, std_corners)
frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners)
# Use the minimum margin ratio to ensure both QR areas are captured correctly
margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05)
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)
# Create horizontal layout: frame on left, std on right
# Each warped image is edge_length x edge_length
ret = Image.new('RGB', (edge_length * 2, edge_length))
ret.paste(frame_warped, (0, 0)) # frame on left
ret.paste(std_warped, (edge_length, 0)) # std on right
return ret
def get_top_margin(img, margin_ratio):
ret = Image.new('RGB', (img.width, int(img.height * margin_ratio)))
ret.paste(img, (0, 0))
return ret
def make_top_margin_stacked_img(frame_img, std_img):
std_qr, std_corners = find_qr(std_img)
frame_qr, frame_corners = find_qr(frame_img)
if std_corners is None or frame_corners is None:
return None
edge_length = min(std_img.width, std_img.height)
margin_ratio = find_min_margin_ratio(std_img, std_corners)
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)
std_top_margin = get_top_margin(std_warped, margin_ratio)
frame_top_margin = get_top_margin(frame_warped, margin_ratio)
outheight = int(edge_length * margin_ratio * 2)
ret = Image.new('RGB', (edge_length, outheight))
ret.paste(std_top_margin, (0, 0))
ret.paste(frame_top_margin, (0, outheight // 2))
return ret
def make_stripe_img(left, right, nstripes):
min_width = min(left.width, right.width)
min_height = min(left.height, right.height)
ret = Image.new('RGB', (min_width * 2, min_height))
left_stripe_width = left.width // nstripes
right_stripe_width = right.width // nstripes
stripe_width = min_width // nstripes
for i in range(nstripes):
left_stripe = left.crop((i * left_stripe_width, 0, (i + 1) * left_stripe_width, left.height))
left_stripe = left_stripe.resize((stripe_width, min_height))
right_stripe = right.crop((i * right_stripe_width, 0, (i + 1) * right_stripe_width, right.height))
right_stripe = right_stripe.resize((stripe_width, min_height))
ret.paste(left_stripe, (i * stripe_width * 2, 0))
ret.paste(right_stripe, (i * stripe_width * 2 + stripe_width, 0))
return ret
def predict_multi(model, transforms, images, ncells=1):
results_per_img = ncells * ncells * 2
ret = []
with torch.no_grad():
tensors = []
for image in images:
for xcoord in range(ncells):
for ycoord in range(ncells):
img = crop_side_by_side(image, ncells, xcoord, ycoord)
img_tensor = transforms(img).to(device)
tensors.append(img_tensor)
output = model(torch.stack(tensors, dim=0))
sub_probs = torch.nn.functional.softmax(output, dim=1).cpu().numpy().tolist()
for i in range(len(images)):
probs = sub_probs[i * results_per_img:(i + 1) * results_per_img]
neg_sum = sum([x[0] for x in probs])
pos_sum = sum([x[1] for x in probs])
predicted_class = 1 if pos_sum > neg_sum else 0
ret.append((predicted_class, probs))
return ret
def predict(model, transforms, image, ncells=1):
r = predict_multi(model, transforms, [image], ncells)
return r[0]
_qr_detector = None
def _get_qr_detector():
global _qr_detector
if _qr_detector is None:
_qr_detector = cv2.wechat_qrcode_WeChatQRCode()
return _qr_detector
def find_qr(img, scale=1.0):
# Convert PIL Image to OpenCV format
orig_size = img.size
if scale < 1.0:
img_w, img_h = img.size
new_w = int(img_w * scale)
new_h = int(img_h * scale)
new_size = (new_w, new_h)
resized_img = img.resize(new_size)
else:
new_size = orig_size
resized_img = img
img_cv = cv2.cvtColor(np.array(resized_img), cv2.COLOR_RGB2BGR)
qr_detector = _get_qr_detector()
qr, corners = qr_detector.detectAndDecode(img_cv)
if not qr:
if scale > 0.05:
return find_qr(img, scale * 2 / 3)
else:
return None, None
corners = np.array(corners[0], dtype=np.float32)
corners /= scale
return qr[0], corners
def extract_qr(img):
qr, corners = find_qr(img)
if not qr:
raise Exception('No QR code found')
corners = np.array(corners, dtype=np.float32)
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# Define target rectangle corners (clockwise from top-left)
min_x = min(corners[:, 0])
max_x = max(corners[:, 0])
min_y = min(corners[:, 1])
max_y = max(corners[:, 1])
width = max_x - min_x
height = max_y - min_y
width = height = int(min(width, height))
dst_corners = np.array([
[0, 0],
[width, 0],
[width, height],
[0, height]
], dtype=np.float32)
matrix = cv2.getPerspectiveTransform(corners, dst_corners)
warped = cv2.warpPerspective(img_cv, matrix, (width, height))
r = Image.fromarray(cv2.cvtColor(warped, cv2.COLOR_BGR2RGB))
return qr, r
def load_model(model_path):
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model = checkpoint['model'] # Load the complete model structure
model.load_state_dict(checkpoint['model_state_dict']) # Load the weights
model.eval()
model = model.to(device)
transforms = checkpoint['transforms']
return model, transforms
def save_model(model, transforms, save_path, metadata=None):
checkpoint = {
'model': model, # Save the complete model structure
'model_state_dict': model.state_dict(),
'transforms': transforms
}
if metadata:
checkpoint['metadata'] = metadata
torch.save(checkpoint, save_path)
def make_model(model_name):
model_makers = {
'resnet': make_resnet,
'resnet18': make_resnet18,
'resnet101': make_resnet101,
'regnet': make_regnet,
'convnext': make_convnext,
'efficientnet': make_efficientnet,
'densenet': make_densenet,
'mobilenet': make_mobilenet,
}
return model_makers[model_name]()
def make_mobilenet():
weights = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
model = torchvision.models.mobilenet_v3_small(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(576, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_densenet():
weights = models.DenseNet121_Weights.IMAGENET1K_V1
model = models.densenet121(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_efficientnet():
weights = torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1
model = torchvision.models.efficientnet_b3(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(1536, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_convnext():
weights = torchvision.models.ConvNeXt_Base_Weights.IMAGENET1K_V1
model = torchvision.models.convnext_base(weights=weights)
model.classifier = nn.Sequential(
nn.Flatten(1), # 先 flatten (从 [B, 1024, 1, 1] 到 [B, 1024])
nn.LayerNorm(1024, eps=1e-6),
nn.Dropout(p=0.4),
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_resnet101():
weights = torchvision.models.ResNet101_Weights.IMAGENET1K_V1
model = torchvision.models.resnet101(weights=weights)
model.fc = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(2048, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_resnet18():
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
model = torchvision.models.resnet18(weights=weights)
model.fc = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_resnet():
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
model = torchvision.models.resnet50(weights=weights)
model.fc = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(2048, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_regnet():
weights = models.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1
model = models.regnet_y_3_2gf(weights=weights)
model.fc = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(1512, 512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(p=0.1),
nn.Linear(128, 2)
)
return model, make_generic_transforms()
def make_generic_transforms():
return torchvision.transforms.Compose([
torchvision.transforms.Resize((128, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class ScanDataset(Dataset):
def __init__(self, scans, transforms):
self.transforms = transforms
self.scans = scans
def __len__(self):
return len(self.scans)
def __getitem__(self, idx):
scan = self.scans[idx]
scan_id = scan['scan_id']
gridcrop2_dir = os.path.join('data', 'gridcrop2', scan_id)
sbs_file = os.path.join(gridcrop2_dir, 'sbs.jpg')
sbs_img = Image.open(sbs_file).convert('RGB')
return self.transforms(sbs_img), 1 if 'pos' in scan['labels'] else 0
def stats(self):
return {
'pos': sum(1 for scan in self.scans if 'pos' in scan['labels']),
'neg': sum(1 for scan in self.scans if 'neg' in scan['labels']),
}
def do_train(cfg):
train_dataset = cfg['train_dataset']
val_dataset = cfg['val_dataset']
batch_size = cfg['batch_size']
num_workers = cfg['num_workers']
prefetch_factor = 4
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
persistent_workers=True,
prefetch_factor=prefetch_factor,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size*2,
shuffle=True,
num_workers=num_workers,
persistent_workers=True,
prefetch_factor=prefetch_factor,
pin_memory=True,
)
model = cfg['model']
# Set memory fraction for better GPU utilization
torch.cuda.set_per_process_memory_fraction(0.8)
# Compile model with optimized settings
model = torch.compile(model).to(device)
criterion = cfg['criterion']
optimizer = cfg['optimizer'](model)
scheduler = cfg['scheduler'](optimizer)
max_epochs = cfg['max_epochs']
# Initialize variables for early stopping
best_epoch = None
# Create GradScaler once at the start of training
scaler = torch.amp.GradScaler('cuda')
for epoch in range(max_epochs):
info(f"Starting epoch {epoch+1}/{max_epochs}")
model.train()
running_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Train {epoch+1}/{max_epochs}", unit_scale=batch_size):
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
optimizer.zero_grad()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
outputs = model(images)
loss = criterion(outputs, labels).mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.detach().item() * len(images)
running_loss = running_loss / len(train_loader) / cfg['learning_rate']
info(f'Epoch [{epoch+1}/{max_epochs}], Loss: {running_loss:.5f}')
# 验证阶段
model.eval()
pos_correct = 0
pos_total = 0
neg_correct = 0
neg_total = 0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Validating {epoch+1}/{max_epochs}", unit_scale=batch_size):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
pos_correct += ((predicted == labels) & (labels == 1)).sum().item()
pos_total += (labels == 1).sum().item()
neg_correct += ((predicted == labels) & (labels == 0)).sum().item()
neg_total += (labels == 0).sum().item()
pos_accu = pos_correct / pos_total if pos_total else 0
neg_accu = neg_correct / neg_total if neg_total else 0
total_accu = (pos_correct + neg_correct) / (pos_total + neg_total)
info(f'Pos Accu: {pos_accu:.2%} ({pos_correct}/{pos_total})')
info(f'Neg Accu: {neg_accu:.2%} ({neg_correct}/{neg_total})')
info(f'Total Accu: {total_accu:.2%} ({pos_correct + neg_correct}/{pos_total + neg_total})')
if not best_epoch or total_accu > best_epoch['total_accu']:
best_epoch = {
'epoch': epoch,
'pos_accu': pos_accu,
'neg_accu': neg_accu,
'total_accu': total_accu,
'model_state_dict': model.state_dict().copy(),
}
info(f'New best model found with total accuracy: {best_epoch["total_accu"]:.2%}')
scheduler.step(total_accu)
# Load the best model weights
if best_epoch is not None:
model.load_state_dict(best_epoch['model_state_dict'])
info(f'Loaded best model with total accuracy: {best_epoch["total_accu"]:.2%}')
return model, best_epoch
def verify_frame(model, transforms, frame_img, orig_img):
side_by_side_img = make_side_by_side_img_with_margins(frame_img, orig_img)
if side_by_side_img is None:
raise Exception("Failed to create side-by-side image with margins")
side_by_side_img = side_by_side_img.convert('RGB')
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
side_by_side_img.save(f.name)
return predict(model, transforms, Image.open(f.name).convert('RGB'), ncells=3)
def parse_ranges(s):
ret = []
for tr in s.split(','):
if '-' in tr:
begin, end = tr.split('-')
ret.append([int(begin), int(end)])
else:
ret.append([int(tr), int(tr)])
return ret
def in_range(x, val_range):
if not val_range:
return False
start, end = val_range
return start <= int(x) <= end
def train_model(model, train_dataset, val_dataset, max_epochs, pos_weight=0.99):
info(f"Train count: {len(train_dataset)}, val count: {len(val_dataset)}")
learning_rate = 0.001
cfg = {
'model': model,
'train_dataset': train_dataset,
'val_dataset': val_dataset,
'criterion': FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device)),
'learning_rate': learning_rate,
'optimizer': lambda model: optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=0.001,
betas=(0.9, 0.999),
eps=1e-8
),
'batch_size': 32,
'max_epochs': max_epochs,
'scheduler': lambda optimizer: torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.1,
patience=3,
min_lr=1e-6
),
'num_workers': concurrency,
}
return do_train(cfg)
def motion_blur_indicator(image):
def f(img):
equalized = cv2.equalizeHist(np.array(img))
sobelx = cv2.Sobel(equalized, cv2.CV_64F, 1, 0, ksize=3)
sobely = cv2.Sobel(equalized, cv2.CV_64F, 0, 1, ksize=3)
var_x = sobelx.var()
var_y = sobely.var()
ratio = min(var_x, var_y) / max(var_x, var_y)
return ratio
return f(image) * f(image.rotate(45))
def calc_clarity(img):
gray = img.convert('L')
blurred = gray.filter(ImageFilter.GaussianBlur(radius=1.5))
equalized = cv2.equalizeHist(np.array(blurred))
lap = cv2.Laplacian(equalized, cv2.CV_64F)
return lap.var() * motion_blur_indicator(gray)
def load_codes(codes_file):
with open(codes_file, 'r') as f:
return [line.strip() for line in f.readlines() if line.strip()]
def load_cell_labels(dataset_dir):
if not os.path.exists(dataset_dir):
raise Exception(f"Dataset directory {dataset_dir} does not exist")
neg_labels = []
pos_labels = []
all_files = []
for label in ["pos", "neg"]:
for r, ds, fs in os.walk(os.path.join(dataset_dir, label)):
for f in fs:
if f.startswith('cell_') and f.endswith('.jpg'):
fp = os.path.join(r, f)
all_files.append((label, fp))
random.shuffle(all_files)
for label, fp in all_files:
jpg_file = fp
json_file = os.path.join(os.path.dirname(fp), "metadata.json")
if not os.path.exists(json_file):
continue
with open(json_file, 'r') as f:
md = json.load(f)
if 'code' not in md:
continue
code = md['code']
if label == "pos":
pos_labels.append((jpg_file, 1))
else:
neg_labels.append((jpg_file, 0))
total_pos = len(pos_labels)
total_neg = len(neg_labels)
info(f"Total positive: {total_pos}, total negative: {total_neg}")
if not total_pos:
raise Exception("No positive labels found")
if not total_neg:
raise Exception("No negative labels found")
return pos_labels, neg_labels
def crop_side_by_side(img, cells, xcoord, ycoord, jitter=False):
width = img.width // (cells * 2)
height = img.height // cells
left = img.crop((xcoord * width, ycoord * height, (xcoord + 1) * width, (ycoord + 1) * height))
right = img.crop(((xcoord + cells) * width, ycoord * height, (xcoord + cells + 1) * width, (ycoord + 1) * height))
if jitter:
movement = 0.02
left = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(left)
right = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(right)
ret = Image.new('RGB', (width * 2, height))
ret.paste(left, (0, 0))
ret.paste(right, (width, 0))
if jitter:
ret = torchvision.transforms.ColorJitter(brightness=0.2, saturation=0.2, hue=0.5)(ret)
return ret
def average(values):
return sum(values) / len(values)
def random_sample(values, count):
if len(values) <= count:
return values
return random.sample(values, count)
class ClarityPredictor(object):
def __init__(self, model_path=clarity_model):
self.model_path = clarity_model
self.model = None
def __call__(self, img):
if not self.model:
if not os.path.exists(self.model_path):
return None
self.model, self.transforms = load_model(self.model_path)
tensor = self.transforms(img).to(device).unsqueeze(0)
with torch.no_grad():
output = self.model(tensor)
return output.argmax(dim=1).detach().cpu().item() == 1
clarity_predictor = ClarityPredictor()
class BaseMethod(object):
def __init__(self, datadir):
self.datadir = datadir
def preprocess(self, scans):
pass
def load_scans(self, datadir, sample_rate=1.0):
self.datadir = datadir
pool = Pool(concurrency)
counts = defaultdict(int)
pos_scans = []
neg_scans = []
all_scan_ids = os.listdir(os.path.join(datadir, "scans"))
if sample_rate < 1.0:
all_scan_ids = random.sample(all_scan_ids, int(len(all_scan_ids) * sample_rate))
for x in tqdm(pool.imap(self.load_one_scan, all_scan_ids), total=len(all_scan_ids), desc="Loading dataset"):
counts[(x.get('ok'), x.get('error'), x.get('lables'))] += 1
if x.get('ok'):
labels = x['data'].get('labels', [])
if 'pos' in labels:
pos_scans.append(x['data'])
elif 'neg' in labels:
neg_scans.append(x['data'])
info(f"Counts: {counts}")
return sorted(pos_scans, key=lambda x: int(x['scan_id'])), sorted(neg_scans, key=lambda x: int(x['scan_id']))
def pre_load_one_scan(self, scan_id):
return True
def load_one_scan(self, scan_id):
if not self.pre_load_one_scan(scan_id):
return {
"ok": False,
"error": f"pre_load_one_scan: skip",
}
scan_dir = os.path.join(self.datadir, 'scans', scan_id)
if not os.path.exists(scan_dir):
return {
"ok": False,
"error": f"Scan dir does not exist",
}
std_qr_file = os.path.join(scan_dir, "std-qr.jpg")
if not os.path.exists(std_qr_file):
return {
"ok": False,
"error": f"Std QR file does not exist",
}
mdfile = os.path.join(scan_dir, "metadata.json")
if not os.path.exists(mdfile):
return {
"ok": False,
"error": f"Metadata file does not exist",
}
with open(mdfile, 'r') as f:
try:
md = json.load(f)
except Exception as e:
return {
"ok": False,
"error": f"Error loading metadata file: {e}",
}
if not md.get('code'):
return {
"ok": False,
"error": f"Code not found in metadata file",
}
if not md.get('labels'):
return {
"ok": False,
"error": f"Labels not found in metadata file",
}
if not md.get('relative_clarity'):
std_qr_img = Image.open(std_qr_file)
std_qr_clarity = calc_clarity(std_qr_img)
if not std_qr_clarity:
return {
"ok": False,
"error": f"Std QR clarity invalid",
}
frame_qr_file = os.path.join(scan_dir, "frame-qr.jpg")
if not os.path.exists(frame_qr_file):
return {
"ok": False,
"error": f"Frame QR file does not exist",
}
frame_qr_img = Image.open(frame_qr_file)
frame_qr_clarity = calc_clarity(frame_qr_img)
relative_clarity = frame_qr_clarity / std_qr_clarity
md['relative_clarity'] = relative_clarity
with open(mdfile, 'w') as f:
json.dump(md, f, indent=2)
if md['relative_clarity'] < 0.5:
return {
"ok": False,
"error": f"Relative clarity too low",
}
self.post_load_one_scan(scan_id)
return {
"ok": True,
"data": {
"scan_id": scan_id,
"std_qr_file": std_qr_file,
"code": md['code'],
"labels": md['labels'],
},
}
def train(self, dataset, epochs):
raise Exception("Not implemented")
def load_method(method_name, datadir):
mod = importlib.import_module('methods.' + method_name)
return mod.Method(datadir)
def balance_pos_and_neg(scans):
pos_scans = [s for s in scans if 'pos' in s['labels']]
neg_scans = [s for s in scans if 'neg' in s['labels']]
random.shuffle(pos_scans)
random.shuffle(neg_scans)
min_count = min(len(pos_scans), len(neg_scans))
pos = pos_scans[:min_count]
neg = neg_scans[:min_count]
info(f'balanced from pos: {len(pos_scans)} to {len(pos)}')
info(f'balanced from neg: {len(neg_scans)} to {len(neg)}')
ret = pos + neg
random.shuffle(ret)
return ret