#!/usr/bin/env python3 import sys import torch import torch.nn as nn import torch.optim as optim from torch.amp import autocast, GradScaler from torch.utils.data import DataLoader, Dataset, ConcatDataset import torchvision from PIL import Image, ImageFilter import os from datetime import datetime from collections import defaultdict import argparse from kornia.losses.focal import FocalLoss from kornia.augmentation import ColorJiggle import cv2 import numpy as np import random import re import shutil import json import tempfile import time import base64 import subprocess from loguru import logger from collections import defaultdict from multiprocessing import Pool, cpu_count, set_start_method from tqdm import tqdm import importlib concurrency = max(1, cpu_count() - 2) def info(msg): logger.info(msg) def debug(msg): logger.debug(msg) cuda_available = torch.cuda.is_available() device = torch.device('cuda' if cuda_available else 'cpu') default_model = 'best_model_ep82_pos98.94_neg96.13_20250720_222102.pt' clarity_model = 'models/clarity-ep15-pos88.14-neg92.23-20250518_164155.pt' torch.set_float32_matmul_precision('high') def batch_generator(labels, batch_size): for i in range(0, len(labels), batch_size): yield labels[i:i+batch_size] def make_side_by_side_img(left, right): min_width = min(left.width, right.width) min_height = min(left.height, right.height) left = left.resize((min_width, min_height)) right = right.resize((min_width, min_height)) ret = Image.new('RGB', (min_width * 2, min_height)) ret.paste(left, (0, 0)) ret.paste(right, (min_width, 0)) return ret def warp_with_margin_ratio(orig, edge, corners, margin_ratio): src_points = np.float32(corners) dst_points = np.float32([ [edge * margin_ratio, edge * margin_ratio], [edge * (1 - margin_ratio), edge * margin_ratio], [edge * (1 - margin_ratio), edge * (1 - margin_ratio)], [edge * margin_ratio, edge * (1 - margin_ratio)], ]) M = cv2.getPerspectiveTransform(src_points, dst_points) warped = cv2.warpPerspective(np.array(orig), M, (edge, edge)) return Image.fromarray(warped) def find_min_margin_ratio(img, corners): min_margin = None for i in range(4): point = corners[i] x = point[0] y = point[1] this_min = min( x / img.width, (img.width - x) / img.width, y / img.height, (img.height - y) / img.height, ) if min_margin is None or this_min < min_margin: min_margin = this_min return min_margin def make_side_by_side_img_with_margins(frame_img, std_img): std_qr, std_corners = find_qr(std_img) frame_qr, frame_corners = find_qr(frame_img) if std_corners is None or frame_corners is None: return None # Use a reasonable size - use the larger dimension or at least 512 edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512) std_margin_ratio = find_min_margin_ratio(std_img, std_corners) frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners) # Use the minimum margin ratio to ensure both QR areas are captured correctly margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05) std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio) frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio) # Create horizontal layout: frame on left, std on right # Each warped image is edge_length x edge_length ret = Image.new('RGB', (edge_length * 2, edge_length)) ret.paste(frame_warped, (0, 0)) # frame on left ret.paste(std_warped, (edge_length, 0)) # std on right return ret def get_top_margin(img, margin_ratio): ret = Image.new('RGB', (img.width, int(img.height * margin_ratio))) ret.paste(img, (0, 0)) return ret def make_top_margin_stacked_img(frame_img, std_img): std_qr, std_corners = find_qr(std_img) frame_qr, frame_corners = find_qr(frame_img) if std_corners is None or frame_corners is None: return None edge_length = min(std_img.width, std_img.height) margin_ratio = find_min_margin_ratio(std_img, std_corners) std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio) frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio) std_top_margin = get_top_margin(std_warped, margin_ratio) frame_top_margin = get_top_margin(frame_warped, margin_ratio) outheight = int(edge_length * margin_ratio * 2) ret = Image.new('RGB', (edge_length, outheight)) ret.paste(std_top_margin, (0, 0)) ret.paste(frame_top_margin, (0, outheight // 2)) return ret def make_stripe_img(left, right, nstripes): min_width = min(left.width, right.width) min_height = min(left.height, right.height) ret = Image.new('RGB', (min_width * 2, min_height)) left_stripe_width = left.width // nstripes right_stripe_width = right.width // nstripes stripe_width = min_width // nstripes for i in range(nstripes): left_stripe = left.crop((i * left_stripe_width, 0, (i + 1) * left_stripe_width, left.height)) left_stripe = left_stripe.resize((stripe_width, min_height)) right_stripe = right.crop((i * right_stripe_width, 0, (i + 1) * right_stripe_width, right.height)) right_stripe = right_stripe.resize((stripe_width, min_height)) ret.paste(left_stripe, (i * stripe_width * 2, 0)) ret.paste(right_stripe, (i * stripe_width * 2 + stripe_width, 0)) return ret def predict_multi(model, transforms, images, ncells=1): results_per_img = ncells * ncells ret = [] with torch.no_grad(): tensors = [] for image in images: for xcoord in range(ncells): for ycoord in range(ncells): img = crop_side_by_side(image, ncells, xcoord, ycoord) img_tensor = transforms(img).to(device) tensors.append(img_tensor) output = model(torch.stack(tensors, dim=0)) sub_probs = torch.nn.functional.softmax(output, dim=1).cpu().numpy().tolist() for i in range(len(images)): probs = sub_probs[i * results_per_img:(i + 1) * results_per_img] neg_sum = sum([x[0] for x in probs]) pos_sum = sum([x[1] for x in probs]) predicted_class = 1 if pos_sum > neg_sum else 0 ret.append((predicted_class, probs)) return ret def predict(model, transforms, image, ncells=1): r = predict_multi(model, transforms, [image], ncells) return r[0] qr_detector = cv2.wechat_qrcode_WeChatQRCode() def find_qr(img, scale=1.0): # Convert PIL Image to OpenCV format orig_size = img.size if scale < 1.0: img_w, img_h = img.size new_w = int(img_w * scale) new_h = int(img_h * scale) new_size = (new_w, new_h) resized_img = img.resize(new_size) else: new_size = orig_size resized_img = img img_cv = cv2.cvtColor(np.array(resized_img), cv2.COLOR_RGB2BGR) qr, corners = qr_detector.detectAndDecode(img_cv) if not qr: if scale > 0.05: return find_qr(img, scale * 2 / 3) else: return None, None corners = np.array(corners[0], dtype=np.float32) corners /= scale return qr[0], corners def extract_qr(img): qr, corners = find_qr(img) if not qr: raise Exception('No QR code found') corners = np.array(corners, dtype=np.float32) img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # Define target rectangle corners (clockwise from top-left) min_x = min(corners[:, 0]) max_x = max(corners[:, 0]) min_y = min(corners[:, 1]) max_y = max(corners[:, 1]) width = max_x - min_x height = max_y - min_y width = height = int(min(width, height)) dst_corners = np.array([ [0, 0], [width, 0], [width, height], [0, height] ], dtype=np.float32) matrix = cv2.getPerspectiveTransform(corners, dst_corners) warped = cv2.warpPerspective(img_cv, matrix, (width, height)) r = Image.fromarray(cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)) return qr, r def load_model(model_path): checkpoint = torch.load(model_path, map_location=device, weights_only=False) model = checkpoint['model'] # Load the complete model structure model.load_state_dict(checkpoint['model_state_dict']) # Load the weights model.eval() model = model.to(device) transforms = checkpoint['transforms'] return model, transforms def save_model(model, transforms, save_path, metadata=None): checkpoint = { 'model': model, # Save the complete model structure 'model_state_dict': model.state_dict(), 'transforms': transforms } if metadata: checkpoint['metadata'] = metadata torch.save(checkpoint, save_path) def make_model(model_name): model_makers = { 'resnet': make_resnet, 'resnet18': make_resnet18, 'resnet101': make_resnet101, 'regnet': make_regnet, 'convnext': make_convnext, 'efficientnet': make_efficientnet, 'densenet': make_densenet, 'mobilenet': make_mobilenet, } return model_makers[model_name]() def make_mobilenet(): weights = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1 model = torchvision.models.mobilenet_v3_small(weights=weights) model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(576, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_densenet(): weights = models.DenseNet121_Weights.IMAGENET1K_V1 model = models.densenet121(weights=weights) model.classifier = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(1024, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_efficientnet(): weights = torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1 model = torchvision.models.efficientnet_b3(weights=weights) model.classifier = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(1536, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_convnext(): weights = torchvision.models.ConvNeXt_Base_Weights.IMAGENET1K_V1 model = torchvision.models.convnext_base(weights=weights) model.classifier = nn.Sequential( nn.Flatten(1), # 先 flatten (从 [B, 1024, 1, 1] 到 [B, 1024]) nn.LayerNorm(1024, eps=1e-6), nn.Dropout(p=0.4), nn.Linear(1024, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_resnet101(): weights = torchvision.models.ResNet101_Weights.IMAGENET1K_V1 model = torchvision.models.resnet101(weights=weights) model.fc = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(2048, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_resnet18(): weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1 model = torchvision.models.resnet18(weights=weights) model.fc = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_resnet(): weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 model = torchvision.models.resnet50(weights=weights) model.fc = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(2048, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.2), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_regnet(): weights = models.RegNet_Y_3_2GF_Weights.IMAGENET1K_V1 model = models.regnet_y_3_2gf(weights=weights) model.fc = nn.Sequential( nn.Dropout(p=0.4), nn.Linear(1512, 512), nn.GELU(), nn.Dropout(p=0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(128, 2) ) return model, make_generic_transforms() def make_generic_transforms(): return torchvision.transforms.Compose([ torchvision.transforms.Resize((128, 256)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class ScanDataset(Dataset): def __init__(self, scans, transforms): self.transforms = transforms self.scans = scans def __len__(self): return len(self.scans) def __getitem__(self, idx): scan = self.scans[idx] scan_id = scan['scan_id'] gridcrop2_dir = os.path.join('data', 'gridcrop2', scan_id) sbs_file = os.path.join(gridcrop2_dir, 'sbs.jpg') sbs_img = Image.open(sbs_file).convert('RGB') return self.transforms(sbs_img), 1 if 'pos' in scan['labels'] else 0 def stats(self): return { 'pos': sum(1 for scan in self.scans if 'pos' in scan['labels']), 'neg': sum(1 for scan in self.scans if 'neg' in scan['labels']), } def do_train(cfg): train_dataset = cfg['train_dataset'] val_dataset = cfg['val_dataset'] batch_size = cfg['batch_size'] num_workers = cfg['num_workers'] prefetch_factor = 4 train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True, prefetch_factor=prefetch_factor, pin_memory=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size*2, shuffle=True, num_workers=num_workers, persistent_workers=True, prefetch_factor=prefetch_factor, pin_memory=True, ) model = cfg['model'] # Set memory fraction for better GPU utilization torch.cuda.set_per_process_memory_fraction(0.8) # Compile model with optimized settings model = torch.compile(model).to(device) criterion = cfg['criterion'] optimizer = cfg['optimizer'](model) scheduler = cfg['scheduler'](optimizer) max_epochs = cfg['max_epochs'] # Initialize variables for early stopping best_epoch = None # Create GradScaler once at the start of training scaler = torch.amp.GradScaler('cuda') for epoch in range(max_epochs): info(f"Starting epoch {epoch+1}/{max_epochs}") model.train() running_loss = 0.0 for images, labels in tqdm(train_loader, desc=f"Train {epoch+1}/{max_epochs}", unit_scale=batch_size): images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) optimizer.zero_grad() with torch.amp.autocast(device_type='cuda', dtype=torch.float16): outputs = model(images) loss = criterion(outputs, labels).mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running_loss += loss.detach().item() * len(images) running_loss = running_loss / len(train_loader) / cfg['learning_rate'] info(f'Epoch [{epoch+1}/{max_epochs}], Loss: {running_loss:.5f}') # 验证阶段 model.eval() pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 with torch.no_grad(): for images, labels in tqdm(val_loader, desc=f"Validating {epoch+1}/{max_epochs}", unit_scale=batch_size): images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) pos_correct += ((predicted == labels) & (labels == 1)).sum().item() pos_total += (labels == 1).sum().item() neg_correct += ((predicted == labels) & (labels == 0)).sum().item() neg_total += (labels == 0).sum().item() pos_accu = pos_correct / pos_total if pos_total else 0 neg_accu = neg_correct / neg_total if neg_total else 0 total_accu = (pos_correct + neg_correct) / (pos_total + neg_total) info(f'Pos Accu: {pos_accu:.2%} ({pos_correct}/{pos_total})') info(f'Neg Accu: {neg_accu:.2%} ({neg_correct}/{neg_total})') info(f'Total Accu: {total_accu:.2%} ({pos_correct + neg_correct}/{pos_total + neg_total})') if not best_epoch or total_accu > best_epoch['total_accu']: best_epoch = { 'epoch': epoch, 'pos_accu': pos_accu, 'neg_accu': neg_accu, 'total_accu': total_accu, 'model_state_dict': model.state_dict().copy(), } info(f'New best model found with total accuracy: {best_epoch["total_accu"]:.2%}') scheduler.step(total_accu) # Load the best model weights if best_epoch is not None: model.load_state_dict(best_epoch['model_state_dict']) info(f'Loaded best model with total accuracy: {best_epoch["total_accu"]:.2%}') return model, best_epoch def verify_frame(model, transforms, frame_img, orig_img): side_by_side_img = make_side_by_side_img_with_margins(frame_img, orig_img) if side_by_side_img is None: raise Exception("Failed to create side-by-side image with margins") side_by_side_img = side_by_side_img.convert('RGB') with tempfile.NamedTemporaryFile(suffix='.jpg') as f: side_by_side_img.save(f.name) return predict(model, transforms, Image.open(f.name).convert('RGB')) def parse_ranges(s): ret = [] for tr in s.split(','): if '-' in tr: begin, end = tr.split('-') ret.append([int(begin), int(end)]) else: ret.append([int(tr), int(tr)]) return ret def in_range(x, val_range): if not val_range: return False start, end = val_range return start <= int(x) <= end def train_model(model, train_dataset, val_dataset, max_epochs, pos_weight=0.99): info(f"Train count: {len(train_dataset)}, val count: {len(val_dataset)}") learning_rate = 0.001 cfg = { 'model': model, 'train_dataset': train_dataset, 'val_dataset': val_dataset, 'criterion': FocalLoss(0.25, weight=torch.Tensor([1.0 - pos_weight, pos_weight]).to(device)), 'learning_rate': learning_rate, 'optimizer': lambda model: optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=0.001, betas=(0.9, 0.999), eps=1e-8 ), 'batch_size': 32, 'max_epochs': max_epochs, 'scheduler': lambda optimizer: torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.1, patience=3, min_lr=1e-6 ), 'num_workers': concurrency, } return do_train(cfg) def motion_blur_indicator(image): def f(img): equalized = cv2.equalizeHist(np.array(img)) sobelx = cv2.Sobel(equalized, cv2.CV_64F, 1, 0, ksize=3) sobely = cv2.Sobel(equalized, cv2.CV_64F, 0, 1, ksize=3) var_x = sobelx.var() var_y = sobely.var() ratio = min(var_x, var_y) / max(var_x, var_y) return ratio return f(image) * f(image.rotate(45)) def calc_clarity(img): gray = img.convert('L') blurred = gray.filter(ImageFilter.GaussianBlur(radius=1.5)) equalized = cv2.equalizeHist(np.array(blurred)) lap = cv2.Laplacian(equalized, cv2.CV_64F) return lap.var() * motion_blur_indicator(gray) def load_codes(codes_file): with open(codes_file, 'r') as f: return [line.strip() for line in f.readlines() if line.strip()] def load_cell_labels(dataset_dir): if not os.path.exists(dataset_dir): raise Exception(f"Dataset directory {dataset_dir} does not exist") neg_labels = [] pos_labels = [] all_files = [] for label in ["pos", "neg"]: for r, ds, fs in os.walk(os.path.join(dataset_dir, label)): for f in fs: if f.startswith('cell_') and f.endswith('.jpg'): fp = os.path.join(r, f) all_files.append((label, fp)) random.shuffle(all_files) for label, fp in all_files: jpg_file = fp json_file = os.path.join(os.path.dirname(fp), "metadata.json") if not os.path.exists(json_file): continue with open(json_file, 'r') as f: md = json.load(f) if 'code' not in md: continue code = md['code'] if label == "pos": pos_labels.append((jpg_file, 1)) else: neg_labels.append((jpg_file, 0)) total_pos = len(pos_labels) total_neg = len(neg_labels) info(f"Total positive: {total_pos}, total negative: {total_neg}") if not total_pos: raise Exception("No positive labels found") if not total_neg: raise Exception("No negative labels found") return pos_labels, neg_labels def crop_side_by_side(img, cells, xcoord, ycoord, jitter=False): width = img.width // (cells * 2) height = img.height // cells left = img.crop((xcoord * width, ycoord * height, (xcoord + 1) * width, (ycoord + 1) * height)) right = img.crop(((xcoord + cells) * width, ycoord * height, (xcoord + cells + 1) * width, (ycoord + 1) * height)) if jitter: movement = 0.02 left = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(left) right = torchvision.transforms.RandomAffine(degrees=0, translate=(movement, movement))(right) ret = Image.new('RGB', (width * 2, height)) ret.paste(left, (0, 0)) ret.paste(right, (width, 0)) if jitter: ret = torchvision.transforms.ColorJitter(brightness=0.2, saturation=0.2, hue=0.5)(ret) return ret def average(values): return sum(values) / len(values) def random_sample(values, count): if len(values) <= count: return values return random.sample(values, count) class ClarityPredictor(object): def __init__(self, model_path=clarity_model): self.model_path = clarity_model self.model = None def __call__(self, img): if not self.model: if not os.path.exists(self.model_path): return None self.model, self.transforms = load_model(self.model_path) tensor = self.transforms(img).to(device).unsqueeze(0) with torch.no_grad(): output = self.model(tensor) return output.argmax(dim=1).detach().cpu().item() == 1 clarity_predictor = ClarityPredictor() class BaseMethod(object): def __init__(self, datadir): self.datadir = datadir def preprocess(self, scans): pass def load_scans(self, datadir, sample_rate=1.0): self.datadir = datadir pool = Pool(concurrency) counts = defaultdict(int) pos_scans = [] neg_scans = [] all_scan_ids = os.listdir(os.path.join(datadir, "scans")) if sample_rate < 1.0: all_scan_ids = random.sample(all_scan_ids, int(len(all_scan_ids) * sample_rate)) for x in tqdm(pool.imap(self.load_one_scan, all_scan_ids), total=len(all_scan_ids), desc="Loading dataset"): counts[(x.get('ok'), x.get('error'), x.get('lables'))] += 1 if x.get('ok'): labels = x['data'].get('labels', []) if 'pos' in labels: pos_scans.append(x['data']) elif 'neg' in labels: neg_scans.append(x['data']) info(f"Counts: {counts}") return sorted(pos_scans, key=lambda x: int(x['scan_id'])), sorted(neg_scans, key=lambda x: int(x['scan_id'])) def pre_load_one_scan(self, scan_id): return True def load_one_scan(self, scan_id): if not self.pre_load_one_scan(scan_id): return { "ok": False, "error": f"pre_load_one_scan: skip", } scan_dir = os.path.join(self.datadir, 'scans', scan_id) if not os.path.exists(scan_dir): return { "ok": False, "error": f"Scan dir does not exist", } std_qr_file = os.path.join(scan_dir, "std-qr.jpg") if not os.path.exists(std_qr_file): return { "ok": False, "error": f"Std QR file does not exist", } mdfile = os.path.join(scan_dir, "metadata.json") if not os.path.exists(mdfile): return { "ok": False, "error": f"Metadata file does not exist", } with open(mdfile, 'r') as f: try: md = json.load(f) except Exception as e: return { "ok": False, "error": f"Error loading metadata file: {e}", } if not md.get('code'): return { "ok": False, "error": f"Code not found in metadata file", } if not md.get('labels'): return { "ok": False, "error": f"Labels not found in metadata file", } if not md.get('relative_clarity'): std_qr_img = Image.open(std_qr_file) std_qr_clarity = calc_clarity(std_qr_img) if not std_qr_clarity: return { "ok": False, "error": f"Std QR clarity invalid", } frame_qr_file = os.path.join(scan_dir, "frame-qr.jpg") if not os.path.exists(frame_qr_file): return { "ok": False, "error": f"Frame QR file does not exist", } frame_qr_img = Image.open(frame_qr_file) frame_qr_clarity = calc_clarity(frame_qr_img) relative_clarity = frame_qr_clarity / std_qr_clarity md['relative_clarity'] = relative_clarity with open(mdfile, 'w') as f: json.dump(md, f, indent=2) if md['relative_clarity'] < 0.5: return { "ok": False, "error": f"Relative clarity too low", } self.post_load_one_scan(scan_id) return { "ok": True, "data": { "scan_id": scan_id, "std_qr_file": std_qr_file, "code": md['code'], "labels": md['labels'], }, } def train(self, dataset, epochs): raise Exception("Not implemented") def load_method(method_name, datadir): mod = importlib.import_module('methods.' + method_name) return mod.Method(datadir) def balance_pos_and_neg(scans): pos_scans = [s for s in scans if 'pos' in s['labels']] neg_scans = [s for s in scans if 'neg' in s['labels']] random.shuffle(pos_scans) random.shuffle(neg_scans) min_count = min(len(pos_scans), len(neg_scans)) pos = pos_scans[:min_count] neg = neg_scans[:min_count] info(f'balanced from pos: {len(pos_scans)} to {len(pos)}') info(f'balanced from neg: {len(neg_scans)} to {len(neg)}') ret = pos + neg random.shuffle(ret) return ret