research: Add /api/roi-verify

This commit is contained in:
Fam Zheng 2025-03-25 13:24:25 +08:00
parent c02b9e342a
commit 3d96d2a7ad
8 changed files with 217 additions and 44 deletions

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
set -e set -e
DOCKER_IMAGE=registry.gitlab.com/euphon/themblem:research-$(git rev-parse --short HEAD) DOCKER_IMAGE=registry.cn-shenzhen.aliyuncs.com/emblem/themblem:research-$(git rev-parse --short HEAD)
docker build --network=host -t $DOCKER_IMAGE . docker build --network=host -t $DOCKER_IMAGE .
docker push $DOCKER_IMAGE docker push $DOCKER_IMAGE

Binary file not shown.

View File

@ -2,23 +2,36 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models from torchvision import transforms, models
from PIL import Image from PIL import Image
import os import os
from datetime import datetime from datetime import datetime
from collections import defaultdict
import random
class CustomDataset(Dataset): class CustomDataset(Dataset):
def __init__(self, img_dir, transform=None): def __init__(self, img_dir, transform=None, limit=500):
self.img_dir = img_dir self.img_dir = img_dir
self.transform = transform self.transform = transform
self.img_labels = self._load_labels() self.img_labels = self._load_labels(limit)
def _load_labels(self): def _load_labels(self, limit):
# 假设标签存储在labels.txt文件中每行格式为图片名 标签 cats = defaultdict(list)
with open(os.path.join(self.img_dir, 'labels.txt'), 'r') as f: with open(os.path.join(self.img_dir, 'labels.txt'), 'r') as f:
lines = f.readlines() lines = f.readlines()
return [line.strip().split() for line in lines if line.strip()] for line in lines:
if not line.strip():
continue
img_name, label = line.strip().split()
cats[label].append([img_name, label])
min_samples = min(len(v) for v in cats.values())
min_samples = min(limit, min_samples)
ret = []
for k, v in cats.items():
ret.extend(random.sample(v, min_samples))
#ret.extend(v)
return ret
def __len__(self): def __len__(self):
return len(self.img_labels) return len(self.img_labels)
@ -34,26 +47,33 @@ class CustomDataset(Dataset):
# 数据预处理 # 数据预处理
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.RandomResizedCrop((128, 64)), # 随机裁剪 #transforms.RandomResizedCrop((128, 64)), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转 #transforms.RandomHorizontalFlip(), # 随机水平翻转
#transforms.Resize((256, 128)), # 调整大小
transforms.ToTensor(), # 转换为Tensor transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
]) ])
transform_val = transforms.Compose([ transform_val = transforms.Compose([
transforms.Resize((128, 64)), # 调整大小 #transforms.Resize((256, 128)), # 调整大小
transforms.ToTensor(), # 转换为Tensor transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
]) ])
# 加载数据集 # 加载数据集
img_dir = os.path.abspath("data/roi/train") img_dir = os.path.abspath("data/roi/train")
train_dataset = CustomDataset(img_dir=img_dir, transform=transform_train)
val_dataset = CustomDataset(img_dir=img_dir, transform=transform_val)
# 创建DataLoader full_dataset = CustomDataset(img_dir=img_dir, transform=transform_train, limit=5000)
train_ratio = 0.5
val_ratio = 0.5
train_size = int(train_ratio * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
# 加载预训练的ResNet18模型 # 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True) model = models.resnet18(pretrained=True)
@ -67,12 +87,19 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) model = model.to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.0001)
#optimizer = torch.optim.SGD( model.parameters(), lr=0.0001, momentum=0.09, weight_decay=1e-4)
num_epochs = 10 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
num_epochs = 15
last_accu = 0
prev_accu = []
for epoch in range(num_epochs): for epoch in range(num_epochs):
# 训练阶段 # 训练阶段
print(f"Start training epoch {epoch+1}/{num_epochs}")
model.train() model.train()
running_loss = 0.0 running_loss = 0.0
for images, labels in train_loader: for images, labels in train_loader:
@ -86,6 +113,8 @@ for epoch in range(num_epochs):
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
scheduler.step()
# 验证阶段 # 验证阶段
model.eval() model.eval()
correct = 0 correct = 0
@ -98,7 +127,16 @@ for epoch in range(num_epochs):
total += labels.size(0) total += labels.size(0)
correct += (predicted == labels).sum().item() correct += (predicted == labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total:.2f}%') last_accu = 100 * correct / total
prev_accu.append(last_accu)
if epoch > 5:
avg = sum(prev_accu[-5:]) / 5
variance = sum((x - avg) ** 2 for x in prev_accu[-5:]) / 5
print(f"variance={variance:.4f}")
if variance < 1 and not (prev_accu[-1] > prev_accu[-2] and prev_accu[-2] > prev_accu[-3] and prev_accu[-3] > prev_accu[-4]):
print(f"Early stopping condition met: {variance:.4f}")
break
print(f'Validation Accuracy: {last_accu:.2f}%')
dt = datetime.now().strftime("%Y%m%d_%H%M%S") dt = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(model.state_dict(), f'data/roi/resnet18_{dt}.pth') torch.save(model.state_dict(), f'data/roi/models/resnet18_{dt}_{last_accu:.2f}.pth')

30
research/roi-verify.py Executable file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env python3
from PIL import Image
import os
import argparse
import sys
from roi_lib import *
def parse_args():
parser = argparse.ArgumentParser(description='ROI prediction')
parser.add_argument('--model', type=str, required=True, help='model path')
parser.add_argument('--image', type=str, required=True, help='image file')
return parser.parse_args()
# 主函数
def main():
args = parse_args()
model = load_model(args.model)
image_path = args.image
image_tensor = preprocess_image(image_path)
predicted_class, probabilities = predict(model, image_tensor)
print(f'{image_path} predicted={predicted_class} prob={probabilities}')
if predicted_class == 1:
print("verify ok")
return 0
else:
print("verify ng")
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@ -9,6 +9,7 @@ import json
import subprocess import subprocess
import cv2 import cv2
from collections import defaultdict from collections import defaultdict
from multiprocessing import Pool
class RoiResearch(object): class RoiResearch(object):
def __init__(self, token=None): def __init__(self, token=None):
@ -204,22 +205,28 @@ def parse_args():
parser.add_argument("--id", "-i", type=int, action='append') parser.add_argument("--id", "-i", type=int, action='append')
return parser.parse_args() return parser.parse_args()
def get_all_samples(): def all_sample_ids():
return os.listdir("data/roi/samples") return os.listdir("data/roi/samples")
def prepare_to_train(id): def prepare_to_train(id):
rd = f"data/roi/samples/{id}" rd = f"data/roi/samples/{id}"
side_by_side_file = os.path.join(rd, f"{id}-side-by-side.jpg")
if not os.path.exists(side_by_side_file):
frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg") frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg")
roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg") roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg")
if not os.path.exists(frame_file) or not os.path.exists(roi_file):
print(f"skipping {id}, no frame or roi")
return
roi_img = cv2.imread(roi_file) roi_img = cv2.imread(roi_file)
try:
frame_roi_img = frame_roi(frame_file) frame_roi_img = frame_roi(frame_file)
except Exception as e:
print(f"failed to get frame_roi for {id}: {e}")
return
frame_roi_img = cv2.resize(frame_roi_img, (128, 128)) frame_roi_img = cv2.resize(frame_roi_img, (128, 128))
roi_img = cv2.resize(roi_img, (128, 128)) roi_img = cv2.resize(roi_img, (128, 128))
side_by_side = np.concatenate((frame_roi_img, roi_img), axis=1) side_by_side = np.concatenate((frame_roi_img, roi_img), axis=1)
side_by_side_file = os.path.abspath(f"{rd}/{id}-side-by-side.jpg") cv2.imwrite(side_by_side_file, side_by_side)
# show_img(side_by_side, "side_by_side")
# cv2.waitKey(0)
label_file = os.path.abspath(f"{rd}/label.txt")
json_file = os.path.abspath(f"{rd}/{id}.json") json_file = os.path.abspath(f"{rd}/{id}.json")
with open(json_file, "r") as f: with open(json_file, "r") as f:
data = json.load(f) data = json.load(f)
@ -230,9 +237,9 @@ def prepare_to_train(id):
elif 'neg' in labels: elif 'neg' in labels:
label = 0 label = 0
else: else:
raise Exception("no label found") print(f"no label found for {id}")
side_by_side_file = os.path.abspath(f"data/roi/train/{id}.jpg") return
cv2.imwrite(side_by_side_file, side_by_side) shutil.copy(side_by_side_file, os.path.abspath(f"data/roi/train/{id}.jpg"))
with open(os.path.abspath(f"data/roi/train/labels.txt"), "a") as f: with open(os.path.abspath(f"data/roi/train/labels.txt"), "a") as f:
f.write(f"{id}.jpg {label}\n") f.write(f"{id}.jpg {label}\n")
return label return label
@ -242,19 +249,10 @@ def main():
if args.download: if args.download:
get_roi_data(args) get_roi_data(args)
if args.preprocess: if args.preprocess:
all_samples = get_all_samples()
total = len(all_samples)
shutil.rmtree(os.path.abspath("data/roi/train")) shutil.rmtree(os.path.abspath("data/roi/train"))
os.makedirs(os.path.abspath("data/roi/train")) os.makedirs(os.path.abspath("data/roi/train"))
label_count = defaultdict(int) with Pool(processes=10) as pool:
for i, id in enumerate(all_samples): pool.map(prepare_to_train, all_sample_ids())
print(f"preprocessing {id} ({i + 1}/{total})")
try:
label = prepare_to_train(id)
label_count[label] += 1
except Exception as e:
print(f"error: {e}")
print(f"count by label: {label_count}")
return return
if args.id: if args.id:
for id in args.id: for id in args.id:

56
research/roi_bench.py Executable file
View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
from PIL import Image
import os
import argparse
import random
from roi_lib import *
def parse_args():
parser = argparse.ArgumentParser(description='ROI prediction')
parser.add_argument('--model', type=str, required=True, help='model path')
parser.add_argument('--label', type=str, required=True, help='label file')
return parser.parse_args()
# 主函数
def main():
args = parse_args()
# 加载模型
model = load_model(args.model)
with open(args.label, 'r') as f:
lines = f.readlines()
correct = 0
zero_to_one = 0
one_to_zero = 0
testset = []
for line in lines:
if not line.strip():
continue
fs = line.strip().split()
if len(fs) != 2:
continue
image_name = fs[0]
label = fs[1]
testset.append((image_name, label))
random.shuffle(testset)
done = 0
for image_name, label in testset:
image_path = os.path.abspath(os.path.join(os.path.dirname(args.label), image_name))
image_tensor = preprocess_image(image_path)
predicted_class, probabilities = predict(model, image_tensor)
done += 1
if str(predicted_class) == label:
correct += 1
else:
if label == '0':
zero_to_one += 1
else:
one_to_zero += 1
accu_so_far = correct / done
print(f'{image_path} label={label} predicted={predicted_class} prob={probabilities} accuracy so far: {accu_so_far:.3f}')
print(f'total={done} correct={correct} accuracy={accu_so_far:.3f}')
print(f'zero_to_one={zero_to_one} one_to_zero={one_to_zero}')
return 0
if __name__ == '__main__':
sys.exit(main())

35
research/roi_lib.py Executable file
View File

@ -0,0 +1,35 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import argparse
import random
def load_model(model_path):
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 加载模型权重
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return model
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
image = Image.open(image_path).convert('RGB') # 打开图像并转换为 RGB
image = transform(image).unsqueeze(0) # 增加 batch 维度
return image
def predict(model, image_tensor):
with torch.no_grad(): # 禁用梯度计算
output = model(image_tensor)
_, predicted = torch.max(output, 1) # 获取预测类别
probabilities = torch.nn.functional.softmax(output, dim=1) # 计算概率
return predicted.item(), probabilities.squeeze().tolist()

View File

@ -11,6 +11,7 @@ from io import BytesIO
import tempfile import tempfile
import cv2 import cv2
import numpy as np import numpy as np
from roi_lib import *
app = Flask(__name__) app = Flask(__name__)
@ -236,5 +237,20 @@ def analyze():
analyze_datapoint(data) analyze_datapoint(data)
return data return data
@app.route('/api/roi-verify', methods=['POST'])
def roi_verify():
files = request.files
side_by_side_fn = files['side_by_side']
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
side_by_side_fn.save(f.name)
img = Image.open(f.name)
img = preprocess_image(img)
predicted_class, probabilities = predict(model, image_tensor)
return {
"ok": True,
"predicted_class": predicted_class,
"probabilities": probabilities,
}
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=26966, debug=True) app.run(host='0.0.0.0', port=26966, debug=True)