themblem/research/roi-train.py
2025-03-25 13:29:57 +08:00

105 lines
3.5 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
from datetime import datetime
class CustomDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
self.img_labels = self._load_labels()
def _load_labels(self):
# 假设标签存储在labels.txt文件中每行格式为图片名 标签
with open(os.path.join(self.img_dir, 'labels.txt'), 'r') as f:
lines = f.readlines()
return [line.strip().split() for line in lines if line.strip()]
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_name, label = self.img_labels[idx]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
label = int(label)
return image, label
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomResizedCrop((128, 64)), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
transform_val = transforms.Compose([
transforms.Resize((128, 64)), # 调整大小
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 加载数据集
img_dir = os.path.abspath("data/roi/train")
train_dataset = CustomDataset(img_dir=img_dir, transform=transform_train)
val_dataset = CustomDataset(img_dir=img_dir, transform=transform_val)
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 修改最后一层全连接层使其输出为2二分类
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# 将模型移动到GPU如果可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
# 训练阶段
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
# 验证阶段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total:.2f}%')
dt = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(model.state_dict(), f'data/roi/resnet18_{dt}.pth')