From 03a25595a25e10508443760639c70b3416ed3f95 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Sat, 22 Mar 2025 22:57:24 -0700 Subject: [PATCH] research: roi train --- research/roi-train.py | 104 ++++++++++++++++++++++++++++++++++++++++++ research/roi.py | 61 ++++++++++++++++++++++++- 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100755 research/roi-train.py diff --git a/research/roi-train.py b/research/roi-train.py new file mode 100755 index 0000000..703458f --- /dev/null +++ b/research/roi-train.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms, models +from PIL import Image +import os +from datetime import datetime + +class CustomDataset(Dataset): + def __init__(self, img_dir, transform=None): + self.img_dir = img_dir + self.transform = transform + self.img_labels = self._load_labels() + + def _load_labels(self): + # 假设标签存储在labels.txt文件中,每行格式为:图片名 标签 + with open(os.path.join(self.img_dir, 'labels.txt'), 'r') as f: + lines = f.readlines() + return [line.strip().split() for line in lines if line.strip()] + + def __len__(self): + return len(self.img_labels) + + def __getitem__(self, idx): + img_name, label = self.img_labels[idx] + img_path = os.path.join(self.img_dir, img_name) + image = Image.open(img_path).convert('RGB') + if self.transform: + image = self.transform(image) + label = int(label) + return image, label + +# 数据预处理 +transform_train = transforms.Compose([ + transforms.RandomResizedCrop((128, 64)), # 随机裁剪 + transforms.RandomHorizontalFlip(), # 随机水平翻转 + transforms.ToTensor(), # 转换为Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 +]) + +transform_val = transforms.Compose([ + transforms.Resize((128, 64)), # 调整大小 + transforms.ToTensor(), # 转换为Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 +]) + +# 加载数据集 +img_dir = os.path.abspath("data/roi/train") +train_dataset = CustomDataset(img_dir=img_dir, transform=transform_train) +val_dataset = CustomDataset(img_dir=img_dir, transform=transform_val) + +# 创建DataLoader +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) + +# 加载预训练的ResNet18模型 +model = models.resnet18(pretrained=True) + +# 修改最后一层全连接层,使其输出为2(二分类) +num_ftrs = model.fc.in_features +model.fc = nn.Linear(num_ftrs, 2) + +# 将模型移动到GPU(如果可用) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.to(device) + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=0.001) + +num_epochs = 10 + +for epoch in range(num_epochs): + # 训练阶段 + model.train() + running_loss = 0.0 + for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() + + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') + + # 验证阶段 + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for images, labels in val_loader: + images, labels = images.to(device), labels.to(device) + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f'Validation Accuracy: {100 * correct / total:.2f}%') + +dt = datetime.now().strftime("%Y%m%d_%H%M%S") +torch.save(model.state_dict(), f'data/roi/resnet18_{dt}.pth') diff --git a/research/roi.py b/research/roi.py index 0a4a8e7..778d9e4 100755 --- a/research/roi.py +++ b/research/roi.py @@ -2,10 +2,13 @@ import argparse import os +import shutil +import numpy as np import requests import json import subprocess import cv2 +from collections import defaultdict class RoiResearch(object): def __init__(self, token=None): @@ -170,13 +173,18 @@ def roi_sim(frame_roi_img, roi_img): cv2.destroyAllWindows() return siml +def frame_roi(frame_file): + cmd = f'./qrtool frame_roi {frame_file}' + subprocess.check_call(cmd, shell=True, cwd=os.path.dirname(os.path.abspath(__file__)) + "/../alg") + frame_roi_img = cv2.imread(frame_file + ".roi.jpg") + return frame_roi_img + def process_roi_data(id): rd = f"data/roi/samples/{id}" frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg") roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg") roi_img = cv2.imread(roi_file) - cmd = f'./qrtool frame_roi {frame_file}' - subprocess.check_call(cmd, shell=True, cwd=os.path.dirname(os.path.abspath(__file__)) + "/../alg") + frame_roi_img = frame_roi(frame_file) frame_roi_img = cv2.imread(frame_file + ".roi.jpg") size = [128, 128] frame_roi_img = cv2.resize(frame_roi_img, size) @@ -190,13 +198,62 @@ def parse_args(): parser.add_argument("--username", "-u", type=str) parser.add_argument("--password", "-p", type=str) parser.add_argument("--download", "-d", action='store_true') + parser.add_argument("--preprocess", "-P", action='store_true') parser.add_argument("--id", "-i", type=int, action='append') return parser.parse_args() +def get_all_samples(): + return os.listdir("data/roi/samples") + +def prepare_to_train(id): + rd = f"data/roi/samples/{id}" + frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg") + roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg") + roi_img = cv2.imread(roi_file) + frame_roi_img = frame_roi(frame_file) + frame_roi_img = cv2.resize(frame_roi_img, (128, 128)) + roi_img = cv2.resize(roi_img, (128, 128)) + side_by_side = np.concatenate((frame_roi_img, roi_img), axis=1) + side_by_side_file = os.path.abspath(f"{rd}/{id}-side-by-side.jpg") + # show_img(side_by_side, "side_by_side") + # cv2.waitKey(0) + label_file = os.path.abspath(f"{rd}/label.txt") + json_file = os.path.abspath(f"{rd}/{id}.json") + with open(json_file, "r") as f: + data = json.load(f) + labels = data['labels'] + label = 0 + if 'pos' in labels: + label = 1 + elif 'neg' in labels: + label = 0 + else: + raise Exception("no label found") + side_by_side_file = os.path.abspath(f"data/roi/train/{id}.jpg") + cv2.imwrite(side_by_side_file, side_by_side) + with open(os.path.abspath(f"data/roi/train/labels.txt"), "a") as f: + f.write(f"{id}.jpg {label}\n") + return label + def main(): args = parse_args() if args.download: get_roi_data(args) + if args.preprocess: + all_samples = get_all_samples() + total = len(all_samples) + shutil.rmtree(os.path.abspath("data/roi/train")) + os.makedirs(os.path.abspath("data/roi/train")) + label_count = defaultdict(int) + for i, id in enumerate(all_samples): + print(f"preprocessing {id} ({i + 1}/{total})") + try: + label = prepare_to_train(id) + label_count[label] += 1 + except Exception as e: + print(f"error: {e}") + print(f"count by label: {label_count}") + return if args.id: for id in args.id: process_roi_data(id)