themblem/research/roi-train.py
2025-03-26 06:53:58 +08:00

151 lines
5.2 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from PIL import Image
import os
from datetime import datetime
from collections import defaultdict
import random
import argparse
class SideBySideDataset(Dataset):
def __init__(self, img_dir, labels_file, transform=None):
self.img_dir = img_dir
self.transform = transform
self.img_labels = self._load_labels(labels_file)
def _load_labels(self, labels_file):
ret = []
with open(labels_file, 'r') as f:
lines = f.readlines()
for line in lines:
if not line.strip():
continue
img_name, label = line.strip().split()
ret.append([img_name, label])
return ret
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_name, label = self.img_labels[idx]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
label = int(label)
return image, label
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('--labels-file', required=True, type=str, help='Path to the labels file')
return parser.parse_args()
def main():
args = parse_args()
# 数据预处理
transform_train = transforms.Compose([
#transforms.RandomResizedCrop((128, 64)), # 随机裁剪
#transforms.RandomHorizontalFlip(), # 随机水平翻转
#transforms.Resize((256, 128)), # 调整大小
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
transform_val = transforms.Compose([
#transforms.Resize((256, 128)), # 调整大小
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 加载数据集
img_dir = os.path.abspath("data/roi/train")
full_dataset = SideBySideDataset(
img_dir=img_dir,
labels_file=args.labels_file,
transform=transform_train,
)
train_ratio = 0.5
val_ratio = 0.5
train_size = int(train_ratio * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 修改最后一层全连接层使其输出为2二分类
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# 将模型移动到GPU如果可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
#optimizer = torch.optim.SGD( model.parameters(), lr=0.0001, momentum=0.09, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
num_epochs = 15
last_accu = 0
prev_accu = []
for epoch in range(num_epochs):
# 训练阶段
print(f"Start training epoch {epoch+1}/{num_epochs}")
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
scheduler.step()
# 验证阶段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
last_accu = 100 * correct / total
prev_accu.append(last_accu)
if epoch > 5:
avg = sum(prev_accu[-5:]) / 5
variance = sum((x - avg) ** 2 for x in prev_accu[-5:]) / 5
print(f"variance={variance:.4f}")
if variance < 1 and not (prev_accu[-1] > prev_accu[-2] and prev_accu[-2] > prev_accu[-3] and prev_accu[-3] > prev_accu[-4]):
print(f"Early stopping condition met: {variance:.4f}")
break
print(f'Validation Accuracy: {last_accu:.2f}%')
dt = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(model.state_dict(), f'data/roi/models/resnet18_{dt}_{last_accu:.2f}.pth')
if __name__ == "__main__":
main()