research: roi train

This commit is contained in:
Fam Zheng 2025-03-22 22:57:24 -07:00
parent 9fbc2fb812
commit 03a25595a2
2 changed files with 163 additions and 2 deletions

104
research/roi-train.py Executable file
View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
from datetime import datetime
class CustomDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
self.img_labels = self._load_labels()
def _load_labels(self):
# 假设标签存储在labels.txt文件中每行格式为图片名 标签
with open(os.path.join(self.img_dir, 'labels.txt'), 'r') as f:
lines = f.readlines()
return [line.strip().split() for line in lines if line.strip()]
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_name, label = self.img_labels[idx]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
label = int(label)
return image, label
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomResizedCrop((128, 64)), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
transform_val = transforms.Compose([
transforms.Resize((128, 64)), # 调整大小
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 加载数据集
img_dir = os.path.abspath("data/roi/train")
train_dataset = CustomDataset(img_dir=img_dir, transform=transform_train)
val_dataset = CustomDataset(img_dir=img_dir, transform=transform_val)
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 修改最后一层全连接层使其输出为2二分类
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# 将模型移动到GPU如果可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
# 训练阶段
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
# 验证阶段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total:.2f}%')
dt = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(model.state_dict(), f'data/roi/resnet18_{dt}.pth')

View File

@ -2,10 +2,13 @@
import argparse import argparse
import os import os
import shutil
import numpy as np
import requests import requests
import json import json
import subprocess import subprocess
import cv2 import cv2
from collections import defaultdict
class RoiResearch(object): class RoiResearch(object):
def __init__(self, token=None): def __init__(self, token=None):
@ -170,13 +173,18 @@ def roi_sim(frame_roi_img, roi_img):
cv2.destroyAllWindows() cv2.destroyAllWindows()
return siml return siml
def frame_roi(frame_file):
cmd = f'./qrtool frame_roi {frame_file}'
subprocess.check_call(cmd, shell=True, cwd=os.path.dirname(os.path.abspath(__file__)) + "/../alg")
frame_roi_img = cv2.imread(frame_file + ".roi.jpg")
return frame_roi_img
def process_roi_data(id): def process_roi_data(id):
rd = f"data/roi/samples/{id}" rd = f"data/roi/samples/{id}"
frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg") frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg")
roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg") roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg")
roi_img = cv2.imread(roi_file) roi_img = cv2.imread(roi_file)
cmd = f'./qrtool frame_roi {frame_file}' frame_roi_img = frame_roi(frame_file)
subprocess.check_call(cmd, shell=True, cwd=os.path.dirname(os.path.abspath(__file__)) + "/../alg")
frame_roi_img = cv2.imread(frame_file + ".roi.jpg") frame_roi_img = cv2.imread(frame_file + ".roi.jpg")
size = [128, 128] size = [128, 128]
frame_roi_img = cv2.resize(frame_roi_img, size) frame_roi_img = cv2.resize(frame_roi_img, size)
@ -190,13 +198,62 @@ def parse_args():
parser.add_argument("--username", "-u", type=str) parser.add_argument("--username", "-u", type=str)
parser.add_argument("--password", "-p", type=str) parser.add_argument("--password", "-p", type=str)
parser.add_argument("--download", "-d", action='store_true') parser.add_argument("--download", "-d", action='store_true')
parser.add_argument("--preprocess", "-P", action='store_true')
parser.add_argument("--id", "-i", type=int, action='append') parser.add_argument("--id", "-i", type=int, action='append')
return parser.parse_args() return parser.parse_args()
def get_all_samples():
return os.listdir("data/roi/samples")
def prepare_to_train(id):
rd = f"data/roi/samples/{id}"
frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg")
roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg")
roi_img = cv2.imread(roi_file)
frame_roi_img = frame_roi(frame_file)
frame_roi_img = cv2.resize(frame_roi_img, (128, 128))
roi_img = cv2.resize(roi_img, (128, 128))
side_by_side = np.concatenate((frame_roi_img, roi_img), axis=1)
side_by_side_file = os.path.abspath(f"{rd}/{id}-side-by-side.jpg")
# show_img(side_by_side, "side_by_side")
# cv2.waitKey(0)
label_file = os.path.abspath(f"{rd}/label.txt")
json_file = os.path.abspath(f"{rd}/{id}.json")
with open(json_file, "r") as f:
data = json.load(f)
labels = data['labels']
label = 0
if 'pos' in labels:
label = 1
elif 'neg' in labels:
label = 0
else:
raise Exception("no label found")
side_by_side_file = os.path.abspath(f"data/roi/train/{id}.jpg")
cv2.imwrite(side_by_side_file, side_by_side)
with open(os.path.abspath(f"data/roi/train/labels.txt"), "a") as f:
f.write(f"{id}.jpg {label}\n")
return label
def main(): def main():
args = parse_args() args = parse_args()
if args.download: if args.download:
get_roi_data(args) get_roi_data(args)
if args.preprocess:
all_samples = get_all_samples()
total = len(all_samples)
shutil.rmtree(os.path.abspath("data/roi/train"))
os.makedirs(os.path.abspath("data/roi/train"))
label_count = defaultdict(int)
for i, id in enumerate(all_samples):
print(f"preprocessing {id} ({i + 1}/{total})")
try:
label = prepare_to_train(id)
label_count[label] += 1
except Exception as e:
print(f"error: {e}")
print(f"count by label: {label_count}")
return
if args.id: if args.id:
for id in args.id: for id in args.id:
process_roi_data(id) process_roi_data(id)