#!/usr/bin/env python3 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, models from PIL import Image import os from datetime import datetime from collections import defaultdict import random import argparse from kornia.losses.focal import FocalLoss from torchvision.models import ResNet50_Weights class SideBySideDataset(Dataset): def __init__(self, img_dir, labels_file, transform=None): self.img_dir = img_dir self.transform = transform self.img_labels = self._load_labels(labels_file) def _load_labels(self, labels_file): ret = [] with open(labels_file, 'r') as f: lines = f.readlines() for line in lines: if not line.strip(): continue img_name, label = line.strip().split() ret.append([img_name, label]) return ret def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_name, label = self.img_labels[idx] img_path = os.path.join(self.img_dir, img_name) image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) label = int(label) return image, label def parse_args(): parser = argparse.ArgumentParser(description='Train a model') parser.add_argument('--labels-file', required=True, type=str, help='Path to the labels file') return parser.parse_args() def main(): args = parse_args() # 数据预处理 transform_train = transforms.Compose([ #transforms.RandomResizedCrop((128, 64)), # 随机裁剪 #transforms.RandomHorizontalFlip(), # 随机水平翻转 #transforms.Resize((256, 128)), # 调整大小 transforms.ToTensor(), # 转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 ]) transform_val = transforms.Compose([ #transforms.Resize((256, 128)), # 调整大小 transforms.ToTensor(), # 转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 ]) # 加载数据集 img_dir = os.path.abspath("data/roi/train") full_dataset = SideBySideDataset( img_dir=img_dir, labels_file=args.labels_file, transform=transform_train, ) train_ratio = 0.5 val_ratio = 0.5 train_size = int(train_ratio * len(full_dataset)) val_size = len(full_dataset) - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True) model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) # 修改最后一层全连接层,使其输出为2(二分类) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # 将模型移动到GPU(如果可用) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # criterion = nn.CrossEntropyLoss() pos_weight = 0.9 criterion = FocalLoss(0.25, weight=torch.Tensor([1 - pos_weight, pos_weight])) optimizer = optim.Adam(model.parameters(), lr=0.0001) #optimizer = torch.optim.SGD( model.parameters(), lr=0.0001, momentum=0.09, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) num_epochs = 15 last_accu = 0 prev_accu = [] for epoch in range(num_epochs): # 训练阶段 print(f"Start training epoch {epoch+1}/{num_epochs}") model.train() running_loss = 0.0 for images, labels in train_loader: # images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss_sum = loss.sum() loss_sum.backward() optimizer.step() running_loss += loss_sum print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') scheduler.step() # 验证阶段 model.eval() pos_correct = 0 pos_total = 0 neg_correct = 0 neg_total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) pos_correct += ((predicted == labels) & (labels == 1)).sum().item() pos_total += (labels == 1).sum().item() neg_correct += ((predicted == labels) & (labels == 0)).sum().item() neg_total += (labels == 0).sum().item() last_accu = 100 * (pos_correct + neg_correct) / (pos_total + neg_total) prev_accu.append(last_accu) if 0 and epoch > 5: avg = sum(prev_accu[-5:]) / 5 variance = sum((x - avg) ** 2 for x in prev_accu[-5:]) / 5 print(f"variance={variance:.4f}") if variance < 1 and not (prev_accu[-1] > prev_accu[-2] and prev_accu[-2] > prev_accu[-3] and prev_accu[-3] > prev_accu[-4]): print(f"Early stopping condition met: {variance:.4f}") break print(f'Pos Accu: {100 * pos_correct / pos_total:.2f}%') print(f'Neg Accu: {100 * neg_correct / neg_total:.2f}%') dt = datetime.now().strftime("%Y%m%d_%H%M%S") torch.save(model.state_dict(), f'data/roi/models/resnet_{dt}_{last_accu:.2f}.pth') if __name__ == "__main__": main()