themblem/detection/thirdTool/qr_box_detect.py
2024-09-01 21:51:50 +01:00

57 lines
2.2 KiB
Python

import torch.nn as nn
import torch
import numpy as np
from .yolo5x_qr.models.common import DetectMultiBackend
from .yolo5x_qr.utils.augmentations import (letterbox)
from .yolo5x_qr.utils.torch_utils import select_device, smart_inference_mode
from .yolo5x_qr.utils.general import (LOGGER, Profile, check_img_size, cv2,
non_max_suppression, scale_boxes)
class QR_Box_detect(nn.Module):
def __init__(self, model_path=None, device='cpu'):
super(QR_Box_detect, self).__init__()
self.conf_thres=0.80
self.iou_thres=0.45
self.classes=None
self.max_det=1
self.agnostic_nms=False
self.model_path = model_path
self.device = select_device(device,model_path=self.model_path)
self.model = DetectMultiBackend(weights=self.model_path, device=self.device)
def forward(self,input,imgsz=512):
#图像按比例缩放
stride, names, pt = self.model.stride, self.model.names, self.model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size
im0 = cv2.cvtColor(np.array(input), cv2.COLOR_RGB2BGR) # BGR
im = letterbox(im0, imgsz, stride=32, auto=True)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im).to(self.model.device)
im = im.half() if self.model.fp16 else im.float() #
im /= 255
if len(im.shape) == 3:
im = im[None] # expand for batch dim
try:
pred = self.model.model(im, augment=False,visualize=False)
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det)
det = pred[0]
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
return det[:, :4].view(2, 2).cpu().numpy(),det[:,4:5].cpu().numpy()
return None,None
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
return None, None