drop detection
@ -1,5 +1,4 @@
|
|||||||
stages:
|
stages:
|
||||||
- download-models
|
|
||||||
- test-and-build
|
- test-and-build
|
||||||
- build-docker
|
- build-docker
|
||||||
- deploy
|
- deploy
|
||||||
@ -10,24 +9,6 @@ cache:
|
|||||||
- emtest/target
|
- emtest/target
|
||||||
- venv
|
- venv
|
||||||
|
|
||||||
download-models:
|
|
||||||
stage: download-models
|
|
||||||
tags:
|
|
||||||
- derby
|
|
||||||
before_script:
|
|
||||||
- source scripts/dev-setup
|
|
||||||
script:
|
|
||||||
- make download-models
|
|
||||||
artifacts:
|
|
||||||
paths:
|
|
||||||
- detection/model
|
|
||||||
except:
|
|
||||||
- main
|
|
||||||
cache:
|
|
||||||
key: models
|
|
||||||
paths:
|
|
||||||
- detection/model
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
stage: test-and-build
|
stage: test-and-build
|
||||||
except:
|
except:
|
||||||
@ -40,8 +21,6 @@ test:
|
|||||||
- make opencv -j$(nproc --ignore=2)
|
- make opencv -j$(nproc --ignore=2)
|
||||||
- make -C alg qrtool -j$(nproc --ignore=2)
|
- make -C alg qrtool -j$(nproc --ignore=2)
|
||||||
- make test
|
- make test
|
||||||
dependencies:
|
|
||||||
- download-models
|
|
||||||
|
|
||||||
build-alg:
|
build-alg:
|
||||||
stage: test-and-build
|
stage: test-and-build
|
||||||
@ -83,7 +62,6 @@ build-docker:
|
|||||||
dependencies:
|
dependencies:
|
||||||
- build-web
|
- build-web
|
||||||
- build-alg
|
- build-alg
|
||||||
- download-models
|
|
||||||
except:
|
except:
|
||||||
- main
|
- main
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ RUN apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y $(cat
|
|||||||
RUN pip3 install --no-cache-dir torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0
|
RUN pip3 install --no-cache-dir torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0
|
||||||
ADD requirements.txt requirements.txt
|
ADD requirements.txt requirements.txt
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||||
ADD detection /emblem/detection
|
|
||||||
ADD alg /emblem/alg
|
ADD alg /emblem/alg
|
||||||
ADD api /emblem/api
|
ADD api /emblem/api
|
||||||
ADD web /emblem/web
|
ADD web /emblem/web
|
||||||
|
|||||||
18
Makefile
@ -21,15 +21,6 @@ WEB_FILES = \
|
|||||||
-type f)\
|
-type f)\
|
||||||
)
|
)
|
||||||
|
|
||||||
DETECTION_FILES = \
|
|
||||||
$(addprefix build/, \
|
|
||||||
$(shell find -L \
|
|
||||||
detection \
|
|
||||||
-type f \
|
|
||||||
-not -name '*.pyc' \
|
|
||||||
) \
|
|
||||||
)
|
|
||||||
|
|
||||||
SCRIPTS_FILES = \
|
SCRIPTS_FILES = \
|
||||||
$(addprefix build/, \
|
$(addprefix build/, \
|
||||||
$(shell find -L \
|
$(shell find -L \
|
||||||
@ -57,8 +48,8 @@ ALG_FILES = \
|
|||||||
) \
|
) \
|
||||||
)
|
)
|
||||||
|
|
||||||
docker-build: build/Dockerfile build/packages.txt build/requirements.txt download-models \
|
docker-build: build/Dockerfile build/packages.txt build/requirements.txt \
|
||||||
build/nginx.conf $(WEB_FILES) $(API_FILES) $(ALG_FILES) $(DETECTION_FILES) $(SCRIPTS_FILES) $(DATASET_FILES)
|
build/nginx.conf $(WEB_FILES) $(API_FILES) $(ALG_FILES) $(SCRIPTS_FILES) $(DATASET_FILES)
|
||||||
find build
|
find build
|
||||||
docker build --network=host -t $(IMAGE) build
|
docker build --network=host -t $(IMAGE) build
|
||||||
|
|
||||||
@ -101,7 +92,6 @@ deploy-roi-worker:
|
|||||||
test: FORCE
|
test: FORCE
|
||||||
cd emtest && cargo test -- --nocapture
|
cd emtest && cargo test -- --nocapture
|
||||||
cd api; ./manage.py migrate && ./manage.py test tests
|
cd api; ./manage.py migrate && ./manage.py test tests
|
||||||
make -C detection test
|
|
||||||
|
|
||||||
OPENCV_TAG := 4.9.0
|
OPENCV_TAG := 4.9.0
|
||||||
opencv/src/LICENSE:
|
opencv/src/LICENSE:
|
||||||
@ -148,7 +138,3 @@ opencv.js: opencv/src/LICENSE FORCE
|
|||||||
|
|
||||||
alg/qrtool:
|
alg/qrtool:
|
||||||
make -C alg qrtool
|
make -C alg qrtool
|
||||||
|
|
||||||
|
|
||||||
download-models: FORCE
|
|
||||||
cd detection && python3 app.py --download-models-only
|
|
||||||
|
|||||||
@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 3.2.25 on 2025-04-24 08:06
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('products', '0101_auto_20250302_1419'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='codebatch',
|
||||||
|
name='detection_service',
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -170,8 +170,6 @@ class CodeBatch(models.Model):
|
|||||||
num_digits = models.IntegerField(default=10, verbose_name="尾号长度")
|
num_digits = models.IntegerField(default=10, verbose_name="尾号长度")
|
||||||
is_active = models.BooleanField(default=True, verbose_name="已激活")
|
is_active = models.BooleanField(default=True, verbose_name="已激活")
|
||||||
name = models.CharField(null=True, max_length=255, db_index=True, verbose_name="名称")
|
name = models.CharField(null=True, max_length=255, db_index=True, verbose_name="名称")
|
||||||
detection_service = models.TextField(null=True, blank=True,
|
|
||||||
verbose_name="指定检测后端微服务地址(可选) ")
|
|
||||||
scan_redirect_url = models.TextField(null=True, blank=True,
|
scan_redirect_url = models.TextField(null=True, blank=True,
|
||||||
verbose_name="自定义扫码重定向URL(可选)")
|
verbose_name="自定义扫码重定向URL(可选)")
|
||||||
enable_auto_torch = models.BooleanField(default=False, verbose_name="自动打开闪光灯")
|
enable_auto_torch = models.BooleanField(default=False, verbose_name="自动打开闪光灯")
|
||||||
|
|||||||
@ -642,88 +642,25 @@ class QrVerifyView(BaseView):
|
|||||||
if batch.qr_angle_allowed_error < 0.01:
|
if batch.qr_angle_allowed_error < 0.01:
|
||||||
messages.append(f"batch.qr_angle_allowed_error {batch.qr_angle_allowed_error} < 0.01, not checking angle")
|
messages.append(f"batch.qr_angle_allowed_error {batch.qr_angle_allowed_error} < 0.01, not checking angle")
|
||||||
return
|
return
|
||||||
if angle is not None:
|
if angle is None:
|
||||||
diff = abs(batch.qr_angle - float(angle))
|
raise Exception("Angle check failed, angle is missing?")
|
||||||
if diff > batch.qr_angle_allowed_error:
|
diff = abs(batch.qr_angle - float(angle))
|
||||||
messages.append(f"Angle check failed, captured {angle} but expecting {batch.qr_angle} with error margin {batch.qr_angle_allowed_error}")
|
if diff > batch.qr_angle_allowed_error:
|
||||||
raise Exception("Angle check failed")
|
messages.append(f"Angle check failed, captured {angle} but expecting {batch.qr_angle} with error margin {batch.qr_angle_allowed_error}")
|
||||||
return
|
raise Exception("Angle check failed")
|
||||||
|
|
||||||
ds = batch.detection_service or "http://localhost:6006"
|
|
||||||
for api in ['/dot_detection_top', '/dot_detection_bottom']:
|
|
||||||
try:
|
|
||||||
messages.append(f"trying api {api}")
|
|
||||||
r = requests.post(ds + api, files={
|
|
||||||
'file': open(filename, 'rb'),
|
|
||||||
},
|
|
||||||
data={
|
|
||||||
'threshold': str(batch.qr_angle_allowed_error),
|
|
||||||
'angle': str(batch.qr_angle),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
r = r.json()
|
|
||||||
messages.append(f"{api} response: {r}")
|
|
||||||
data = r['data']
|
|
||||||
if data["status"] == "OK":
|
|
||||||
messages.append("status is OK, angle check succeeded")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
messages.append(f"API {api} error: {e}")
|
|
||||||
pass
|
|
||||||
messages.append(f"All angle check api failed")
|
|
||||||
raise Exception("Angle detection failed")
|
|
||||||
|
|
||||||
def feature_comparison_check(self, sc, batch, img_fn, messages, threshold):
|
|
||||||
if not threshold:
|
|
||||||
threshold = batch.feature_comparison_threshold
|
|
||||||
if threshold < 0.01:
|
|
||||||
messages.append(f"batch.feature_comparison_threshold {batch.feature_comparison_threshold} < 0.01, not comparing feature")
|
|
||||||
return
|
|
||||||
feature_roi = self.get_feature_roi(sc.code)
|
|
||||||
if not feature_roi:
|
|
||||||
messages.append(f"feature roi not found for code {sc.code}, skiping")
|
|
||||||
return
|
|
||||||
ds = batch.detection_service or "http://localhost:6006"
|
|
||||||
api_name = "/qr_roi_cloud_comparison"
|
|
||||||
url = ds + api_name
|
|
||||||
feature_roi_len = len(feature_roi)
|
|
||||||
qrtool_path = os.path.abspath("../alg/qrtool")
|
|
||||||
if not qrtool_path:
|
|
||||||
raise Exception("Cannot find qrtool")
|
|
||||||
cwd = os.path.dirname(qrtool_path)
|
|
||||||
cmd = [qrtool_path, 'topleft', img_fn]
|
|
||||||
messages.append(" ".join(cmd))
|
|
||||||
subprocess.check_call(cmd, cwd=cwd)
|
|
||||||
messages.append(f"calling: {url}, local file {img_fn}, feature roi size {feature_roi_len}, threshold {threshold}")
|
|
||||||
r = requests.post(url, files={
|
|
||||||
'ter_file': open(img_fn + ".topleft.jpg", 'rb'),
|
|
||||||
'std_file': feature_roi,
|
|
||||||
},
|
|
||||||
data={
|
|
||||||
'threshold': str(threshold),
|
|
||||||
})
|
|
||||||
j = r.json()
|
|
||||||
messages.append(f"response: {j}")
|
|
||||||
data = j.get('data')
|
|
||||||
if data.get("status") != "OK" or data.get('comparison_result').lower() not in ['pass', 'passed']:
|
|
||||||
messages.append(f"Feature comparison failed")
|
|
||||||
raise Exception(f"feature comparison check failed: {j}")
|
|
||||||
messages.append(f"Feature comparison succeeded")
|
|
||||||
|
|
||||||
def do_v5_qr_verify(self, img_fn, messages):
|
def do_v5_qr_verify(self, img_fn, messages):
|
||||||
try:
|
resp = requests.post(
|
||||||
resp = requests.post(
|
"https://themblem.com/api/v5/qr_verify",
|
||||||
"https://themblem.com/api/v5/qr_verify",
|
files={
|
||||||
files={
|
'frame': open(img_fn, 'rb'),
|
||||||
'frame': open(img_fn, 'rb'),
|
},
|
||||||
},
|
)
|
||||||
)
|
rd = resp.json()
|
||||||
rd = resp.json()
|
messages.append(f"v5 qr_verify response: {rd}")
|
||||||
messages.append(f"v5 qr_verify response: {rd}")
|
ok = rd.get("result", {}).get("predicted_class", 0) == 1
|
||||||
return rd.get("result", {}).get("predicted_class", 0) == 1
|
if not ok:
|
||||||
except Exception as e:
|
raise Exception("v5 qr verify failed")
|
||||||
messages.append(f"v5 qr verify failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
image_name = ''
|
image_name = ''
|
||||||
@ -776,10 +713,7 @@ class QrVerifyView(BaseView):
|
|||||||
sd.tenant = tenant
|
sd.tenant = tenant
|
||||||
sd.batch = sc.batch
|
sd.batch = sc.batch
|
||||||
self.dot_angle_check(sc.batch, tf.name, messages, request.data.get("angle"))
|
self.dot_angle_check(sc.batch, tf.name, messages, request.data.get("angle"))
|
||||||
if self.do_v5_qr_verify(tf.name, messages):
|
self.do_v5_qr_verify(tf.name, messages)
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.feature_comparison_check(sd, sc.batch, tf.name, messages, threshold)
|
|
||||||
sd.succeeded = True
|
sd.succeeded = True
|
||||||
article_id = None
|
article_id = None
|
||||||
if product.article:
|
if product.article:
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
/.git
|
|
||||||
8
detection/.gitignore
vendored
@ -1,8 +0,0 @@
|
|||||||
__pycache__
|
|
||||||
.*.swp
|
|
||||||
.*.swo
|
|
||||||
.*.swn
|
|
||||||
*.pyc
|
|
||||||
/tmp/*
|
|
||||||
.idea
|
|
||||||
.DS_Store
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
stages:
|
|
||||||
- build
|
|
||||||
- test
|
|
||||||
- push
|
|
||||||
- deploy
|
|
||||||
|
|
||||||
cache:
|
|
||||||
- key: venv
|
|
||||||
paths:
|
|
||||||
- venv
|
|
||||||
|
|
||||||
build:
|
|
||||||
stage: build
|
|
||||||
script:
|
|
||||||
- ./scripts/ci build
|
|
||||||
|
|
||||||
test:
|
|
||||||
stage: test
|
|
||||||
script:
|
|
||||||
- ./scripts/ci test
|
|
||||||
|
|
||||||
push:
|
|
||||||
stage: push
|
|
||||||
script:
|
|
||||||
- ./scripts/ci push
|
|
||||||
|
|
||||||
deploy-dev:
|
|
||||||
stage: deploy
|
|
||||||
script:
|
|
||||||
- ./scripts/ci deploy-dev
|
|
||||||
|
|
||||||
deploy-prod:
|
|
||||||
stage: deploy
|
|
||||||
when: manual
|
|
||||||
script:
|
|
||||||
- ./scripts/ci deploy-prod
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
test:
|
|
||||||
./tests/run.py
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : api_debug.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2023/8/26 17:55 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import requests
|
|
||||||
import cv2
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
std_feature = os.path.join('/project/emblem_detection/tests/data/std_roi_encoding.json')
|
|
||||||
with open(std_feature, "r") as input_file:
|
|
||||||
loaded_data = json.load(input_file)
|
|
||||||
# 提取 qr_encoding 数据
|
|
||||||
std_feature_encoding = loaded_data['data']["roi_encoding"]
|
|
||||||
|
|
||||||
imagePath = "/project/emblem_detection/tests/data/1291182811394_pos_iPhone12Pro.jpg"
|
|
||||||
# imagePath = "/project/emblem_detection/tests/data/1291182811394_pos_LYA-AL00.jpg"
|
|
||||||
img = cv2.imread(imagePath)
|
|
||||||
imgfile = cv2.imencode('.jpg', img)[1]
|
|
||||||
files = { "file" : ('filename.jpg',imgfile, 'image/jpg')}
|
|
||||||
# form = {"threshold":'50',"angle":'45'}
|
|
||||||
form = {"threshold": '50', "std_roi_feature": json.dumps(std_feature_encoding)}
|
|
||||||
|
|
||||||
r = requests.post('http://localhost:6006/qr_roi_cloud_feature_comparison', files=files,
|
|
||||||
data=form).json()
|
|
||||||
|
|
||||||
print(r)
|
|
||||||
|
|
||||||
# r = requests.post('http://localhost:6006/qr_decode', files=files).json()
|
|
||||||
#
|
|
||||||
# print(r)
|
|
||||||
#
|
|
||||||
# r = requests.post('http://localhost:6006/dot_detection', files=files, data=form).json()
|
|
||||||
#
|
|
||||||
# print(r)
|
|
||||||
#
|
|
||||||
# r = requests.post('http://localhost:6006/dot_detection_top', files=files, data=form).json()
|
|
||||||
#
|
|
||||||
# print(r)
|
|
||||||
#
|
|
||||||
# r = requests.post('http://localhost:6006/dot_detection_bottom', files=files, data=form).json()
|
|
||||||
#
|
|
||||||
# print(r)
|
|
||||||
#
|
|
||||||
# r = requests.post('http://localhost:6006/qr_roi_cloud_feature_extraction', files=files).json()
|
|
||||||
#
|
|
||||||
# print(r['data']['status'])
|
|
||||||
266
detection/app.py
@ -1,266 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
from flask import Flask, request
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import qr_detection
|
|
||||||
from PIL import Image
|
|
||||||
from dot_detection import dots_angle_measure_dl_sr, dots_angle_measure_tr_sr
|
|
||||||
from thirdTool import qr_box_detect
|
|
||||||
from common import cropped_image
|
|
||||||
from qr_verify_Tool.qr_orb import ORB_detect, roi_siml
|
|
||||||
import json
|
|
||||||
from qr_verify_Tool.affinity import roi_affinity_siml
|
|
||||||
from qr_verify_Tool.roi_img_process import *
|
|
||||||
from qr_verify_Tool.models import *
|
|
||||||
from utils import get_model
|
|
||||||
from thirdTool import qrsrgan
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
|
|
||||||
class Detection(object):
|
|
||||||
def get_models(self):
|
|
||||||
self.qr_cloud_detect_model = get_model('qr_cloud_detect_20230928.pt')
|
|
||||||
self.qr_roi_cloud_detect_model = get_model('qr_roi_cloud_detect_20230928.pt')
|
|
||||||
self.qr_net_g_2000_model = get_model('net_g_2000.pth')
|
|
||||||
self.roi_net_g_model = get_model("roi_net_g_20240306.pth")
|
|
||||||
|
|
||||||
def init_models(self):
|
|
||||||
device = os.environ.get("CUDA_DEVICE", "cpu")
|
|
||||||
self.qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_cloud_detect_model, device=device)
|
|
||||||
self.qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_roi_cloud_detect_model, device=device)
|
|
||||||
|
|
||||||
self.dot_realesrgan = qrsrgan.RealsrGan(model_path=self.qr_net_g_2000_model, device=device)
|
|
||||||
self.roi_generator = GeneratorUNet()
|
|
||||||
self.roi_generator.load_state_dict(torch.load(self.roi_net_g_model, map_location=torch.device('cpu')))
|
|
||||||
|
|
||||||
detection = Detection()
|
|
||||||
|
|
||||||
@app.route('/upload', methods=['POST', 'GET'])
|
|
||||||
def upload():
|
|
||||||
f = request.files.get('file')
|
|
||||||
upload_path = os.path.join("tmp/tmp." + f.filename.split(".")[-1])
|
|
||||||
# secure_filename(f.filename)) #注意:没有的文件夹一定要先创建,不然会提示没有该路径
|
|
||||||
f.save(upload_path)
|
|
||||||
return upload_path
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
#用于防伪图签特征对比
|
|
||||||
入参:std_file---备案特征采集区域文件(输入原始roi带黑框图像)
|
|
||||||
ter_file---终端特征采集文件(输入仿射扭正后1/4的左上角图像)
|
|
||||||
threshold---图像对比阈值
|
|
||||||
返回:
|
|
||||||
{"comparison_result": 对比结果,'similarity_value':相似度值, "status":'对比方法调用状态'}
|
|
||||||
'''
|
|
||||||
@app.route('/qr_roi_cloud_comparison', methods=['POST', 'GET'])
|
|
||||||
def qr_roi_cloud_comparison():
|
|
||||||
def do_api():
|
|
||||||
'''
|
|
||||||
#获取图片url,此处的url为request参数
|
|
||||||
'''
|
|
||||||
std_file = request.files.get('std_file') #备案特征采集区域文件
|
|
||||||
ter_file = request.files.get('ter_file') #终端特征采集文件
|
|
||||||
threshold = float(request.form.get('threshold')) #对比阈值
|
|
||||||
|
|
||||||
std_roi_img = Image.open(std_file)
|
|
||||||
ter_img = Image.open(ter_file)
|
|
||||||
|
|
||||||
'''备案特征区域数据处理'''
|
|
||||||
std_roi_im_fix = roi_img_processing(std_roi_img)
|
|
||||||
|
|
||||||
#提取特征提取区域
|
|
||||||
ter_roi_points, qr_roi_conf = detection.qr_roi_detect_now(ter_img)
|
|
||||||
ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1)
|
|
||||||
# 图像还原计算
|
|
||||||
ter_roi_im_re = detection.roi_generator(ter_roi_img)
|
|
||||||
|
|
||||||
# roi特征区域处理
|
|
||||||
ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re)
|
|
||||||
|
|
||||||
if ter_roi_im_re_fix is None:
|
|
||||||
return {"comparison_result": None, 'similarity_value': None, "status": 'False', "reason": "cannot find roi region after enhancing"}
|
|
||||||
|
|
||||||
similarity = roi_affinity_siml(ter_roi_im_re_fix,std_roi_im_fix)
|
|
||||||
|
|
||||||
comparison_result = "pass" if similarity > threshold else "failed"
|
|
||||||
return {"comparison_result": comparison_result,'similarity_value':similarity, "status":'OK'}
|
|
||||||
|
|
||||||
try:
|
|
||||||
return {"data": do_api()}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
'''
|
|
||||||
网点角度顶部位置检测
|
|
||||||
'''
|
|
||||||
@app.route('/dot_detection_top', methods=['POST', 'GET'])
|
|
||||||
def dot_detection_top():
|
|
||||||
|
|
||||||
def do_api():
|
|
||||||
|
|
||||||
image_file = request.files.get('file')
|
|
||||||
threshold = request.form.get('threshold')
|
|
||||||
angle = request.form.get('angle')
|
|
||||||
img = Image.open(image_file)
|
|
||||||
points, conf = detection.qr_box_detect_now(img)
|
|
||||||
|
|
||||||
def angle_ok(a):
|
|
||||||
if a is None:
|
|
||||||
return False
|
|
||||||
return abs(a - float(angle)) <= float(threshold)
|
|
||||||
|
|
||||||
if points is None :
|
|
||||||
raise Exception("Qr_Decode error")
|
|
||||||
else:
|
|
||||||
rotation_region = qr_detection.rotation_region_crop(img, points)
|
|
||||||
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
|
|
||||||
img_ = np.array(img_)
|
|
||||||
|
|
||||||
x1_1 = int(points[0, 0] * 1)
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
|
|
||||||
if y1_1 < 0:
|
|
||||||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
|
|
||||||
if y1_1 >= 0:
|
|
||||||
x1_2 = int(x1_1 + (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
|
|
||||||
y1_2 = int(y1_1 + (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
|
|
||||||
else:
|
|
||||||
return {"status":"False", "reason": "coordinate invalid"}
|
|
||||||
|
|
||||||
dots_region_top = img_[y1_1:y1_2, x1_1:x1_2]
|
|
||||||
|
|
||||||
# 采用传统插值方式实现图像增强,进行网线角度检测
|
|
||||||
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_top, save_dots=False,
|
|
||||||
res_code=None)
|
|
||||||
if angle_ok(tr_sr_lines_arctan):
|
|
||||||
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
|
|
||||||
|
|
||||||
#采用模型生成实现图像增强,进行网线角度检测
|
|
||||||
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, detection.dot_realesrgan, save_dots=False,
|
|
||||||
res_code=None)
|
|
||||||
if dl_sr_lines_arctan is not None:
|
|
||||||
if angle_ok(dl_sr_lines_arctan):
|
|
||||||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
|
|
||||||
else:
|
|
||||||
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
|
|
||||||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
|
|
||||||
return {"status": "False", "reason": "Failed to find dot angle"}
|
|
||||||
try:
|
|
||||||
return {"data": do_api()}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}
|
|
||||||
'''
|
|
||||||
网点角度底部位置检测
|
|
||||||
'''
|
|
||||||
@app.route('/dot_detection_bottom', methods=['POST', 'GET'])
|
|
||||||
def dot_detection_bottom():
|
|
||||||
|
|
||||||
def do_api():
|
|
||||||
image_file = request.files.get('file')
|
|
||||||
threshold = request.form.get('threshold')
|
|
||||||
angle = request.form.get('angle')
|
|
||||||
img = Image.open(image_file)
|
|
||||||
points, conf = detection.qr_box_detect_now(img)
|
|
||||||
|
|
||||||
def angle_ok(a):
|
|
||||||
if a is None:
|
|
||||||
return False
|
|
||||||
return abs(a - float(angle)) <= float(threshold)
|
|
||||||
|
|
||||||
if points is None :
|
|
||||||
raise Exception("Qr_Decode error")
|
|
||||||
else:
|
|
||||||
rotation_region = qr_detection.rotation_region_crop(img, points)
|
|
||||||
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
|
|
||||||
img_ = np.array(img_)
|
|
||||||
x2_2 = int(points[1, 0] * 0.8)
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
|
|
||||||
max_y2 = img_.shape[0]
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
|
|
||||||
|
|
||||||
if y2_2 > max_y2:
|
|
||||||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
|
|
||||||
|
|
||||||
if y2_2 <= max_y2:
|
|
||||||
x2_1 = int(x2_2 - (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
|
|
||||||
y2_1 = int(y2_2 - (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
|
|
||||||
else:
|
|
||||||
return {"status": "False", "reason": "coordinate invalid"}
|
|
||||||
|
|
||||||
dots_region_bottom = img_[y2_1:y2_2, x2_1:x2_2]
|
|
||||||
|
|
||||||
# 采用传统插值方式实现图像增强,进行网线角度检测
|
|
||||||
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_bottom, save_dots=False,
|
|
||||||
res_code=None)
|
|
||||||
if angle_ok(tr_sr_lines_arctan):
|
|
||||||
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
|
|
||||||
|
|
||||||
#采用模型生成实现图像增强,进行网线角度检测
|
|
||||||
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, detection.dot_realesrgan, save_dots=False,
|
|
||||||
res_code=None)
|
|
||||||
if dl_sr_lines_arctan is not None:
|
|
||||||
if angle_ok(dl_sr_lines_arctan):
|
|
||||||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
|
|
||||||
else:
|
|
||||||
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
|
|
||||||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
|
|
||||||
return {"status": "False", "reason": "Failed to find dot angle"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
return {"data": do_api()}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
@app.route('/', methods=['GET'])
|
|
||||||
def index():
|
|
||||||
return 'emblem detection api'
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", "-l", default="0.0.0.0")
|
|
||||||
parser.add_argument("--port", "-p", type=int, default=6006)
|
|
||||||
parser.add_argument("--download-models-only", action="store_true")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
detection.get_models()
|
|
||||||
if args.download_models_only:
|
|
||||||
print("Models loaded successfully")
|
|
||||||
return
|
|
||||||
detection.init_models()
|
|
||||||
print(f"Starting server at {args.host}:{args.port}")
|
|
||||||
app.run(host=args.host, port=args.port, debug=True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : common.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/4/21 23:46 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
from matplotlib.ticker import NullLocator
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
'''色彩增益加权的AutoMSRCR算法'''
|
|
||||||
def singleScaleRetinex(img, sigma):
|
|
||||||
retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma))
|
|
||||||
|
|
||||||
return retinex
|
|
||||||
|
|
||||||
def multiScaleRetinex(img, sigma_list):
|
|
||||||
retinex = np.zeros_like(img)
|
|
||||||
for sigma in sigma_list:
|
|
||||||
retinex += singleScaleRetinex(img, sigma)
|
|
||||||
|
|
||||||
retinex = retinex / len(sigma_list)
|
|
||||||
|
|
||||||
return retinex
|
|
||||||
|
|
||||||
def automatedMSRCR(img, sigma_list):
|
|
||||||
img = np.float64(img) + 1.0
|
|
||||||
|
|
||||||
img_retinex = multiScaleRetinex(img, sigma_list)
|
|
||||||
|
|
||||||
for i in range(img_retinex.shape[2]):
|
|
||||||
unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True)
|
|
||||||
for u, c in zip(unique, count):
|
|
||||||
if u == 0:
|
|
||||||
zero_count = c
|
|
||||||
break
|
|
||||||
|
|
||||||
low_val = unique[0] / 100.0
|
|
||||||
high_val = unique[-1] / 100.0
|
|
||||||
for u, c in zip(unique, count):
|
|
||||||
if u < 0 and c < zero_count * 0.1:
|
|
||||||
low_val = u / 100.0
|
|
||||||
if u > 0 and c < zero_count * 0.1:
|
|
||||||
high_val = u / 100.0
|
|
||||||
break
|
|
||||||
|
|
||||||
img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val)
|
|
||||||
|
|
||||||
img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \
|
|
||||||
(np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \
|
|
||||||
* 255
|
|
||||||
|
|
||||||
img_retinex = np.uint8(img_retinex)
|
|
||||||
|
|
||||||
return img_retinex
|
|
||||||
|
|
||||||
'''图像处理过程保存'''
|
|
||||||
def save_figure(fig_context,fig_name,res_code):
|
|
||||||
fig = plt.figure()
|
|
||||||
ax = fig.subplots(1)
|
|
||||||
ax.imshow(fig_context, aspect='equal')
|
|
||||||
result_path = os.path.join("tmp", '{}_{}_{}.jpg'.format(fig_name,res_code,time.time()))
|
|
||||||
plt.axis("off")
|
|
||||||
plt.gca().xaxis.set_major_locator(NullLocator())
|
|
||||||
plt.gca().yaxis.set_major_locator(NullLocator())
|
|
||||||
# filename = result_path.split("/")[-1].split(".")[0]
|
|
||||||
plt.savefig(result_path, quality=95, bbox_inches="tight", pad_inches=0.0)
|
|
||||||
plt.close()
|
|
||||||
return result_path
|
|
||||||
|
|
||||||
def cropped_image(img, points, shift, size):
|
|
||||||
x_min, y_min = points[0]
|
|
||||||
x_max, y_max = points[1]
|
|
||||||
|
|
||||||
x_min, y_min, x_max, y_max = int(x_min) - shift, int(y_min) - shift, int(x_max) + shift, int(y_max) + shift
|
|
||||||
quarter_width = (x_max - x_min) // size
|
|
||||||
quarter_height = (y_max - y_min) // size
|
|
||||||
|
|
||||||
# 裁剪图像
|
|
||||||
img = Image.fromarray(np.uint8(img))
|
|
||||||
cropped_im = img.crop((x_min, y_min, x_min + quarter_width, y_min + quarter_height))
|
|
||||||
|
|
||||||
|
|
||||||
return cropped_im
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def return_img_stream(img_local_path):
|
|
||||||
"""
|
|
||||||
工具函数:
|
|
||||||
获取本地图片流
|
|
||||||
:param img_local_path:文件单张图片的本地绝对路径
|
|
||||||
:return: 图片流
|
|
||||||
"""
|
|
||||||
import base64
|
|
||||||
img_stream = ''
|
|
||||||
with open(img_local_path, 'rb') as img_f:
|
|
||||||
img_stream = img_f.read()
|
|
||||||
img_stream = base64.b64encode(img_stream)
|
|
||||||
# print('img_stream:{}'.format('data:image/png;base64,' + str(img_stream)))
|
|
||||||
return str(img_stream).split("\'")[1]
|
|
||||||
@ -1,137 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : dot_detection.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/4/21 17:53 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
from common import automatedMSRCR, save_figure
|
|
||||||
from component.utile import return_img_stream
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
'''
|
|
||||||
采用realesrgan模型实现图像增强后进行网线角度检测
|
|
||||||
'''
|
|
||||||
def dots_angle_measure_dl_sr(dots_region,realesrgan,save_dots=False,res_code=None):
|
|
||||||
|
|
||||||
original_img_stream = None
|
|
||||||
process_dots_img_stream = None
|
|
||||||
detection_dots_img_stream = None
|
|
||||||
lines_arctan = None
|
|
||||||
if save_dots:
|
|
||||||
result_path = save_figure(dots_region, 'original', res_code)
|
|
||||||
original_img_stream = return_img_stream(result_path)
|
|
||||||
# 超分计算处理
|
|
||||||
process_dots = realesrgan(dots_region)
|
|
||||||
|
|
||||||
|
|
||||||
if save_dots:
|
|
||||||
b, g, r = cv2.split(process_dots)
|
|
||||||
process_dots_ = cv2.merge([r, g, b])
|
|
||||||
result_path = save_figure(process_dots_, 'process_dots', res_code)
|
|
||||||
process_dots_img_stream = return_img_stream(result_path)
|
|
||||||
|
|
||||||
# 斑点检测
|
|
||||||
found, corners = cv2.findCirclesGrid(process_dots, (3, 3), cv2.CALIB_CB_SYMMETRIC_GRID)
|
|
||||||
|
|
||||||
if corners is None or corners.any() != None and corners.shape[0] < 2:
|
|
||||||
# print("---------------------------------------------------------------")
|
|
||||||
# print("corners status,", found)
|
|
||||||
# print("---------------------------------------------------------------")
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream]
|
|
||||||
|
|
||||||
elif corners.shape[0] >= 2:
|
|
||||||
A1 = corners[0]
|
|
||||||
A2 = corners[1]
|
|
||||||
if (A1 == A2).all():
|
|
||||||
# 斑点检测
|
|
||||||
found, corners = cv2.findCirclesGrid(process_dots, (4, 4), cv2.CALIB_CB_SYMMETRIC_GRID)
|
|
||||||
A1 = corners[0]
|
|
||||||
A2 = corners[1]
|
|
||||||
if (A1 == A2).all():
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream]
|
|
||||||
B1 = corners[1]
|
|
||||||
B2 = np.array([[0, corners[1][:, 1]]],dtype=object)
|
|
||||||
|
|
||||||
kLine1 = (A2[:, 1] - A1[:, 1]) / (A2[:, 0] - A1[:, 0])
|
|
||||||
kLine2 = (B2[:, 1] - B1[:, 1]) / (B2[:, 0] - B1[:, 0])
|
|
||||||
tan_k = (kLine2 - kLine1) / (1 + kLine2 * kLine1)
|
|
||||||
lines_arctan = math.atan(tan_k)
|
|
||||||
lines_arctan = float('%.2f' % abs(-lines_arctan * 180.0 / 3.1415926))
|
|
||||||
if save_dots:
|
|
||||||
dbg_image_circles = process_dots.copy()
|
|
||||||
dbg_image_circles = cv2.drawChessboardCorners(dbg_image_circles, (3, 3), corners, found)
|
|
||||||
result_path = save_figure(dbg_image_circles, 'detection_dots', res_code)
|
|
||||||
detection_dots_img_stream = return_img_stream(result_path)
|
|
||||||
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream, detection_dots_img_stream]
|
|
||||||
'''
|
|
||||||
采用传统方式实现图像增强后进行网线角度检测
|
|
||||||
'''
|
|
||||||
def dots_angle_measure_tr_sr(dots_region,save_dots=False,res_code=None):
|
|
||||||
|
|
||||||
original_img_stream = None
|
|
||||||
process_dots_img_stream = None
|
|
||||||
detection_dots_img_stream = None
|
|
||||||
lines_arctan = None
|
|
||||||
if save_dots:
|
|
||||||
result_path = save_figure(dots_region, 'original', res_code)
|
|
||||||
original_img_stream = return_img_stream(result_path)
|
|
||||||
# 锐化插值处理
|
|
||||||
sigma_list = [15, 80, 200]
|
|
||||||
process_dots = automatedMSRCR(
|
|
||||||
dots_region,
|
|
||||||
sigma_list
|
|
||||||
)
|
|
||||||
size = 4
|
|
||||||
process_dots = cv2.resize(process_dots, None, fx=size, fy=size, interpolation=cv2.INTER_CUBIC)
|
|
||||||
|
|
||||||
if save_dots:
|
|
||||||
result_path = save_figure(process_dots, ' process_dots', res_code)
|
|
||||||
process_dots_img_stream = return_img_stream(result_path)
|
|
||||||
|
|
||||||
# 斑点检测
|
|
||||||
found, corners = cv2.findCirclesGrid(process_dots, (3, 3), cv2.CALIB_CB_SYMMETRIC_GRID)
|
|
||||||
|
|
||||||
if corners is None or corners.any() != None and corners.shape[0] < 2:
|
|
||||||
# print("---------------------------------------------------------------")
|
|
||||||
# print("corners status,", found)
|
|
||||||
# print("---------------------------------------------------------------")
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream]
|
|
||||||
|
|
||||||
elif corners.shape[0] >= 2:
|
|
||||||
A1 = corners[0]
|
|
||||||
A2 = corners[1]
|
|
||||||
if (A1 == A2).all():
|
|
||||||
# 斑点检测
|
|
||||||
found, corners = cv2.findCirclesGrid(process_dots, (4, 4), cv2.CALIB_CB_SYMMETRIC_GRID)
|
|
||||||
A1 = corners[0]
|
|
||||||
A2 = corners[1]
|
|
||||||
if (A1 == A2).all():
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream]
|
|
||||||
B1 = corners[1]
|
|
||||||
B2 = np.array([[0, corners[1][:, 1]]],dtype=object)
|
|
||||||
|
|
||||||
kLine1 = (A2[:, 1] - A1[:, 1]) / (A2[:, 0] - A1[:, 0])
|
|
||||||
kLine2 = (B2[:, 1] - B1[:, 1]) / (B2[:, 0] - B1[:, 0])
|
|
||||||
tan_k = (kLine2 - kLine1) / (1 + kLine2 * kLine1)
|
|
||||||
lines_arctan = math.atan(tan_k)
|
|
||||||
lines_arctan = float('%.2f' % abs(-lines_arctan * 180.0 / 3.1415926))
|
|
||||||
if save_dots:
|
|
||||||
dbg_image_circles = process_dots.copy()
|
|
||||||
dbg_image_circles = cv2.drawChessboardCorners(dbg_image_circles, (3, 3), corners, found)
|
|
||||||
result_path = save_figure(dbg_image_circles, 'detection_dots', res_code)
|
|
||||||
detection_dots_img_stream = return_img_stream(result_path)
|
|
||||||
|
|
||||||
return lines_arctan, [original_img_stream, process_dots_img_stream, detection_dots_img_stream]
|
|
||||||
|
Before Width: | Height: | Size: 1.8 MiB |
@ -1,447 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : qr_detection.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/4/22 09:08 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
from PIL import Image, ImageEnhance
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
def solution_1_1(img, detector):
|
|
||||||
# 亮度处理
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(2)
|
|
||||||
birght_img = np.array(birght_img)
|
|
||||||
res, points = detector.detectAndDecode(birght_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_1_2(img, detector):
|
|
||||||
# 亮度处理
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(3)
|
|
||||||
birght_img = np.array(birght_img)
|
|
||||||
res, points = detector.detectAndDecode(birght_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_2_1(img, detector):
|
|
||||||
# #对比度增强 + 亮度处理
|
|
||||||
contrast_img = ImageEnhance.Contrast(img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
birght_img = ImageEnhance.Brightness(contrast_img)
|
|
||||||
birght_img = birght_img.enhance(2)
|
|
||||||
birght_img = np.array(birght_img)
|
|
||||||
res, points = detector.detectAndDecode(birght_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_2_2(img, detector):
|
|
||||||
# #对比度增强 + 亮度处理
|
|
||||||
contrast_img = ImageEnhance.Contrast(img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
birght_img = ImageEnhance.Brightness(contrast_img)
|
|
||||||
birght_img = birght_img.enhance(3)
|
|
||||||
birght_img = np.array(birght_img)
|
|
||||||
res, points = detector.detectAndDecode(birght_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(birght_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_3_1(img, detector):
|
|
||||||
# # 亮度处理 + 对比度增强
|
|
||||||
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(2)
|
|
||||||
contrast_img = ImageEnhance.Contrast(birght_img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
contrast_img = np.array(contrast_img)
|
|
||||||
res, points = detector.detectAndDecode(contrast_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(contrast_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(contrast_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_3_2(img, detector):
|
|
||||||
# 亮度处理 + 对比度增强
|
|
||||||
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(3)
|
|
||||||
contrast_img = ImageEnhance.Contrast(birght_img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
contrast_img = np.array(contrast_img)
|
|
||||||
res, points = detector.detectAndDecode(contrast_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(contrast_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(contrast_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_4_1(img, detector):
|
|
||||||
# 亮度处理 + 对比度增强 + 锐化
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(2)
|
|
||||||
contrast_img = ImageEnhance.Contrast(birght_img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
sharpness_img = ImageEnhance.Sharpness(contrast_img)
|
|
||||||
sharpness_img = sharpness_img.enhance(1.5)
|
|
||||||
sharpness_img = np.array(sharpness_img)
|
|
||||||
res, points = detector.detectAndDecode(sharpness_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(sharpness_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(sharpness_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_4_2(img, detector):
|
|
||||||
# 亮度处理 + 对比度增强 + 锐化
|
|
||||||
birght_img = ImageEnhance.Brightness(img)
|
|
||||||
birght_img = birght_img.enhance(3)
|
|
||||||
contrast_img = ImageEnhance.Contrast(birght_img)
|
|
||||||
contrast_img = contrast_img.enhance(1.5)
|
|
||||||
sharpness_img = ImageEnhance.Sharpness(contrast_img)
|
|
||||||
sharpness_img = sharpness_img.enhance(1.5)
|
|
||||||
sharpness_img = np.array(sharpness_img)
|
|
||||||
res, points = detector.detectAndDecode(sharpness_img)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(sharpness_img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
samll_img = cv2.resize(np.array(sharpness_img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
return res, points[0] * 4
|
|
||||||
return res, points[0] * 1
|
|
||||||
|
|
||||||
|
|
||||||
def solution_5(img, detector):
|
|
||||||
# 缩放X4
|
|
||||||
samll_img = cv2.resize(np.array(img), None, fx=1 / 4, fy=1 / 4, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 4
|
|
||||||
|
|
||||||
|
|
||||||
def solution_6(img, detector):
|
|
||||||
# 缩放X8
|
|
||||||
samll_img = cv2.resize(np.array(img), None, fx=1 / 8, fy=1 / 8, interpolation=cv2.INTER_CUBIC)
|
|
||||||
input = np.array(samll_img)
|
|
||||||
res, points = detector.detectAndDecode(input)
|
|
||||||
if len(res) == 0:
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return res, points[0] * 8
|
|
||||||
|
|
||||||
|
|
||||||
def rotation_waffine(rotation_region, input):
|
|
||||||
# 亮度处理
|
|
||||||
birght_img = ImageEnhance.Brightness(rotation_region)
|
|
||||||
birght_img = birght_img.enhance(5)
|
|
||||||
|
|
||||||
# 灰度二值化
|
|
||||||
img_grey = cv2.cvtColor(np.array(birght_img), cv2.COLOR_BGR2GRAY)
|
|
||||||
ret2, thresh = cv2.threshold(img_grey, 0, 255,
|
|
||||||
cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
||||||
|
|
||||||
# 腐蚀
|
|
||||||
# OpenCV定义的结构矩形元素
|
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (50, 50))
|
|
||||||
eroded = cv2.erode(thresh, kernel)
|
|
||||||
|
|
||||||
# canny边缘检测
|
|
||||||
eroded_canny = cv2.Canny(eroded, 100, 300)
|
|
||||||
contours, hierarchy = cv2.findContours(eroded_canny, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
# canny_img = np.zeros((np.array(birght_img).shape[0], np.array(birght_img).shape[1], 3), np.uint8) + 255
|
|
||||||
# canny_img = cv2.drawContours(canny_img, contours, -1, (0, 255, 0), 3)
|
|
||||||
|
|
||||||
# 寻找最大元素
|
|
||||||
k = 0
|
|
||||||
index = 0
|
|
||||||
if len(contours) != 0:
|
|
||||||
for i in range(len(contours)):
|
|
||||||
j = contours[i].size
|
|
||||||
if j > k:
|
|
||||||
k = j
|
|
||||||
index = i
|
|
||||||
else:
|
|
||||||
return input
|
|
||||||
|
|
||||||
# 拟合旋转矩形
|
|
||||||
cnt = contours[index]
|
|
||||||
rect = cv2.minAreaRect(cnt)
|
|
||||||
angle = rect[2]
|
|
||||||
# box = cv2.boxPoints(rect)
|
|
||||||
# box = np.int0(box)
|
|
||||||
# rect_img = cv2.drawContours(canny_img, [box], 0, (0, 0, 255), 2)
|
|
||||||
|
|
||||||
# 根据角度差计算仿射矩阵
|
|
||||||
height, width, _ = np.array(input).shape
|
|
||||||
center = (width // 2, height // 2)
|
|
||||||
if angle != 0.0:
|
|
||||||
if angle > 45.0:
|
|
||||||
angle = angle - 90
|
|
||||||
# angle = angle - 90
|
|
||||||
# rotate page if not straight relative to QR code
|
|
||||||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
|
||||||
output = cv2.warpAffine(np.array(input), M, (width, height), flags=cv2.INTER_CUBIC,
|
|
||||||
borderMode=cv2.BORDER_REPLICATE)
|
|
||||||
output = Image.fromarray(output)
|
|
||||||
|
|
||||||
output = output.convert("RGB")
|
|
||||||
else:
|
|
||||||
output = input
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def rotation_region_crop(input,points):
|
|
||||||
input = np.array(input)
|
|
||||||
x1 = int(points[0, 0])
|
|
||||||
y1 = int(points[0, 1])
|
|
||||||
x2 = int(points[1, 0])
|
|
||||||
y2 = int(points[1, 1])
|
|
||||||
x_a = x1 - int((x2 - x1) * 0.1)
|
|
||||||
x_b = x2 + int((x2 - x1) * 0.1)
|
|
||||||
y_a = y1 - int((y2 - y1) * 0.1)
|
|
||||||
y_b = y2 + int((y2 - y1) * 0.1)
|
|
||||||
if x_a < 0:
|
|
||||||
x_a = 0
|
|
||||||
if y_a < 0:
|
|
||||||
y_a = 0
|
|
||||||
if x_b >= input.shape[1]:
|
|
||||||
x_b = input.shape[1]
|
|
||||||
if y_b >= input.shape[0]:
|
|
||||||
y_b = input.shape[0]
|
|
||||||
|
|
||||||
top_size, bottom_size, left_size, right_size = (50, 50, 50, 50)
|
|
||||||
rotation_region = cv2.copyMakeBorder(input[y_a:y_b, x_a:x_b], top_size, bottom_size, left_size, right_size,
|
|
||||||
borderType=cv2.BORDER_REPLICATE, value=0)
|
|
||||||
return rotation_region
|
|
||||||
|
|
||||||
def qr_detetor(input, detector):
|
|
||||||
success = False
|
|
||||||
if not success:
|
|
||||||
res, points = solution_1_1(input, detector)
|
|
||||||
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input,points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region),input)
|
|
||||||
# img = input
|
|
||||||
# print('solution_1_1')
|
|
||||||
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_1_2(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
# print('solution_1_2')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_5(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_5')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_6(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_6')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_2_1(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_2_1')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_2_2(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_2_2')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_3_1(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_3_1')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_3_2(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_3_2')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_4_1(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_4_1')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
res, points = solution_4_2(input, detector)
|
|
||||||
if res == None:
|
|
||||||
success = False
|
|
||||||
else:
|
|
||||||
success = True
|
|
||||||
rotation_region = rotation_region_crop(input, points)
|
|
||||||
img = rotation_waffine(Image.fromarray(rotation_region), input)
|
|
||||||
# img = input
|
|
||||||
print('solution_4_2')
|
|
||||||
return res, points, img
|
|
||||||
|
|
||||||
if success is False:
|
|
||||||
return None, None, None
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : __init__.py.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/3/3 11:41 AM Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : affinity.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/3/12 9:18 PM Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
def eightway_total_diff(arr):
|
|
||||||
# 对输入数组进行填充,每边填充1个元素
|
|
||||||
padded_arr = np.pad(arr, pad_width=1, mode='edge')
|
|
||||||
|
|
||||||
# 获取填充后数组的尺寸
|
|
||||||
h, w = padded_arr.shape
|
|
||||||
|
|
||||||
# 初始化输出数组,用于存储总差值
|
|
||||||
diff_array = np.zeros((h-2, w-2))
|
|
||||||
|
|
||||||
# 遍历每个元素(不包括填充的边界)
|
|
||||||
for y in range(1, h-1):
|
|
||||||
for x in range(1, w-1):
|
|
||||||
total_diff = 0 # 初始化当前元素的总差值
|
|
||||||
# 计算每个方向的差值
|
|
||||||
for dy in range(-1, 2):
|
|
||||||
for dx in range(-1, 2):
|
|
||||||
if dy == 0 and dx == 0:
|
|
||||||
# 排除中心点自身
|
|
||||||
continue
|
|
||||||
# 直接使用填充后的坐标计算差值
|
|
||||||
diff = abs(padded_arr[y, x] - padded_arr[y+dy, x+dx])
|
|
||||||
# 如果差值小于0,则置为0
|
|
||||||
total_diff += max(diff, 0)
|
|
||||||
# 将总差值存储在输出数组中
|
|
||||||
diff_array[y-1, x-1] = total_diff
|
|
||||||
return diff_array.astype(int)
|
|
||||||
|
|
||||||
def roi_affinity_siml(X_ter,X_str):
|
|
||||||
X_ter = np.array(X_ter)
|
|
||||||
X_str = np.array(X_str)
|
|
||||||
X_ter_diffs = eightway_total_diff(X_ter / X_ter.max())
|
|
||||||
X_str_diffs = eightway_total_diff(X_str / X_str.max())
|
|
||||||
# 计算差异
|
|
||||||
difference = cv2.absdiff(X_ter_diffs, X_str_diffs)
|
|
||||||
mean_diff = np.mean(difference)
|
|
||||||
# 计算相似度
|
|
||||||
similarity = max((1 - mean_diff) * 100, 0)
|
|
||||||
return similarity
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : image_transform.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2024/3/5 11:33 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
from typing import Union, Optional, List, Tuple, Text, BinaryIO
|
|
||||||
import io
|
|
||||||
import pathlib
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
irange = range
|
|
||||||
|
|
||||||
|
|
||||||
def make_grid(
|
|
||||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
|
||||||
nrow: int = 8,
|
|
||||||
padding: int = 2,
|
|
||||||
normalize: bool = False,
|
|
||||||
range: Optional[Tuple[int, int]] = None,
|
|
||||||
scale_each: bool = False,
|
|
||||||
pad_value: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Make a grid of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
|
||||||
or a list of images all of the same size.
|
|
||||||
nrow (int, optional): Number of images displayed in each row of the grid.
|
|
||||||
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
|
||||||
padding (int, optional): amount of padding. Default: ``2``.
|
|
||||||
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
|
||||||
by the min and max values specified by :attr:`range`. Default: ``False``.
|
|
||||||
range (tuple, optional): tuple (min, max) where min and max are numbers,
|
|
||||||
then these numbers are used to normalize the image. By default, min and max
|
|
||||||
are computed from the tensor.
|
|
||||||
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
|
||||||
images separately rather than the (min, max) over all images. Default: ``False``.
|
|
||||||
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not (torch.is_tensor(tensor) or
|
|
||||||
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
|
||||||
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
|
|
||||||
|
|
||||||
# if list of tensors, convert to a 4D mini-batch Tensor
|
|
||||||
if isinstance(tensor, list):
|
|
||||||
tensor = torch.stack(tensor, dim=0)
|
|
||||||
|
|
||||||
if tensor.dim() == 2: # single image H x W
|
|
||||||
tensor = tensor.unsqueeze(0)
|
|
||||||
if tensor.dim() == 3: # single image
|
|
||||||
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
|
||||||
tensor = torch.cat((tensor, tensor, tensor), 0)
|
|
||||||
tensor = tensor.unsqueeze(0)
|
|
||||||
|
|
||||||
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
|
||||||
tensor = torch.cat((tensor, tensor, tensor), 1)
|
|
||||||
|
|
||||||
if normalize is True:
|
|
||||||
tensor = tensor.clone() # avoid modifying tensor in-place
|
|
||||||
if range is not None:
|
|
||||||
assert isinstance(range, tuple), \
|
|
||||||
"range has to be a tuple (min, max) if specified. min and max are numbers"
|
|
||||||
|
|
||||||
def norm_ip(img, min, max):
|
|
||||||
img.clamp_(min=min, max=max)
|
|
||||||
img.add_(-min).div_(max - min + 1e-5)
|
|
||||||
|
|
||||||
def norm_range(t, range):
|
|
||||||
if range is not None:
|
|
||||||
norm_ip(t, range[0], range[1])
|
|
||||||
else:
|
|
||||||
norm_ip(t, float(t.min()), float(t.max()))
|
|
||||||
|
|
||||||
if scale_each is True:
|
|
||||||
for t in tensor: # loop over mini-batch dimension
|
|
||||||
norm_range(t, range)
|
|
||||||
else:
|
|
||||||
norm_range(tensor, range)
|
|
||||||
|
|
||||||
if tensor.size(0) == 1:
|
|
||||||
return tensor.squeeze(0)
|
|
||||||
|
|
||||||
# make the mini-batch of images into a grid
|
|
||||||
nmaps = tensor.size(0)
|
|
||||||
xmaps = min(nrow, nmaps)
|
|
||||||
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
|
||||||
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
|
||||||
num_channels = tensor.size(1)
|
|
||||||
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
|
||||||
k = 0
|
|
||||||
for y in irange(ymaps):
|
|
||||||
for x in irange(xmaps):
|
|
||||||
if k >= nmaps:
|
|
||||||
break
|
|
||||||
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
|
||||||
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
|
||||||
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
|
|
||||||
2, x * width + padding, width - padding
|
|
||||||
).copy_(tensor[k])
|
|
||||||
k = k + 1
|
|
||||||
return grid
|
|
||||||
|
|
||||||
|
|
||||||
def image_tf(
|
|
||||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
|
||||||
nrow: int = 8,
|
|
||||||
padding: int = 2,
|
|
||||||
normalize: bool = False,
|
|
||||||
range: Optional[Tuple[int, int]] = None,
|
|
||||||
scale_each: bool = False,
|
|
||||||
pad_value: int = 0,
|
|
||||||
format: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Save a given Tensor into an image file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
|
||||||
saves the tensor as a grid of images by calling ``make_grid``.
|
|
||||||
fp (string or file object): A filename or a file object
|
|
||||||
format(Optional): If omitted, the format to use is determined from the filename extension.
|
|
||||||
If a file object was used instead of a filename, this parameter should always be used.
|
|
||||||
**kwargs: Other arguments are documented in ``make_grid``.
|
|
||||||
"""
|
|
||||||
from PIL import Image
|
|
||||||
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
|
|
||||||
normalize=normalize, range=range, scale_each=scale_each)
|
|
||||||
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
|
||||||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
|
||||||
im = Image.fromarray(ndarr)
|
|
||||||
return im
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.utils import save_image
|
|
||||||
from qr_verify_Tool.image_transform import *
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
def weights_init_normal(m):
|
|
||||||
classname = m.__class__.__name__
|
|
||||||
if classname.find("Conv") != -1:
|
|
||||||
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
||||||
elif classname.find("BatchNorm2d") != -1:
|
|
||||||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
||||||
torch.nn.init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# U-NET
|
|
||||||
##############################
|
|
||||||
|
|
||||||
|
|
||||||
class UNetDown(nn.Module):
|
|
||||||
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
|
|
||||||
super(UNetDown, self).__init__()
|
|
||||||
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
|
|
||||||
if normalize:
|
|
||||||
layers.append(nn.InstanceNorm2d(out_size))
|
|
||||||
layers.append(nn.LeakyReLU(0.2))
|
|
||||||
if dropout:
|
|
||||||
layers.append(nn.Dropout(dropout))
|
|
||||||
self.model = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.model(x)
|
|
||||||
|
|
||||||
|
|
||||||
class UNetUp(nn.Module):
|
|
||||||
def __init__(self, in_size, out_size, dropout=0.0):
|
|
||||||
super(UNetUp, self).__init__()
|
|
||||||
layers = [
|
|
||||||
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
|
|
||||||
nn.InstanceNorm2d(out_size),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
]
|
|
||||||
if dropout:
|
|
||||||
layers.append(nn.Dropout(dropout))
|
|
||||||
|
|
||||||
self.model = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x, skip_input):
|
|
||||||
x = self.model(x)
|
|
||||||
x = torch.cat((x, skip_input), 1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GeneratorUNet(nn.Module):
|
|
||||||
def __init__(self, in_channels=3, out_channels=3):
|
|
||||||
super(GeneratorUNet, self).__init__()
|
|
||||||
|
|
||||||
self.down1 = UNetDown(in_channels, 64, normalize=False)
|
|
||||||
self.down2 = UNetDown(64, 128)
|
|
||||||
self.down3 = UNetDown(128, 256)
|
|
||||||
self.down4 = UNetDown(256, 512, dropout=0.5)
|
|
||||||
self.down5 = UNetDown(512, 512, dropout=0.5)
|
|
||||||
self.down6 = UNetDown(512, 512, dropout=0.5)
|
|
||||||
self.down7 = UNetDown(512, 512, dropout=0.5)
|
|
||||||
self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
|
|
||||||
|
|
||||||
self.up1 = UNetUp(512, 512, dropout=0.5)
|
|
||||||
self.up2 = UNetUp(1024, 512, dropout=0.5)
|
|
||||||
self.up3 = UNetUp(1024, 512, dropout=0.5)
|
|
||||||
self.up4 = UNetUp(1024, 512, dropout=0.5)
|
|
||||||
self.up5 = UNetUp(1024, 256)
|
|
||||||
self.up6 = UNetUp(512, 128)
|
|
||||||
self.up7 = UNetUp(256, 64)
|
|
||||||
|
|
||||||
self.final = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=2),
|
|
||||||
nn.ZeroPad2d((1, 0, 1, 0)),
|
|
||||||
nn.Conv2d(128, out_channels, 4, padding=1),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
transform_init = transforms.Compose([transforms.Resize((256, 256), Image.NEAREST),transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
|
|
||||||
x_init = transform_init(x)
|
|
||||||
x_init = x_init.unsqueeze(0)
|
|
||||||
d1 = self.down1(x_init)
|
|
||||||
d2 = self.down2(d1)
|
|
||||||
d3 = self.down3(d2)
|
|
||||||
d4 = self.down4(d3)
|
|
||||||
d5 = self.down5(d4)
|
|
||||||
d6 = self.down6(d5)
|
|
||||||
d7 = self.down7(d6)
|
|
||||||
d8 = self.down8(d7)
|
|
||||||
u1 = self.up1(d8, d7)
|
|
||||||
u2 = self.up2(u1, d6)
|
|
||||||
u3 = self.up3(u2, d5)
|
|
||||||
u4 = self.up4(u3, d4)
|
|
||||||
u5 = self.up5(u4, d3)
|
|
||||||
u6 = self.up6(u5, d2)
|
|
||||||
u7 = self.up7(u6, d1)
|
|
||||||
x_final = self.final(u7)
|
|
||||||
|
|
||||||
x_final = image_tf(x_final,normalize=True)
|
|
||||||
return x_final
|
|
||||||
@ -1,205 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision.models import resnet18
|
|
||||||
|
|
||||||
class DIQA(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(DIQA, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.conv2 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, stride=2, padding=1)
|
|
||||||
self.conv3 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1)
|
|
||||||
self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.conv7 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.conv8 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
|
|
||||||
# self.conv9 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
|
|
||||||
self.fc1 = nn.Linear(128, 128)
|
|
||||||
# self.fc1 = nn.Linear(100352, 128)
|
|
||||||
self.fc2 = nn.Linear(128, 64)
|
|
||||||
|
|
||||||
# 上面定义模型 下面真正意义上的搭建模型
|
|
||||||
def forward_once(self, input):
|
|
||||||
x = input.view(-1, input[0].size(-3), input[0].size(-2), input[0].size(-1))
|
|
||||||
|
|
||||||
x = F.relu(self.conv1(x))
|
|
||||||
x = F.relu(self.conv2(x))
|
|
||||||
x = F.relu(self.conv3(x))
|
|
||||||
x = F.relu(self.conv4(x))
|
|
||||||
x = F.relu(self.conv5(x))
|
|
||||||
x = F.relu(self.conv6(x))
|
|
||||||
x = F.relu(self.conv7(x))
|
|
||||||
conv8 = F.relu(self.conv8(x)) # 12 128 28 28
|
|
||||||
|
|
||||||
q = torch.nn.functional.adaptive_avg_pool2d(conv8, (1, 1)) # 12 128 1 1
|
|
||||||
q = q.squeeze(3).squeeze(2)
|
|
||||||
# q = conv8.view(conv8.size(0), -1)
|
|
||||||
q = self.fc1(q)
|
|
||||||
q = self.fc2(q)
|
|
||||||
q = F.normalize(q, p=2, dim=1)
|
|
||||||
|
|
||||||
return q
|
|
||||||
|
|
||||||
def forward(self, input1, input2):
|
|
||||||
output1 = self.forward_once(input1)
|
|
||||||
output2 = self.forward_once(input2)
|
|
||||||
return output1, output2
|
|
||||||
|
|
||||||
class SiameseNetwork(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(SiameseNetwork, self).__init__()
|
|
||||||
self.cnn1 = nn.Sequential(
|
|
||||||
nn.Conv2d(1,32,5),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.MaxPool2d(2,2),
|
|
||||||
# Currently 45x55x32
|
|
||||||
nn.Conv2d(32,64,3),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.MaxPool2d(2,2),
|
|
||||||
#Currently 21x26x64
|
|
||||||
nn.Conv2d(64,64,3),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.MaxPool2d(2,2),
|
|
||||||
# #Currently 9x12x64
|
|
||||||
# nn.Conv2d(64, 64, 3),
|
|
||||||
# nn.BatchNorm2d(64),
|
|
||||||
# nn.ReLU(True),
|
|
||||||
# nn.MaxPool2d(2, 2),
|
|
||||||
# nn.Conv2d(64, 64, 3),
|
|
||||||
# nn.BatchNorm2d(64),
|
|
||||||
# nn.ReLU(True)
|
|
||||||
# nn.MaxPool2d(2, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fc1 = nn.Sequential(
|
|
||||||
nn.Linear(20736 , 4096),
|
|
||||||
# nn.Sigmoid(),
|
|
||||||
# nn.Dropout(0.5,False),
|
|
||||||
nn.Linear(4096,256))
|
|
||||||
# self.out = nn.Linear(4096,1)
|
|
||||||
|
|
||||||
def forward_once(self, x):
|
|
||||||
output = self.cnn1(x)
|
|
||||||
output = output.view(-1,20736)
|
|
||||||
output = self.fc1(output)
|
|
||||||
output = F.normalize(output, p=2, dim=1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, input1,input2):
|
|
||||||
output1 = self.forward_once(input1)
|
|
||||||
output2 = self.forward_once(input2)
|
|
||||||
# out = self.out(torch.abs(output1-output2))
|
|
||||||
# return out.view(out.size())
|
|
||||||
return output1,output2
|
|
||||||
|
|
||||||
class FaceModel(nn.Module):
|
|
||||||
def __init__(self,embedding_size,num_classes,pretrained=False):
|
|
||||||
super(FaceModel, self).__init__()
|
|
||||||
|
|
||||||
self.model = resnet18(pretrained)
|
|
||||||
|
|
||||||
self.embedding_size = embedding_size
|
|
||||||
|
|
||||||
self.model.fc = nn.Linear(512*8*8, self.embedding_size)
|
|
||||||
|
|
||||||
self.model.classifier = nn.Linear(self.embedding_size, num_classes)
|
|
||||||
|
|
||||||
|
|
||||||
def l2_norm(self,input):
|
|
||||||
input_size = input.size()
|
|
||||||
buffer = torch.pow(input, 2)
|
|
||||||
|
|
||||||
normp = torch.sum(buffer, 1).add_(1e-10)
|
|
||||||
norm = torch.sqrt(normp)
|
|
||||||
|
|
||||||
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
|
||||||
|
|
||||||
output = _output.view(input_size)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward_once(self, x):
|
|
||||||
|
|
||||||
x = self.model.conv1(x)
|
|
||||||
x = self.model.bn1(x)
|
|
||||||
x = self.model.relu(x)
|
|
||||||
x = self.model.maxpool(x)
|
|
||||||
x = self.model.layer1(x)
|
|
||||||
x = self.model.layer2(x)
|
|
||||||
x = self.model.layer3(x)
|
|
||||||
x = self.model.layer4(x)
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
x = self.model.fc(x)
|
|
||||||
self.features = self.l2_norm(x)
|
|
||||||
# Multiply by alpha = 10 as suggested in https://arxiv.org/pdf/1703.09507.pdf
|
|
||||||
alpha=10
|
|
||||||
self.features = self.features*alpha
|
|
||||||
|
|
||||||
#x = self.model.classifier(self.features)
|
|
||||||
return self.features
|
|
||||||
def forward(self, input):
|
|
||||||
output = self.forward_once(input)
|
|
||||||
# output2 = self.forward_once(input2)
|
|
||||||
# out = self.out(torch.abs(output1-output2))
|
|
||||||
# return out.view(out.size())
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward_classifier(self, x):
|
|
||||||
features = self.forward(x)
|
|
||||||
res = self.model.classifier(features)
|
|
||||||
return res
|
|
||||||
|
|
||||||
class SignaturesNetwork(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(SignaturesNetwork, self).__init__()
|
|
||||||
# Setting up the Sequential of CNN Layers
|
|
||||||
self.cnn1 = nn.Sequential(
|
|
||||||
nn.Conv2d(3, 96, kernel_size=11,stride=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
|
|
||||||
nn.MaxPool2d(3, stride=2),
|
|
||||||
|
|
||||||
nn.Conv2d(96, 256, kernel_size=5,stride=1,padding=2),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
|
|
||||||
nn.MaxPool2d(3, stride=2),
|
|
||||||
nn.Dropout2d(p=0.3),
|
|
||||||
|
|
||||||
nn.Conv2d(256,384 , kernel_size=3,stride=1,padding=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(384,256, kernel_size=3,stride=1,padding=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.MaxPool2d(3, stride=2),
|
|
||||||
nn.Dropout2d(p=0.3),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Defining the fully connected layers
|
|
||||||
self.fc1 = nn.Sequential(
|
|
||||||
# First Dense Layer
|
|
||||||
nn.Linear(952576, 1024),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Dropout2d(p=0.5),
|
|
||||||
# Second Dense Layer
|
|
||||||
nn.Linear(1024, 128),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
# # Final Dense Layer
|
|
||||||
nn.Linear(128,128))
|
|
||||||
|
|
||||||
def forward_once(self, x):
|
|
||||||
# Forward pass
|
|
||||||
output = self.cnn1(x)
|
|
||||||
output = output.view(output.size()[0], -1)
|
|
||||||
output = self.fc1(output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, input1, input2):
|
|
||||||
# forward pass of input 1
|
|
||||||
output1 = self.forward_once(input1)
|
|
||||||
# forward pass of input 2
|
|
||||||
output2 = self.forward_once(input2)
|
|
||||||
# returning the feature vectors of two inputs
|
|
||||||
return output1, output2
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : qr_orb.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2023/8/26 18:58 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def preprocess_image(image):
|
|
||||||
# 将图像调整为固定大小
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
||||||
resized_image = cv2.resize(image, (275, 275), interpolation=cv2.INTER_NEAREST)
|
|
||||||
return resized_image
|
|
||||||
|
|
||||||
def FAST_corner_detection(image, threshold):
|
|
||||||
keypoints = []
|
|
||||||
for i in range(3, image.shape[0] - 3):
|
|
||||||
for j in range(3, image.shape[1] - 3):
|
|
||||||
center_pixel = image[i, j]
|
|
||||||
pixel_difference = [abs(image[i + dx, j + dy] - center_pixel) for dx, dy in
|
|
||||||
[(0, -3), (0, 3), (-3, 0), (3, 0)]]
|
|
||||||
if all(diff > threshold for diff in pixel_difference):
|
|
||||||
keypoints.append(cv2.KeyPoint(j, i, 7))
|
|
||||||
return keypoints
|
|
||||||
|
|
||||||
|
|
||||||
def BRIEF_descriptor(keypoints, image):
|
|
||||||
patch_size = 31
|
|
||||||
descriptors = []
|
|
||||||
for kp in keypoints:
|
|
||||||
x, y = int(kp.pt[0]), int(kp.pt[1])
|
|
||||||
patch = image[y - patch_size // 2:y + patch_size // 2 + 1, x - patch_size // 2:x + patch_size // 2 + 1]
|
|
||||||
descriptor = ""
|
|
||||||
for i in range(patch_size * patch_size):
|
|
||||||
x1, y1 = np.random.randint(0, patch_size, size=2)
|
|
||||||
x2, y2 = np.random.randint(0, patch_size, size=2)
|
|
||||||
if patch[y1, x1] < patch[y2, x2]:
|
|
||||||
descriptor += "0"
|
|
||||||
else:
|
|
||||||
descriptor += "1"
|
|
||||||
descriptors.append(descriptor)
|
|
||||||
return descriptors
|
|
||||||
|
|
||||||
|
|
||||||
def ORB(image, nfeatures=500, fastThreshold=20):
|
|
||||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
keypoints = FAST_corner_detection(gray, fastThreshold)
|
|
||||||
keypoints = sorted(keypoints, key=lambda x: -x.response)[:nfeatures]
|
|
||||||
|
|
||||||
descriptors = BRIEF_descriptor(keypoints, gray)
|
|
||||||
|
|
||||||
return keypoints, descriptors
|
|
||||||
|
|
||||||
def ORB_detect(img,nfeatures=800,fastThreshold=20):
|
|
||||||
img = preprocess_image(img)
|
|
||||||
# 初始化opencv原生ORB检测器
|
|
||||||
orb = cv2.ORB_create(nfeatures=nfeatures, scaleFactor=1.2, edgeThreshold=31, patchSize=31, fastThreshold=fastThreshold)
|
|
||||||
|
|
||||||
kp, des = orb.detectAndCompute(img, None)
|
|
||||||
return des
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
必要传参:
|
|
||||||
std_roi_feature: 服务器备案特征数据
|
|
||||||
ter_roi_feature: 手机终端特征数据
|
|
||||||
threshold: 相似度对比阈值
|
|
||||||
预留传参:
|
|
||||||
distance: 距离筛选度量,默认值100
|
|
||||||
'''
|
|
||||||
|
|
||||||
def roi_siml(std_roi_feature,ter_roi_feature,distance=100, threshold = None):
|
|
||||||
|
|
||||||
std_roi_feature = std_roi_feature.astype('uint8')
|
|
||||||
ter_roi_feature = ter_roi_feature.astype('uint8')
|
|
||||||
threshold = float(threshold)
|
|
||||||
|
|
||||||
# 使用汉明距离对特侦点距离进行计算
|
|
||||||
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
|
|
||||||
# 使用knn算法进行匹配
|
|
||||||
matches = bf.knnMatch(std_roi_feature, trainDescriptors=ter_roi_feature, k=2)
|
|
||||||
|
|
||||||
# 去除无效和模糊的匹配
|
|
||||||
good = []
|
|
||||||
for match in matches:
|
|
||||||
if len(match) >= 2:
|
|
||||||
good = [(m, n) for (m, n) in matches if m.distance < 0.95 * n.distance and m.distance < distance]
|
|
||||||
|
|
||||||
similarity = len(good) / len(matches) *100
|
|
||||||
if similarity >=threshold:
|
|
||||||
return 'passed',similarity
|
|
||||||
else:
|
|
||||||
return 'failed ',similarity
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : qr_verify.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2022/3/3 11:38 AM Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from qr_verify_Tool.network import FaceModel
|
|
||||||
|
|
||||||
|
|
||||||
class QR_verify_dev(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(QR_verify_dev, self).__init__()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print('cuda available:{}'.format(torch.cuda.is_available()))
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
self.checkpoint = torch.load('./model/train_qr_varification_FaceModel_3qp_256_Affine_lr1e-3_far-1_80.pth',
|
|
||||||
map_location=self.device)
|
|
||||||
self.model = FaceModel(256, num_classes=640, pretrained=True)
|
|
||||||
self.model.load_state_dict(self.checkpoint['model'])
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def forward(self,input):
|
|
||||||
# input = torch.from_numpy(np.array(input).transpose((2, 0, 1))).unsqueeze(0)
|
|
||||||
# input = input.cuda().float()
|
|
||||||
output = self.model(input)
|
|
||||||
|
|
||||||
return output
|
|
||||||
@ -1,207 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : roi_img_process.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2024/1/7 22:57 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
def preprocess_image(image,size,blur):
|
|
||||||
# 将图像调整为固定大小
|
|
||||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
||||||
if size !=-1:
|
|
||||||
image = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA)
|
|
||||||
if size == -1:
|
|
||||||
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY)
|
|
||||||
if blur:
|
|
||||||
kernel_size = 3
|
|
||||||
image = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigmaX=1, sigmaY=0)
|
|
||||||
return image
|
|
||||||
def clos_open(input,B_p):
|
|
||||||
_, binary_image = cv2.threshold(input, B_p, 255, cv2.THRESH_BINARY)
|
|
||||||
# 定义结构元素(内核)
|
|
||||||
kernel = np.ones((3, 3), np.uint8)
|
|
||||||
|
|
||||||
# 开运算
|
|
||||||
opening = cv2.morphologyEx(binary_image, cv2.MORPH_OPEN, kernel)
|
|
||||||
|
|
||||||
# 定义结构元素(内核)
|
|
||||||
kernel = np.ones((3, 3), np.uint8)
|
|
||||||
# 闭运算
|
|
||||||
closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
|
|
||||||
|
|
||||||
return closing
|
|
||||||
|
|
||||||
|
|
||||||
def img_angle(input):
|
|
||||||
# canny边缘检测
|
|
||||||
eroded_canny = cv2.Canny(input, 50, 255, apertureSize=3)
|
|
||||||
contours, hierarchy = cv2.findContours(eroded_canny, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
# 寻找最大元素
|
|
||||||
k = 0
|
|
||||||
index = 0
|
|
||||||
if len(contours) != 0:
|
|
||||||
for i in range(len(contours)):
|
|
||||||
j = contours[i].size
|
|
||||||
if j > k:
|
|
||||||
k = j
|
|
||||||
index = i
|
|
||||||
|
|
||||||
# 拟合旋转矩形
|
|
||||||
cnt = contours[index]
|
|
||||||
rect = cv2.minAreaRect(cnt)
|
|
||||||
angle = rect[2]
|
|
||||||
box = cv2.boxPoints(rect)
|
|
||||||
|
|
||||||
box = np.int0(box)
|
|
||||||
|
|
||||||
# 对点按照y坐标排序(最上和最下的点)
|
|
||||||
box = box[box[:, 1].argsort()]
|
|
||||||
|
|
||||||
# 上边的两个点按照x坐标排序(决定左上和右上)
|
|
||||||
top_two = box[:2]
|
|
||||||
top_two = top_two[top_two[:, 0].argsort()]
|
|
||||||
|
|
||||||
# 下边的两个点按照x坐标排序(决定左下和右下)
|
|
||||||
bottom_two = box[2:]
|
|
||||||
bottom_two = bottom_two[bottom_two[:, 0].argsort()]
|
|
||||||
|
|
||||||
# 重新组合为左上、右上、右下、左下的顺序
|
|
||||||
ordered_box = np.array([top_two[0], top_two[1], bottom_two[1], bottom_two[0]])
|
|
||||||
|
|
||||||
return angle,ordered_box
|
|
||||||
|
|
||||||
def angel_affine(input,angle):
|
|
||||||
height, width = np.array(input).shape
|
|
||||||
center = (width // 2, height // 2)
|
|
||||||
if angle is not None and angle != 0.0:
|
|
||||||
if angle > 45.0:
|
|
||||||
angle = angle - 90
|
|
||||||
# rotate page if not straight relative to QR code
|
|
||||||
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
|
||||||
output = cv2.warpAffine(np.array(input), M, (width, height), flags=cv2.INTER_CUBIC,borderMode=cv2.BORDER_REPLICATE)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def merge_overlapping_rectangles(rectangles):
|
|
||||||
if len(rectangles) <= 1:
|
|
||||||
return rectangles
|
|
||||||
|
|
||||||
merged_rectangles = []
|
|
||||||
for rect in sorted(rectangles, key=lambda x: x[2]*x[3], reverse=True):
|
|
||||||
x, y, w, h = rect
|
|
||||||
if not any((x < x2 + w2 and x + w > x2 and y < y2 + h2 and y + h > y2) for x2, y2, w2, h2 in merged_rectangles):
|
|
||||||
merged_rectangles.append(rect)
|
|
||||||
return merged_rectangles
|
|
||||||
|
|
||||||
|
|
||||||
def roi2_detector(input,min_area_threshold):
|
|
||||||
# 使用 Canny 边缘检测
|
|
||||||
edges = cv2.Canny(input, 50, 255, apertureSize=3)
|
|
||||||
|
|
||||||
# 寻找轮廓
|
|
||||||
contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
# 临时存储矩形
|
|
||||||
temp_rectangles = []
|
|
||||||
|
|
||||||
# 识别矩形轮廓
|
|
||||||
for contour in contours:
|
|
||||||
# 轮廓近似
|
|
||||||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
|
||||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
|
||||||
|
|
||||||
if len(approx) >= 4:
|
|
||||||
x, y, w, h = cv2.boundingRect(approx)
|
|
||||||
aspect_ratio = w / float(h)
|
|
||||||
area_contour = cv2.contourArea(contour)
|
|
||||||
if aspect_ratio >= 0.75 and aspect_ratio < 1.2 and min_area_threshold < area_contour < 15000:
|
|
||||||
area_bounding_rect = w * h
|
|
||||||
area_ratio = area_contour / area_bounding_rect
|
|
||||||
if 0.75 < area_ratio < 1.05:
|
|
||||||
temp_rectangles.append((x, y, w, h))
|
|
||||||
|
|
||||||
# 合并重叠的方框
|
|
||||||
merged_rectangles = merge_overlapping_rectangles(temp_rectangles)
|
|
||||||
|
|
||||||
# 选择并绘制最大的方框
|
|
||||||
if merged_rectangles:
|
|
||||||
outermost_rectangle = max(merged_rectangles, key=lambda x: x[2]*x[3])
|
|
||||||
rectangle_count = 1
|
|
||||||
return outermost_rectangle
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
def roi_img_processing(input):
|
|
||||||
input_ray = cv2.cvtColor(np.array(input), cv2.COLOR_BGR2GRAY)
|
|
||||||
img_resize = preprocess_image(input_ray, 128, blur=False)
|
|
||||||
img_closing = clos_open(img_resize,B_p=50) # 闭运算
|
|
||||||
min_area_threshold=6500
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
if roi_coordinate is None:
|
|
||||||
img_closing = clos_open(img_resize, B_p=100) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
if roi_coordinate is None:
|
|
||||||
return None
|
|
||||||
x, y, w, h = roi_coordinate
|
|
||||||
# 截取图像
|
|
||||||
input_fix = img_resize[y+1:y-1 + h, x+1:x-1 + w]
|
|
||||||
output = preprocess_image(input_fix, 32, blur=True)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def roi_detect(input):
|
|
||||||
size = 128
|
|
||||||
height, width, _ = input.shape
|
|
||||||
|
|
||||||
if height < size:
|
|
||||||
img_resize = cv2.resize(np.array(input), (size, size), interpolation=cv2.INTER_AREA)
|
|
||||||
else:
|
|
||||||
img_resize = cv2.resize(np.array(input), (size, size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
image_gay = cv2.cvtColor(img_resize, cv2.COLOR_BGR2GRAY)
|
|
||||||
min_area_threshold = 1000
|
|
||||||
img_closing = clos_open(image_gay, B_p=50) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
|
|
||||||
if roi_coordinate is None :
|
|
||||||
img_closing = clos_open(img_resize, B_p=100) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
# print('150')
|
|
||||||
if roi_coordinate is None :
|
|
||||||
img_closing = clos_open(img_resize, B_p=150) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
# print('100')
|
|
||||||
if roi_coordinate is None :
|
|
||||||
img_closing = clos_open(img_resize, B_p=120) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
if roi_coordinate is None:
|
|
||||||
img_closing = clos_open(img_resize, B_p=75) # 闭运算
|
|
||||||
roi_coordinate = roi2_detector(img_closing,min_area_threshold)
|
|
||||||
if roi_coordinate is None:
|
|
||||||
return None
|
|
||||||
x, y, w, h = roi_coordinate
|
|
||||||
# 截取图像
|
|
||||||
input_fix = img_resize[y + 1:y - 1 + h, x + 1:x - 1 + w]
|
|
||||||
input_fix = Image.fromarray(np.uint8(input_fix))
|
|
||||||
return input_fix
|
|
||||||
def roi_minAreaRect(input):
|
|
||||||
img_closing = clos_open(input) # 闭运算
|
|
||||||
_, box = img_angle(img_closing) # 图像偏差角度
|
|
||||||
return box
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,9 +0,0 @@
|
|||||||
Flask==3.0.2
|
|
||||||
basicsr==1.3.5
|
|
||||||
matplotlib==3.5.1
|
|
||||||
opencv-contrib-python==4.5.5.62
|
|
||||||
opencv-python==4.5.5.62
|
|
||||||
timm==0.9.2
|
|
||||||
pandas
|
|
||||||
seaborn
|
|
||||||
oss2
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : roi_fix.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2024/2/27 00:02 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
'''批量固定尺寸'''
|
|
||||||
import os
|
|
||||||
import cv2
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def merge_overlapping_rectangles(rectangles):
|
|
||||||
if len(rectangles) <= 1:
|
|
||||||
return rectangles
|
|
||||||
|
|
||||||
merged_rectangles = []
|
|
||||||
for rect in sorted(rectangles, key=lambda x: x[2]*x[3], reverse=True):
|
|
||||||
x, y, w, h = rect
|
|
||||||
if not any((x < x2 + w2 and x + w > x2 and y < y2 + h2 and y + h > y2) for x2, y2, w2, h2 in merged_rectangles):
|
|
||||||
merged_rectangles.append(rect)
|
|
||||||
return merged_rectangles
|
|
||||||
|
|
||||||
def roi_Inner_detector(input,):
|
|
||||||
# 定义结构元素(内核)
|
|
||||||
kernel = np.ones((3, 3), np.uint8)
|
|
||||||
|
|
||||||
# 开运算
|
|
||||||
opening = cv2.morphologyEx(input, cv2.MORPH_OPEN, kernel)
|
|
||||||
|
|
||||||
# 定义结构元素(内核)
|
|
||||||
kernel = np.ones((3, 3), np.uint8)
|
|
||||||
# 闭运算
|
|
||||||
closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
|
|
||||||
|
|
||||||
# 使用 Canny 边缘检测
|
|
||||||
edges = cv2.Canny(closing, 50, 255, apertureSize=3)
|
|
||||||
|
|
||||||
# 寻找轮廓
|
|
||||||
contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
# 临时存储矩形
|
|
||||||
temp_rectangles = []
|
|
||||||
|
|
||||||
# 识别矩形轮廓
|
|
||||||
for contour in contours:
|
|
||||||
# 轮廓近似
|
|
||||||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
|
||||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
|
||||||
|
|
||||||
if len(approx) >= 4:
|
|
||||||
x, y, w, h = cv2.boundingRect(approx)
|
|
||||||
aspect_ratio = w / float(h)
|
|
||||||
area_contour = cv2.contourArea(contour)
|
|
||||||
min_area_threshold = 0
|
|
||||||
# print('aspect_ratio',aspect_ratio)
|
|
||||||
# print('area_contour',area_contour)
|
|
||||||
if aspect_ratio >= 0.80 and aspect_ratio < 1.2 and 6500 < area_contour < 13000:
|
|
||||||
area_bounding_rect = w * h
|
|
||||||
area_ratio = area_contour / area_bounding_rect
|
|
||||||
# print('aspect_ratio', area_ratio)
|
|
||||||
# print('area_ratio', area_ratio)
|
|
||||||
# print('area_contour', area_contour)
|
|
||||||
if 0.80 < area_ratio < 1.05:
|
|
||||||
temp_rectangles.append((x, y, w, h))
|
|
||||||
|
|
||||||
# 合并重叠的方框
|
|
||||||
merged_rectangles = merge_overlapping_rectangles(temp_rectangles)
|
|
||||||
return merged_rectangles
|
|
||||||
|
|
||||||
input_dir = '/project/dataset/QR2024/org100/orange-0-roi-fix'
|
|
||||||
output_dir = '/project/dataset/QR2024/org100/orange-0-roi-fix2fix'
|
|
||||||
|
|
||||||
size = 128
|
|
||||||
|
|
||||||
# 创建输出目录
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 遍历输入目录中的图片
|
|
||||||
for filename in tqdm(os.listdir(input_dir), desc='Processing'):
|
|
||||||
if filename.endswith('.jpg') or filename.endswith('.png'):
|
|
||||||
input_path = os.path.join(input_dir, filename)
|
|
||||||
image = cv2.imread(input_path)
|
|
||||||
height, width, _ = image.shape
|
|
||||||
|
|
||||||
if height < size:
|
|
||||||
image = cv2.resize(np.array(image), (size, size), interpolation=cv2.INTER_AREA)
|
|
||||||
else:
|
|
||||||
image = cv2.resize(np.array(image), (size, size), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
image_gay = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 50, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
|
|
||||||
if len(merged_rectangles) == 0 :
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 150, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
# print('150')
|
|
||||||
if len(merged_rectangles) == 0 :
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 100, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
# print('100')
|
|
||||||
if len(merged_rectangles) == 0 :
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 125, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
# print('120')
|
|
||||||
if len(merged_rectangles) == 0 :
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 75, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
|
|
||||||
if len(merged_rectangles) == 0:
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 85, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
if len(merged_rectangles) == 0:
|
|
||||||
_, binary_image_A = cv2.threshold(image_gay, 115, 255, cv2.THRESH_BINARY)
|
|
||||||
merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
# print('75')
|
|
||||||
# if len(merged_rectangles) == 0 :
|
|
||||||
# _, binary_image_A = cv2.threshold(image_gay, 0, 255, cv2.THRESH_BINARY)
|
|
||||||
# merged_rectangles = roi_Inner_detector(binary_image_A)
|
|
||||||
|
|
||||||
if merged_rectangles:
|
|
||||||
outermost_rectangle = max(merged_rectangles, key=lambda x: x[2] * x[3])
|
|
||||||
x, y, w, h = outermost_rectangle
|
|
||||||
image_crope = image[y + 1:y - 1 + h, x + 1:x - 1 + w]
|
|
||||||
image_crope = cv2.resize(image_crope, (32, 32), interpolation=cv2.INTER_AREA)
|
|
||||||
output_path = os.path.join(output_dir, filename)
|
|
||||||
cv2.imwrite(output_path, image_crope)
|
|
||||||
else:
|
|
||||||
print("failed",filename)
|
|
||||||
# output_path = os.path.join(output_dir, filename)
|
|
||||||
# cv2.imwrite(output_path, image)
|
|
||||||
# 保存裁剪后的图像
|
|
||||||
|
|
||||||
print("裁剪完成!")
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import subprocess
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
BASE_DIR = os.path.abspath(os.path.dirname(__file__) + '/..')
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("action", help="CI action: build, push")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
NAME = subprocess.check_output('git rev-parse --short HEAD', shell=True, encoding='utf-8').strip()
|
|
||||||
IMG_TAG = "registry.cn-shenzhen.aliyuncs.com/euphon-private/emblem-detection:" + NAME
|
|
||||||
|
|
||||||
print(IMG_TAG)
|
|
||||||
|
|
||||||
def do_test():
|
|
||||||
subprocess.check_call("docker run --rm --network=host %s ./tests/run.py" % IMG_TAG, shell=True)
|
|
||||||
|
|
||||||
def do_build():
|
|
||||||
subprocess.check_call("docker build --network=host -t %s ." % IMG_TAG, shell=True)
|
|
||||||
|
|
||||||
def do_push():
|
|
||||||
subprocess.check_call("docker push %s" % IMG_TAG, shell=True)
|
|
||||||
|
|
||||||
def do_deploy(env=""):
|
|
||||||
kubeconfig = os.path.join(BASE_DIR, 'deploy/kubeconfig.' + env)
|
|
||||||
src = os.path.join(BASE_DIR, 'deploy/detection.yaml')
|
|
||||||
tf = tempfile.NamedTemporaryFile(mode='w')
|
|
||||||
with open(src, 'r') as f:
|
|
||||||
tmpl = f.read()
|
|
||||||
tf.write(tmpl.format(image=IMG_TAG))
|
|
||||||
tf.flush()
|
|
||||||
cmd = ['kubectl', '--kubeconfig', kubeconfig, 'apply', '-f', tf.name]
|
|
||||||
subprocess.check_call(cmd)
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
if args.action == "build":
|
|
||||||
return do_build()
|
|
||||||
elif args.action == "test":
|
|
||||||
return do_test()
|
|
||||||
elif args.action == "push":
|
|
||||||
return do_push()
|
|
||||||
elif args.action == "deploy-prod":
|
|
||||||
return do_deploy('prod')
|
|
||||||
elif args.action == "deploy-dev":
|
|
||||||
return do_deploy('dev')
|
|
||||||
raise Exception("Unrecognized command")
|
|
||||||
|
|
||||||
main()
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
pip3 install opencv-python==4.5.5.62
|
|
||||||
pip3 install opencv-contrib-python==4.5.5.62
|
|
||||||
pip3 install -r requirements.txt
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
python3 app.py
|
|
||||||
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 14 KiB |
|
Before Width: | Height: | Size: 189 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 20 KiB |
|
Before Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 7.3 KiB |
|
Before Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 5.5 KiB |
|
Before Width: | Height: | Size: 73 KiB |
|
Before Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 111 KiB |
|
Before Width: | Height: | Size: 110 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 219 KiB |
|
Before Width: | Height: | Size: 195 KiB |
|
Before Width: | Height: | Size: 103 KiB |
|
Before Width: | Height: | Size: 969 B |
|
Before Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 266 KiB |
|
Before Width: | Height: | Size: 16 KiB |
@ -1,69 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import time
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import atexit
|
|
||||||
import cv2
|
|
||||||
import requests
|
|
||||||
import subprocess
|
|
||||||
import random
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
BASE_DIR = os.path.abspath(os.path.dirname(__file__) + "/..")
|
|
||||||
|
|
||||||
class Testing(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = self.start_server()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.server.terminate()
|
|
||||||
self.server.wait()
|
|
||||||
|
|
||||||
def start_server(self):
|
|
||||||
port = random.randint(40000, 60000)
|
|
||||||
cmd = ['python3', 'app.py', '-l', '127.0.0.1', '-p', str(port)]
|
|
||||||
p = subprocess.Popen(cmd, cwd=BASE_DIR)
|
|
||||||
start = time.time()
|
|
||||||
while p.poll() == None and time.time() - start < 1200:
|
|
||||||
try:
|
|
||||||
url = 'http://localhost:%d' % port
|
|
||||||
if 'emblem' in requests.get(url, timeout=1).text:
|
|
||||||
self.base_url = url
|
|
||||||
return p
|
|
||||||
except:
|
|
||||||
time.sleep(1)
|
|
||||||
raise Exception("Failed to start server")
|
|
||||||
|
|
||||||
def do_test_roi_cloud_comparison(self, std, ter):
|
|
||||||
std_file_Path = os.path.join(BASE_DIR, 'tests/data/', std)
|
|
||||||
ter_file_Path = os.path.join(BASE_DIR, 'tests/data/', ter)
|
|
||||||
|
|
||||||
std_img = cv2.imread(std_file_Path)
|
|
||||||
ter_img = cv2.imread(ter_file_Path)
|
|
||||||
|
|
||||||
std_img_encode = cv2.imencode('.jpg', std_img)[1]
|
|
||||||
ter_img_encode = cv2.imencode('.jpg', ter_img)[1]
|
|
||||||
|
|
||||||
files = {"std_file": ('std_file.jpg', std_img_encode, 'image/jpg'),
|
|
||||||
"ter_file": ('ter_file.jpg', ter_img_encode, 'image/jpg')}
|
|
||||||
# form = {"threshold":'50',"angle":'45'}
|
|
||||||
form = {"threshold": '50'}
|
|
||||||
begin = time.time()
|
|
||||||
api = '/qr_roi_cloud_comparison'
|
|
||||||
r = requests.post(self.base_url + api, files=files,
|
|
||||||
data=form).json()
|
|
||||||
print("std file", std_file_Path)
|
|
||||||
print("ter file", ter_file_Path)
|
|
||||||
print(r)
|
|
||||||
print("%.3fs processed %s" % (time.time() - begin, r))
|
|
||||||
self.assertEqual(r['data']['status'], 'OK')
|
|
||||||
|
|
||||||
def test_roi_cmp(self):
|
|
||||||
self.do_test_roi_cloud_comparison('11173059793161.jpg', '11173059793161_pos_M2012K11AC_fix.jpg')
|
|
||||||
#std_file_Path =os.path.join(BASE_DIR, 'tests/data/','11173059793161.jpg')
|
|
||||||
#ter_file_Path =os.path.join(BASE_DIR, 'tests/data/','11173059793161_pos_M2012K11AC_fix.jpg')
|
|
||||||
#self.do_test_roi_cloud_comparison('0079080983780.jpg', '0079080983780_pos_iPhone14Pro_1706842853.5876188.jpg')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@ -1,56 +0,0 @@
|
|||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from .yolo5x_qr.models.common import DetectMultiBackend
|
|
||||||
from .yolo5x_qr.utils.augmentations import (letterbox)
|
|
||||||
from .yolo5x_qr.utils.torch_utils import select_device, smart_inference_mode
|
|
||||||
from .yolo5x_qr.utils.general import (LOGGER, Profile, check_img_size, cv2,
|
|
||||||
non_max_suppression, scale_boxes)
|
|
||||||
|
|
||||||
class QR_Box_detect(nn.Module):
|
|
||||||
def __init__(self, model_path=None, device='cpu'):
|
|
||||||
super(QR_Box_detect, self).__init__()
|
|
||||||
self.conf_thres=0.80
|
|
||||||
self.iou_thres=0.45
|
|
||||||
self.classes=None
|
|
||||||
self.max_det=1
|
|
||||||
self.agnostic_nms=False
|
|
||||||
self.model_path = model_path
|
|
||||||
self.device = select_device(device,model_path=self.model_path)
|
|
||||||
self.model = DetectMultiBackend(weights=self.model_path, device=self.device)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,input,imgsz=512):
|
|
||||||
|
|
||||||
#图像按比例缩放
|
|
||||||
stride, names, pt = self.model.stride, self.model.names, self.model.pt
|
|
||||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
|
||||||
im0 = cv2.cvtColor(np.array(input), cv2.COLOR_RGB2BGR) # BGR
|
|
||||||
im = letterbox(im0, imgsz, stride=32, auto=True)[0] # padded resize
|
|
||||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
im = np.ascontiguousarray(im) # contiguous
|
|
||||||
|
|
||||||
|
|
||||||
im = torch.from_numpy(im).to(self.model.device)
|
|
||||||
im = im.half() if self.model.fp16 else im.float() #
|
|
||||||
im /= 255
|
|
||||||
if len(im.shape) == 3:
|
|
||||||
im = im[None] # expand for batch dim
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
pred = self.model.model(im, augment=False,visualize=False)
|
|
||||||
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det)
|
|
||||||
det = pred[0]
|
|
||||||
if len(det):
|
|
||||||
# Rescale boxes from img_size to im0 size
|
|
||||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
|
||||||
return det[:, :4].view(2, 2).cpu().numpy(),det[:,4:5].cpu().numpy()
|
|
||||||
return None,None
|
|
||||||
except RuntimeError as error:
|
|
||||||
print('Error', error)
|
|
||||||
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,46 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import cv2
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
import torch.nn as nn
|
|
||||||
from .realesrgan import Real_ESRGANer
|
|
||||||
# from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
class RealsrGan(nn.Module):
|
|
||||||
def __init__(self, num_in_ch=3, scale=4,model_path=None,device='0'):
|
|
||||||
super(RealsrGan, self).__init__()
|
|
||||||
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
|
||||||
self.netscale = 4
|
|
||||||
self.model_path = model_path
|
|
||||||
self.device = device
|
|
||||||
# restorer
|
|
||||||
self.upsampler = Real_ESRGANer(
|
|
||||||
scale=self.netscale,
|
|
||||||
device=self.device,
|
|
||||||
model_path=self.model_path,
|
|
||||||
model=self.model,
|
|
||||||
tile=0,
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=False)
|
|
||||||
|
|
||||||
def forward(self,input):
|
|
||||||
|
|
||||||
# img = cv2.imread(input, cv2.IMREAD_UNCHANGED)
|
|
||||||
# if len(img.shape) == 3 and img.shape[2] == 4:
|
|
||||||
# img_mode = 'RGBA'
|
|
||||||
# else:
|
|
||||||
# img_mode = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
output, _ = self.upsampler.enhance(input, outscale=4)
|
|
||||||
except RuntimeError as error:
|
|
||||||
print('Error', error)
|
|
||||||
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
|
||||||
# else:
|
|
||||||
# if img_mode == 'RGBA': # RGBA images should be saved in png format
|
|
||||||
# extension = 'png'
|
|
||||||
# save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
|
||||||
# cv2.imwrite(save_path, output)
|
|
||||||
return output
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
from .archs import *
|
|
||||||
from .data import *
|
|
||||||
from .models import *
|
|
||||||
from .utils import *
|
|
||||||
from .version import *
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
import importlib
|
|
||||||
from basicsr.utils import scandir
|
|
||||||
from os import path as osp
|
|
||||||
|
|
||||||
# automatically scan and import arch modules for registry
|
|
||||||
# scan all the files that end with '_arch.py' under the archs folder
|
|
||||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
|
||||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
|
||||||
# import all the arch modules
|
|
||||||
_arch_modules = [importlib.import_module(f'thirdTool.realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class UNetDiscriminatorSN(nn.Module):
|
|
||||||
"""Defines a U-Net discriminator with spectral normalization (SN)
|
|
||||||
|
|
||||||
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
||||||
|
|
||||||
Arg:
|
|
||||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
|
||||||
num_feat (int): Channel number of base intermediate features. Default: 64.
|
|
||||||
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
|
||||||
super(UNetDiscriminatorSN, self).__init__()
|
|
||||||
self.skip_connection = skip_connection
|
|
||||||
norm = spectral_norm
|
|
||||||
# the first convolution
|
|
||||||
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
|
||||||
# downsample
|
|
||||||
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
|
||||||
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
|
||||||
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
|
||||||
# upsample
|
|
||||||
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
|
||||||
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
|
||||||
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
|
||||||
# extra convolutions
|
|
||||||
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
|
||||||
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
|
||||||
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# downsample
|
|
||||||
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
|
||||||
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
|
||||||
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
|
||||||
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# upsample
|
|
||||||
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
|
||||||
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
if self.skip_connection:
|
|
||||||
x4 = x4 + x2
|
|
||||||
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
|
||||||
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
if self.skip_connection:
|
|
||||||
x5 = x5 + x1
|
|
||||||
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
|
||||||
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
if self.skip_connection:
|
|
||||||
x6 = x6 + x0
|
|
||||||
|
|
||||||
# extra convolutions
|
|
||||||
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
|
||||||
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
|
||||||
out = self.conv9(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
@ -1,69 +0,0 @@
|
|||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from torch import nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class SRVGGNetCompact(nn.Module):
|
|
||||||
"""A compact VGG-style network structure for super-resolution.
|
|
||||||
|
|
||||||
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
|
||||||
conducted on the HR feature space.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
|
||||||
num_out_ch (int): Channel number of outputs. Default: 3.
|
|
||||||
num_feat (int): Channel number of intermediate features. Default: 64.
|
|
||||||
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
|
||||||
upscale (int): Upsampling factor. Default: 4.
|
|
||||||
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
|
||||||
super(SRVGGNetCompact, self).__init__()
|
|
||||||
self.num_in_ch = num_in_ch
|
|
||||||
self.num_out_ch = num_out_ch
|
|
||||||
self.num_feat = num_feat
|
|
||||||
self.num_conv = num_conv
|
|
||||||
self.upscale = upscale
|
|
||||||
self.act_type = act_type
|
|
||||||
|
|
||||||
self.body = nn.ModuleList()
|
|
||||||
# the first conv
|
|
||||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
|
||||||
# the first activation
|
|
||||||
if act_type == 'relu':
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == 'prelu':
|
|
||||||
activation = nn.PReLU(num_parameters=num_feat)
|
|
||||||
elif act_type == 'leakyrelu':
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation)
|
|
||||||
|
|
||||||
# the body structure
|
|
||||||
for _ in range(num_conv):
|
|
||||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
|
||||||
# activation
|
|
||||||
if act_type == 'relu':
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == 'prelu':
|
|
||||||
activation = nn.PReLU(num_parameters=num_feat)
|
|
||||||
elif act_type == 'leakyrelu':
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation)
|
|
||||||
|
|
||||||
# the last conv
|
|
||||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
|
||||||
# upsample
|
|
||||||
self.upsampler = nn.PixelShuffle(upscale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = x
|
|
||||||
for i in range(0, len(self.body)):
|
|
||||||
out = self.body[i](out)
|
|
||||||
|
|
||||||
out = self.upsampler(out)
|
|
||||||
# add the nearest upsampled image, so that the network learns the residual
|
|
||||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
|
||||||
out += base
|
|
||||||
return out
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
import importlib
|
|
||||||
from basicsr.utils import scandir
|
|
||||||
from os import path as osp
|
|
||||||
|
|
||||||
# automatically scan and import dataset modules for registry
|
|
||||||
# scan all the files that end with '_dataset.py' under the data folder
|
|
||||||
data_folder = osp.dirname(osp.abspath(__file__))
|
|
||||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
|
||||||
# import all the dataset modules
|
|
||||||
_dataset_modules = [importlib.import_module(f'thirdTool.realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
|
||||||
@ -1,192 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
|
||||||
from basicsr.data.transforms import augment
|
|
||||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
|
||||||
from basicsr.utils.registry import DATASET_REGISTRY
|
|
||||||
from torch.utils import data as data
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET_REGISTRY.register()
|
|
||||||
class RealESRGANDataset(data.Dataset):
|
|
||||||
"""Dataset used for Real-ESRGAN model:
|
|
||||||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
||||||
|
|
||||||
It loads gt (Ground-Truth) images, and augments them.
|
|
||||||
It also generates blur kernels and sinc kernels for generating low-quality images.
|
|
||||||
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
opt (dict): Config for train datasets. It contains the following keys:
|
|
||||||
dataroot_gt (str): Data root path for gt.
|
|
||||||
meta_info (str): Path for meta information file.
|
|
||||||
io_backend (dict): IO backend type and other kwarg.
|
|
||||||
use_hflip (bool): Use horizontal flips.
|
|
||||||
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
|
||||||
Please see more options in the codes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(RealESRGANDataset, self).__init__()
|
|
||||||
self.opt = opt
|
|
||||||
self.file_client = None
|
|
||||||
self.io_backend_opt = opt['io_backend']
|
|
||||||
self.gt_folder = opt['dataroot_gt']
|
|
||||||
|
|
||||||
# file client (lmdb io backend)
|
|
||||||
if self.io_backend_opt['type'] == 'lmdb':
|
|
||||||
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
|
||||||
self.io_backend_opt['client_keys'] = ['gt']
|
|
||||||
if not self.gt_folder.endswith('.lmdb'):
|
|
||||||
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
|
||||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
|
||||||
self.paths = [line.split('.')[0] for line in fin]
|
|
||||||
else:
|
|
||||||
# disk backend with meta_info
|
|
||||||
# Each line in the meta_info describes the relative path to an image
|
|
||||||
with open(self.opt['meta_info']) as fin:
|
|
||||||
paths = [line.strip().split(' ')[0] for line in fin]
|
|
||||||
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
|
||||||
|
|
||||||
# blur settings for the first degradation
|
|
||||||
self.blur_kernel_size = opt['blur_kernel_size']
|
|
||||||
self.kernel_list = opt['kernel_list']
|
|
||||||
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
|
||||||
self.blur_sigma = opt['blur_sigma']
|
|
||||||
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
|
||||||
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
|
||||||
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
|
||||||
|
|
||||||
# blur settings for the second degradation
|
|
||||||
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
|
||||||
self.kernel_list2 = opt['kernel_list2']
|
|
||||||
self.kernel_prob2 = opt['kernel_prob2']
|
|
||||||
self.blur_sigma2 = opt['blur_sigma2']
|
|
||||||
self.betag_range2 = opt['betag_range2']
|
|
||||||
self.betap_range2 = opt['betap_range2']
|
|
||||||
self.sinc_prob2 = opt['sinc_prob2']
|
|
||||||
|
|
||||||
# a final sinc filter
|
|
||||||
self.final_sinc_prob = opt['final_sinc_prob']
|
|
||||||
|
|
||||||
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
|
||||||
# TODO: kernel range is now hard-coded, should be in the configure file
|
|
||||||
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
|
||||||
self.pulse_tensor[10, 10] = 1
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.file_client is None:
|
|
||||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
|
||||||
|
|
||||||
# -------------------------------- Load gt images -------------------------------- #
|
|
||||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
|
||||||
gt_path = self.paths[index]
|
|
||||||
# avoid errors caused by high latency in reading files
|
|
||||||
retry = 3
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
|
||||||
except (IOError, OSError) as e:
|
|
||||||
logger = get_root_logger()
|
|
||||||
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
|
||||||
# change another file to read
|
|
||||||
index = random.randint(0, self.__len__())
|
|
||||||
gt_path = self.paths[index]
|
|
||||||
time.sleep(1) # sleep 1s for occasional server congestion
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
retry -= 1
|
|
||||||
img_gt = imfrombytes(img_bytes, float32=True)
|
|
||||||
|
|
||||||
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
|
||||||
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
|
||||||
|
|
||||||
# crop or pad to 400
|
|
||||||
# TODO: 400 is hard-coded. You may change it accordingly
|
|
||||||
h, w = img_gt.shape[0:2]
|
|
||||||
crop_pad_size = 400
|
|
||||||
# pad
|
|
||||||
if h < crop_pad_size or w < crop_pad_size:
|
|
||||||
pad_h = max(0, crop_pad_size - h)
|
|
||||||
pad_w = max(0, crop_pad_size - w)
|
|
||||||
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
|
||||||
# crop
|
|
||||||
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
|
||||||
h, w = img_gt.shape[0:2]
|
|
||||||
# randomly choose top and left coordinates
|
|
||||||
top = random.randint(0, h - crop_pad_size)
|
|
||||||
left = random.randint(0, w - crop_pad_size)
|
|
||||||
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
|
||||||
|
|
||||||
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
|
||||||
kernel_size = random.choice(self.kernel_range)
|
|
||||||
if np.random.uniform() < self.opt['sinc_prob']:
|
|
||||||
# this sinc filter setting is for kernels ranging from [7, 21]
|
|
||||||
if kernel_size < 13:
|
|
||||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
||||||
else:
|
|
||||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
|
||||||
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
|
||||||
else:
|
|
||||||
kernel = random_mixed_kernels(
|
|
||||||
self.kernel_list,
|
|
||||||
self.kernel_prob,
|
|
||||||
kernel_size,
|
|
||||||
self.blur_sigma,
|
|
||||||
self.blur_sigma, [-math.pi, math.pi],
|
|
||||||
self.betag_range,
|
|
||||||
self.betap_range,
|
|
||||||
noise_range=None)
|
|
||||||
# pad kernel
|
|
||||||
pad_size = (21 - kernel_size) // 2
|
|
||||||
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
|
||||||
|
|
||||||
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
|
||||||
kernel_size = random.choice(self.kernel_range)
|
|
||||||
if np.random.uniform() < self.opt['sinc_prob2']:
|
|
||||||
if kernel_size < 13:
|
|
||||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
||||||
else:
|
|
||||||
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
|
||||||
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
|
||||||
else:
|
|
||||||
kernel2 = random_mixed_kernels(
|
|
||||||
self.kernel_list2,
|
|
||||||
self.kernel_prob2,
|
|
||||||
kernel_size,
|
|
||||||
self.blur_sigma2,
|
|
||||||
self.blur_sigma2, [-math.pi, math.pi],
|
|
||||||
self.betag_range2,
|
|
||||||
self.betap_range2,
|
|
||||||
noise_range=None)
|
|
||||||
|
|
||||||
# pad kernel
|
|
||||||
pad_size = (21 - kernel_size) // 2
|
|
||||||
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
|
||||||
|
|
||||||
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
|
||||||
if np.random.uniform() < self.opt['final_sinc_prob']:
|
|
||||||
kernel_size = random.choice(self.kernel_range)
|
|
||||||
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
|
||||||
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
|
||||||
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
|
||||||
else:
|
|
||||||
sinc_kernel = self.pulse_tensor
|
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
||||||
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
|
||||||
kernel = torch.FloatTensor(kernel)
|
|
||||||
kernel2 = torch.FloatTensor(kernel2)
|
|
||||||
|
|
||||||
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
|
||||||
return return_d
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
@ -1,108 +0,0 @@
|
|||||||
import os
|
|
||||||
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
|
||||||
from basicsr.data.transforms import augment, paired_random_crop
|
|
||||||
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
|
||||||
from basicsr.utils.registry import DATASET_REGISTRY
|
|
||||||
from torch.utils import data as data
|
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET_REGISTRY.register()
|
|
||||||
class RealESRGANPairedDataset(data.Dataset):
|
|
||||||
"""Paired image dataset for image restoration.
|
|
||||||
|
|
||||||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
|
||||||
|
|
||||||
There are three modes:
|
|
||||||
1. 'lmdb': Use lmdb files.
|
|
||||||
If opt['io_backend'] == lmdb.
|
|
||||||
2. 'meta_info': Use meta information file to generate paths.
|
|
||||||
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
|
|
||||||
3. 'folder': Scan folders to generate paths.
|
|
||||||
The rest.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
opt (dict): Config for train datasets. It contains the following keys:
|
|
||||||
dataroot_gt (str): Data root path for gt.
|
|
||||||
dataroot_lq (str): Data root path for lq.
|
|
||||||
meta_info (str): Path for meta information file.
|
|
||||||
io_backend (dict): IO backend type and other kwarg.
|
|
||||||
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
|
||||||
Default: '{}'.
|
|
||||||
gt_size (int): Cropped patched size for gt patches.
|
|
||||||
use_hflip (bool): Use horizontal flips.
|
|
||||||
use_rot (bool): Use rotation (use vertical flip and transposing h
|
|
||||||
and w for implementation).
|
|
||||||
|
|
||||||
scale (bool): Scale, which will be added automatically.
|
|
||||||
phase (str): 'train' or 'val'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(RealESRGANPairedDataset, self).__init__()
|
|
||||||
self.opt = opt
|
|
||||||
self.file_client = None
|
|
||||||
self.io_backend_opt = opt['io_backend']
|
|
||||||
# mean and std for normalizing the input images
|
|
||||||
self.mean = opt['mean'] if 'mean' in opt else None
|
|
||||||
self.std = opt['std'] if 'std' in opt else None
|
|
||||||
|
|
||||||
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
|
||||||
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
|
||||||
|
|
||||||
# file client (lmdb io backend)
|
|
||||||
if self.io_backend_opt['type'] == 'lmdb':
|
|
||||||
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
|
||||||
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
|
||||||
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
|
||||||
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
|
||||||
# disk backend with meta_info
|
|
||||||
# Each line in the meta_info describes the relative path to an image
|
|
||||||
with open(self.opt['meta_info']) as fin:
|
|
||||||
paths = [line.strip() for line in fin]
|
|
||||||
self.paths = []
|
|
||||||
for path in paths:
|
|
||||||
gt_path, lq_path = path.split(', ')
|
|
||||||
gt_path = os.path.join(self.gt_folder, gt_path)
|
|
||||||
lq_path = os.path.join(self.lq_folder, lq_path)
|
|
||||||
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
|
||||||
else:
|
|
||||||
# disk backend
|
|
||||||
# it will scan the whole folder to get meta info
|
|
||||||
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
|
||||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.file_client is None:
|
|
||||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
|
||||||
|
|
||||||
scale = self.opt['scale']
|
|
||||||
|
|
||||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
|
||||||
# image range: [0, 1], float32.
|
|
||||||
gt_path = self.paths[index]['gt_path']
|
|
||||||
img_bytes = self.file_client.get(gt_path, 'gt')
|
|
||||||
img_gt = imfrombytes(img_bytes, float32=True)
|
|
||||||
lq_path = self.paths[index]['lq_path']
|
|
||||||
img_bytes = self.file_client.get(lq_path, 'lq')
|
|
||||||
img_lq = imfrombytes(img_bytes, float32=True)
|
|
||||||
|
|
||||||
# augmentation for training
|
|
||||||
if self.opt['phase'] == 'train':
|
|
||||||
gt_size = self.opt['gt_size']
|
|
||||||
# random crop
|
|
||||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
|
||||||
# flip, rotation
|
|
||||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
|
||||||
|
|
||||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
|
||||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
|
||||||
# normalize
|
|
||||||
if self.mean is not None or self.std is not None:
|
|
||||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
|
||||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
|
||||||
|
|
||||||
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
import importlib
|
|
||||||
from basicsr.utils import scandir
|
|
||||||
from os import path as osp
|
|
||||||
|
|
||||||
# automatically scan and import model modules for registry
|
|
||||||
# scan all the files that end with '_model.py' under the model folder
|
|
||||||
model_folder = osp.dirname(osp.abspath(__file__))
|
|
||||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
|
||||||
# import all the model modules
|
|
||||||
_model_modules = [importlib.import_module(f'thirdTool.realesrgan.models.{file_name}') for file_name in model_filenames]
|
|
||||||
@ -1,258 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
|
||||||
from basicsr.data.transforms import paired_random_crop
|
|
||||||
from basicsr.models.srgan_model import SRGANModel
|
|
||||||
from basicsr.utils import DiffJPEG, USMSharp
|
|
||||||
from basicsr.utils.img_process_util import filter2D
|
|
||||||
from basicsr.utils.registry import MODEL_REGISTRY
|
|
||||||
from collections import OrderedDict
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@MODEL_REGISTRY.register()
|
|
||||||
class RealESRGANModel(SRGANModel):
|
|
||||||
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
||||||
|
|
||||||
It mainly performs:
|
|
||||||
1. randomly synthesize LQ images in GPU tensors
|
|
||||||
2. optimize the networks with GAN training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(RealESRGANModel, self).__init__(opt)
|
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
|
||||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
|
||||||
self.queue_size = opt.get('queue_size', 180)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _dequeue_and_enqueue(self):
|
|
||||||
"""It is the training pair pool for increasing the diversity in a batch.
|
|
||||||
|
|
||||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
|
||||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
|
||||||
to increase the degradation diversity in a batch.
|
|
||||||
"""
|
|
||||||
# initialize
|
|
||||||
b, c, h, w = self.lq.size()
|
|
||||||
if not hasattr(self, 'queue_lr'):
|
|
||||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
|
||||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
||||||
_, c, h, w = self.gt.size()
|
|
||||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
||||||
self.queue_ptr = 0
|
|
||||||
if self.queue_ptr == self.queue_size: # the pool is full
|
|
||||||
# do dequeue and enqueue
|
|
||||||
# shuffle
|
|
||||||
idx = torch.randperm(self.queue_size)
|
|
||||||
self.queue_lr = self.queue_lr[idx]
|
|
||||||
self.queue_gt = self.queue_gt[idx]
|
|
||||||
# get first b samples
|
|
||||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
|
||||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
|
||||||
# update the queue
|
|
||||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
|
||||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
|
||||||
|
|
||||||
self.lq = lq_dequeue
|
|
||||||
self.gt = gt_dequeue
|
|
||||||
else:
|
|
||||||
# only do enqueue
|
|
||||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
|
||||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
|
||||||
self.queue_ptr = self.queue_ptr + b
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def feed_data(self, data):
|
|
||||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
|
||||||
"""
|
|
||||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
|
||||||
# training data synthesis
|
|
||||||
self.gt = data['gt'].to(self.device)
|
|
||||||
self.gt_usm = self.usm_sharpener(self.gt)
|
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
|
||||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
|
||||||
|
|
||||||
ori_h, ori_w = self.gt.size()[2:4]
|
|
||||||
|
|
||||||
# ----------------------- The first degradation process ----------------------- #
|
|
||||||
# blur
|
|
||||||
out = filter2D(self.gt_usm, self.kernel1)
|
|
||||||
# random resize
|
|
||||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
|
||||||
if updown_type == 'up':
|
|
||||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
|
||||||
elif updown_type == 'down':
|
|
||||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
|
||||||
# add noise
|
|
||||||
gray_noise_prob = self.opt['gray_noise_prob']
|
|
||||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
|
||||||
out = random_add_gaussian_noise_pt(
|
|
||||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
|
||||||
else:
|
|
||||||
out = random_add_poisson_noise_pt(
|
|
||||||
out,
|
|
||||||
scale_range=self.opt['poisson_scale_range'],
|
|
||||||
gray_prob=gray_noise_prob,
|
|
||||||
clip=True,
|
|
||||||
rounds=False)
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
|
||||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
|
|
||||||
# ----------------------- The second degradation process ----------------------- #
|
|
||||||
# blur
|
|
||||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
|
||||||
out = filter2D(out, self.kernel2)
|
|
||||||
# random resize
|
|
||||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
|
||||||
if updown_type == 'up':
|
|
||||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
|
||||||
elif updown_type == 'down':
|
|
||||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(
|
|
||||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
|
||||||
# add noise
|
|
||||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
|
||||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
|
||||||
out = random_add_gaussian_noise_pt(
|
|
||||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
|
||||||
else:
|
|
||||||
out = random_add_poisson_noise_pt(
|
|
||||||
out,
|
|
||||||
scale_range=self.opt['poisson_scale_range2'],
|
|
||||||
gray_prob=gray_noise_prob,
|
|
||||||
clip=True,
|
|
||||||
rounds=False)
|
|
||||||
|
|
||||||
# JPEG compression + the final sinc filter
|
|
||||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
|
||||||
# as one operation.
|
|
||||||
# We consider two orders:
|
|
||||||
# 1. [resize back + sinc filter] + JPEG compression
|
|
||||||
# 2. JPEG compression + [resize back + sinc filter]
|
|
||||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
|
||||||
if np.random.uniform() < 0.5:
|
|
||||||
# resize back + the final sinc filter
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
|
||||||
out = filter2D(out, self.sinc_kernel)
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
|
||||||
out = torch.clamp(out, 0, 1)
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
else:
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
|
||||||
out = torch.clamp(out, 0, 1)
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
# resize back + the final sinc filter
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
|
||||||
out = filter2D(out, self.sinc_kernel)
|
|
||||||
|
|
||||||
# clamp and round
|
|
||||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
gt_size = self.opt['gt_size']
|
|
||||||
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
|
||||||
self.opt['scale'])
|
|
||||||
|
|
||||||
# training pair pool
|
|
||||||
self._dequeue_and_enqueue()
|
|
||||||
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
|
||||||
self.gt_usm = self.usm_sharpener(self.gt)
|
|
||||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
|
||||||
else:
|
|
||||||
# for paired training or validation
|
|
||||||
self.lq = data['lq'].to(self.device)
|
|
||||||
if 'gt' in data:
|
|
||||||
self.gt = data['gt'].to(self.device)
|
|
||||||
self.gt_usm = self.usm_sharpener(self.gt)
|
|
||||||
|
|
||||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
|
||||||
# do not use the synthetic process during validation
|
|
||||||
self.is_train = False
|
|
||||||
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
|
||||||
self.is_train = True
|
|
||||||
|
|
||||||
def optimize_parameters(self, current_iter):
|
|
||||||
# usm sharpening
|
|
||||||
l1_gt = self.gt_usm
|
|
||||||
percep_gt = self.gt_usm
|
|
||||||
gan_gt = self.gt_usm
|
|
||||||
if self.opt['l1_gt_usm'] is False:
|
|
||||||
l1_gt = self.gt
|
|
||||||
if self.opt['percep_gt_usm'] is False:
|
|
||||||
percep_gt = self.gt
|
|
||||||
if self.opt['gan_gt_usm'] is False:
|
|
||||||
gan_gt = self.gt
|
|
||||||
|
|
||||||
# optimize net_g
|
|
||||||
for p in self.net_d.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
self.optimizer_g.zero_grad()
|
|
||||||
self.output = self.net_g(self.lq)
|
|
||||||
|
|
||||||
l_g_total = 0
|
|
||||||
loss_dict = OrderedDict()
|
|
||||||
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
|
||||||
# pixel loss
|
|
||||||
if self.cri_pix:
|
|
||||||
l_g_pix = self.cri_pix(self.output, l1_gt)
|
|
||||||
l_g_total += l_g_pix
|
|
||||||
loss_dict['l_g_pix'] = l_g_pix
|
|
||||||
# perceptual loss
|
|
||||||
if self.cri_perceptual:
|
|
||||||
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
|
||||||
if l_g_percep is not None:
|
|
||||||
l_g_total += l_g_percep
|
|
||||||
loss_dict['l_g_percep'] = l_g_percep
|
|
||||||
if l_g_style is not None:
|
|
||||||
l_g_total += l_g_style
|
|
||||||
loss_dict['l_g_style'] = l_g_style
|
|
||||||
# gan loss
|
|
||||||
fake_g_pred = self.net_d(self.output)
|
|
||||||
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
|
||||||
l_g_total += l_g_gan
|
|
||||||
loss_dict['l_g_gan'] = l_g_gan
|
|
||||||
|
|
||||||
l_g_total.backward()
|
|
||||||
self.optimizer_g.step()
|
|
||||||
|
|
||||||
# optimize net_d
|
|
||||||
for p in self.net_d.parameters():
|
|
||||||
p.requires_grad = True
|
|
||||||
|
|
||||||
self.optimizer_d.zero_grad()
|
|
||||||
# real
|
|
||||||
real_d_pred = self.net_d(gan_gt)
|
|
||||||
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
|
||||||
loss_dict['l_d_real'] = l_d_real
|
|
||||||
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
|
||||||
l_d_real.backward()
|
|
||||||
# fake
|
|
||||||
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
|
||||||
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
|
||||||
loss_dict['l_d_fake'] = l_d_fake
|
|
||||||
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
|
||||||
l_d_fake.backward()
|
|
||||||
self.optimizer_d.step()
|
|
||||||
|
|
||||||
if self.ema_decay > 0:
|
|
||||||
self.model_ema(decay=self.ema_decay)
|
|
||||||
|
|
||||||
self.log_dict = self.reduce_loss_dict(loss_dict)
|
|
||||||
@ -1,188 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
|
||||||
from basicsr.data.transforms import paired_random_crop
|
|
||||||
from basicsr.models.sr_model import SRModel
|
|
||||||
from basicsr.utils import DiffJPEG, USMSharp
|
|
||||||
from basicsr.utils.img_process_util import filter2D
|
|
||||||
from basicsr.utils.registry import MODEL_REGISTRY
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
@MODEL_REGISTRY.register()
|
|
||||||
class RealESRNetModel(SRModel):
|
|
||||||
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
|
||||||
|
|
||||||
It is trained without GAN losses.
|
|
||||||
It mainly performs:
|
|
||||||
1. randomly synthesize LQ images in GPU tensors
|
|
||||||
2. optimize the networks with GAN training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, opt):
|
|
||||||
super(RealESRNetModel, self).__init__(opt)
|
|
||||||
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
|
||||||
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
|
||||||
self.queue_size = opt.get('queue_size', 180)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _dequeue_and_enqueue(self):
|
|
||||||
"""It is the training pair pool for increasing the diversity in a batch.
|
|
||||||
|
|
||||||
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
|
||||||
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
|
||||||
to increase the degradation diversity in a batch.
|
|
||||||
"""
|
|
||||||
# initialize
|
|
||||||
b, c, h, w = self.lq.size()
|
|
||||||
if not hasattr(self, 'queue_lr'):
|
|
||||||
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
|
||||||
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
||||||
_, c, h, w = self.gt.size()
|
|
||||||
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
|
||||||
self.queue_ptr = 0
|
|
||||||
if self.queue_ptr == self.queue_size: # the pool is full
|
|
||||||
# do dequeue and enqueue
|
|
||||||
# shuffle
|
|
||||||
idx = torch.randperm(self.queue_size)
|
|
||||||
self.queue_lr = self.queue_lr[idx]
|
|
||||||
self.queue_gt = self.queue_gt[idx]
|
|
||||||
# get first b samples
|
|
||||||
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
|
||||||
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
|
||||||
# update the queue
|
|
||||||
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
|
||||||
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
|
||||||
|
|
||||||
self.lq = lq_dequeue
|
|
||||||
self.gt = gt_dequeue
|
|
||||||
else:
|
|
||||||
# only do enqueue
|
|
||||||
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
|
||||||
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
|
||||||
self.queue_ptr = self.queue_ptr + b
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def feed_data(self, data):
|
|
||||||
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
|
||||||
"""
|
|
||||||
if self.is_train and self.opt.get('high_order_degradation', True):
|
|
||||||
# training data synthesis
|
|
||||||
self.gt = data['gt'].to(self.device)
|
|
||||||
# USM sharpen the GT images
|
|
||||||
if self.opt['gt_usm'] is True:
|
|
||||||
self.gt = self.usm_sharpener(self.gt)
|
|
||||||
|
|
||||||
self.kernel1 = data['kernel1'].to(self.device)
|
|
||||||
self.kernel2 = data['kernel2'].to(self.device)
|
|
||||||
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
|
||||||
|
|
||||||
ori_h, ori_w = self.gt.size()[2:4]
|
|
||||||
|
|
||||||
# ----------------------- The first degradation process ----------------------- #
|
|
||||||
# blur
|
|
||||||
out = filter2D(self.gt, self.kernel1)
|
|
||||||
# random resize
|
|
||||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
|
||||||
if updown_type == 'up':
|
|
||||||
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
|
||||||
elif updown_type == 'down':
|
|
||||||
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
|
||||||
# add noise
|
|
||||||
gray_noise_prob = self.opt['gray_noise_prob']
|
|
||||||
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
|
||||||
out = random_add_gaussian_noise_pt(
|
|
||||||
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
|
||||||
else:
|
|
||||||
out = random_add_poisson_noise_pt(
|
|
||||||
out,
|
|
||||||
scale_range=self.opt['poisson_scale_range'],
|
|
||||||
gray_prob=gray_noise_prob,
|
|
||||||
clip=True,
|
|
||||||
rounds=False)
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
|
||||||
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
|
|
||||||
# ----------------------- The second degradation process ----------------------- #
|
|
||||||
# blur
|
|
||||||
if np.random.uniform() < self.opt['second_blur_prob']:
|
|
||||||
out = filter2D(out, self.kernel2)
|
|
||||||
# random resize
|
|
||||||
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
|
||||||
if updown_type == 'up':
|
|
||||||
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
|
||||||
elif updown_type == 'down':
|
|
||||||
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(
|
|
||||||
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
|
||||||
# add noise
|
|
||||||
gray_noise_prob = self.opt['gray_noise_prob2']
|
|
||||||
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
|
||||||
out = random_add_gaussian_noise_pt(
|
|
||||||
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
|
||||||
else:
|
|
||||||
out = random_add_poisson_noise_pt(
|
|
||||||
out,
|
|
||||||
scale_range=self.opt['poisson_scale_range2'],
|
|
||||||
gray_prob=gray_noise_prob,
|
|
||||||
clip=True,
|
|
||||||
rounds=False)
|
|
||||||
|
|
||||||
# JPEG compression + the final sinc filter
|
|
||||||
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
|
||||||
# as one operation.
|
|
||||||
# We consider two orders:
|
|
||||||
# 1. [resize back + sinc filter] + JPEG compression
|
|
||||||
# 2. JPEG compression + [resize back + sinc filter]
|
|
||||||
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
|
||||||
if np.random.uniform() < 0.5:
|
|
||||||
# resize back + the final sinc filter
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
|
||||||
out = filter2D(out, self.sinc_kernel)
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
|
||||||
out = torch.clamp(out, 0, 1)
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
else:
|
|
||||||
# JPEG compression
|
|
||||||
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
|
||||||
out = torch.clamp(out, 0, 1)
|
|
||||||
out = self.jpeger(out, quality=jpeg_p)
|
|
||||||
# resize back + the final sinc filter
|
|
||||||
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
|
||||||
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
|
||||||
out = filter2D(out, self.sinc_kernel)
|
|
||||||
|
|
||||||
# clamp and round
|
|
||||||
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
gt_size = self.opt['gt_size']
|
|
||||||
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
|
||||||
|
|
||||||
# training pair pool
|
|
||||||
self._dequeue_and_enqueue()
|
|
||||||
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
|
||||||
else:
|
|
||||||
# for paired training or validation
|
|
||||||
self.lq = data['lq'].to(self.device)
|
|
||||||
if 'gt' in data:
|
|
||||||
self.gt = data['gt'].to(self.device)
|
|
||||||
self.gt_usm = self.usm_sharpener(self.gt)
|
|
||||||
|
|
||||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
|
||||||
# do not use the synthetic process during validation
|
|
||||||
self.is_train = False
|
|
||||||
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
|
||||||
self.is_train = True
|
|
||||||
@ -1,11 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
import os.path as osp
|
|
||||||
from basicsr.train import train_pipeline
|
|
||||||
|
|
||||||
import realesrgan.archs
|
|
||||||
import realesrgan.data
|
|
||||||
import realesrgan.models
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
|
||||||
train_pipeline(root_path)
|
|
||||||
@ -1,287 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
import torch
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from thirdTool.yolo5x_qr.utils.torch_utils import select_device
|
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
|
|
||||||
class Real_ESRGANer():
|
|
||||||
"""A helper class for upsampling images with RealESRGAN.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
|
||||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
|
||||||
model (nn.Module): The defined network. Default: None.
|
|
||||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
|
||||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
|
||||||
0 denotes for do not use tile. Default: 0.
|
|
||||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
|
||||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
|
||||||
half (float): Whether to use half precision during inference. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale, device, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
|
||||||
self.scale = scale
|
|
||||||
self.tile_size = tile
|
|
||||||
self.tile_pad = tile_pad
|
|
||||||
self.pre_pad = pre_pad
|
|
||||||
self.mod_scale = None
|
|
||||||
self.half = half
|
|
||||||
|
|
||||||
# initialize model
|
|
||||||
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
self.device = select_device(device, model_path=model_path)
|
|
||||||
|
|
||||||
# self.device = torch.device('cpu')
|
|
||||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
|
||||||
if model_path.startswith('https://'):
|
|
||||||
model_path = load_file_from_url(
|
|
||||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
|
||||||
|
|
||||||
# loadnet = torch.load(model_path,map_location=torch.device('cpu'))
|
|
||||||
loadnet = torch.load(model_path, map_location=self.device)
|
|
||||||
|
|
||||||
# prefer to use params_ema
|
|
||||||
if 'params_ema' in loadnet:
|
|
||||||
keyname = 'params_ema'
|
|
||||||
else:
|
|
||||||
keyname = 'params'
|
|
||||||
model.load_state_dict(loadnet[keyname], strict=True)
|
|
||||||
model.eval()
|
|
||||||
self.model = model.to(self.device)
|
|
||||||
if self.half:
|
|
||||||
self.model = self.model.half()
|
|
||||||
|
|
||||||
def pre_process(self, img):
|
|
||||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
|
||||||
"""
|
|
||||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
|
||||||
self.img = img.unsqueeze(0).to(self.device)
|
|
||||||
if self.half:
|
|
||||||
self.img = self.img.half()
|
|
||||||
|
|
||||||
# pre_pad
|
|
||||||
if self.pre_pad != 0:
|
|
||||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
|
||||||
# mod pad for divisible borders
|
|
||||||
if self.scale == 2:
|
|
||||||
self.mod_scale = 2
|
|
||||||
elif self.scale == 1:
|
|
||||||
self.mod_scale = 4
|
|
||||||
if self.mod_scale is not None:
|
|
||||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
|
||||||
_, _, h, w = self.img.size()
|
|
||||||
if (h % self.mod_scale != 0):
|
|
||||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
|
||||||
if (w % self.mod_scale != 0):
|
|
||||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
|
||||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
# model inference
|
|
||||||
self.output = self.model(self.img)
|
|
||||||
|
|
||||||
def tile_process(self):
|
|
||||||
"""It will first crop input images to tiles, and then process each tile.
|
|
||||||
Finally, all the processed tiles are merged into one images.
|
|
||||||
|
|
||||||
Modified from: https://github.com/ata4/esrgan-launcher
|
|
||||||
"""
|
|
||||||
batch, channel, height, width = self.img.shape
|
|
||||||
output_height = height * self.scale
|
|
||||||
output_width = width * self.scale
|
|
||||||
output_shape = (batch, channel, output_height, output_width)
|
|
||||||
|
|
||||||
# start with black image
|
|
||||||
self.output = self.img.new_zeros(output_shape)
|
|
||||||
tiles_x = math.ceil(width / self.tile_size)
|
|
||||||
tiles_y = math.ceil(height / self.tile_size)
|
|
||||||
|
|
||||||
# loop over all tiles
|
|
||||||
for y in range(tiles_y):
|
|
||||||
for x in range(tiles_x):
|
|
||||||
# extract tile from input image
|
|
||||||
ofs_x = x * self.tile_size
|
|
||||||
ofs_y = y * self.tile_size
|
|
||||||
# input tile area on total image
|
|
||||||
input_start_x = ofs_x
|
|
||||||
input_end_x = min(ofs_x + self.tile_size, width)
|
|
||||||
input_start_y = ofs_y
|
|
||||||
input_end_y = min(ofs_y + self.tile_size, height)
|
|
||||||
|
|
||||||
# input tile area on total image with padding
|
|
||||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
|
||||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
|
||||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
|
||||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
|
||||||
|
|
||||||
# input tile dimensions
|
|
||||||
input_tile_width = input_end_x - input_start_x
|
|
||||||
input_tile_height = input_end_y - input_start_y
|
|
||||||
tile_idx = y * tiles_x + x + 1
|
|
||||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
|
||||||
|
|
||||||
# upscale tile
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output_tile = self.model(input_tile)
|
|
||||||
except RuntimeError as error:
|
|
||||||
print('Error', error)
|
|
||||||
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
|
||||||
|
|
||||||
# output tile area on total image
|
|
||||||
output_start_x = input_start_x * self.scale
|
|
||||||
output_end_x = input_end_x * self.scale
|
|
||||||
output_start_y = input_start_y * self.scale
|
|
||||||
output_end_y = input_end_y * self.scale
|
|
||||||
|
|
||||||
# output tile area without padding
|
|
||||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
|
||||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
|
||||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
|
||||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
|
||||||
|
|
||||||
# put tile into output image
|
|
||||||
self.output[:, :, output_start_y:output_end_y,
|
|
||||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
|
||||||
output_start_x_tile:output_end_x_tile]
|
|
||||||
|
|
||||||
def post_process(self):
|
|
||||||
# remove extra pad
|
|
||||||
if self.mod_scale is not None:
|
|
||||||
_, _, h, w = self.output.size()
|
|
||||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
|
||||||
# remove prepad
|
|
||||||
if self.pre_pad != 0:
|
|
||||||
_, _, h, w = self.output.size()
|
|
||||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
|
||||||
return self.output
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
|
||||||
h_input, w_input = img.shape[0:2]
|
|
||||||
# img: numpy
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
if np.max(img) > 256: # 16-bit image
|
|
||||||
max_range = 65535
|
|
||||||
print('\tInput is a 16-bit image')
|
|
||||||
else:
|
|
||||||
max_range = 255
|
|
||||||
img = img / max_range
|
|
||||||
if len(img.shape) == 2: # gray image
|
|
||||||
img_mode = 'L'
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
||||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
|
||||||
img_mode = 'RGBA'
|
|
||||||
alpha = img[:, :, 3]
|
|
||||||
img = img[:, :, 0:3]
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
if alpha_upsampler == 'realesrgan':
|
|
||||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
|
||||||
else:
|
|
||||||
img_mode = 'RGB'
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# ------------------- process image (without the alpha channel) ------------------- #
|
|
||||||
self.pre_process(img)
|
|
||||||
if self.tile_size > 0:
|
|
||||||
self.tile_process()
|
|
||||||
else:
|
|
||||||
self.process()
|
|
||||||
output_img = self.post_process()
|
|
||||||
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
if img_mode == 'L':
|
|
||||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# ------------------- process the alpha channel if necessary ------------------- #
|
|
||||||
if img_mode == 'RGBA':
|
|
||||||
if alpha_upsampler == 'realesrgan':
|
|
||||||
self.pre_process(alpha)
|
|
||||||
if self.tile_size > 0:
|
|
||||||
self.tile_process()
|
|
||||||
else:
|
|
||||||
self.process()
|
|
||||||
output_alpha = self.post_process()
|
|
||||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
|
||||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
|
||||||
else: # use the cv2 resize for alpha channel
|
|
||||||
h, w = alpha.shape[0:2]
|
|
||||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
# merge the alpha channel
|
|
||||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
|
||||||
output_img[:, :, 3] = output_alpha
|
|
||||||
|
|
||||||
# ------------------------------ return ------------------------------ #
|
|
||||||
if max_range == 65535: # 16-bit image
|
|
||||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
|
||||||
else:
|
|
||||||
output = (output_img * 255.0).round().astype(np.uint8)
|
|
||||||
|
|
||||||
if outscale is not None and outscale != float(self.scale):
|
|
||||||
output = cv2.resize(
|
|
||||||
output, (
|
|
||||||
int(w_input * outscale),
|
|
||||||
int(h_input * outscale),
|
|
||||||
), interpolation=cv2.INTER_LANCZOS4)
|
|
||||||
|
|
||||||
return output, img_mode
|
|
||||||
|
|
||||||
|
|
||||||
class PrefetchReader(threading.Thread):
|
|
||||||
"""Prefetch images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_list (list[str]): A image list of image paths to be read.
|
|
||||||
num_prefetch_queue (int): Number of prefetch queue.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_list, num_prefetch_queue):
|
|
||||||
super().__init__()
|
|
||||||
self.que = queue.Queue(num_prefetch_queue)
|
|
||||||
self.img_list = img_list
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
for img_path in self.img_list:
|
|
||||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
|
||||||
self.que.put(img)
|
|
||||||
|
|
||||||
self.que.put(None)
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
next_item = self.que.get()
|
|
||||||
if next_item is None:
|
|
||||||
raise StopIteration
|
|
||||||
return next_item
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class IOConsumer(threading.Thread):
|
|
||||||
|
|
||||||
def __init__(self, opt, que, qid):
|
|
||||||
super().__init__()
|
|
||||||
self._queue = que
|
|
||||||
self.qid = qid
|
|
||||||
self.opt = opt
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
while True:
|
|
||||||
msg = self._queue.get()
|
|
||||||
if isinstance(msg, str) and msg == 'quit':
|
|
||||||
break
|
|
||||||
|
|
||||||
output = msg['output']
|
|
||||||
save_path = msg['save_path']
|
|
||||||
cv2.imwrite(save_path, output)
|
|
||||||
print(f'IO worker {self.qid} is done.')
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
# GENERATED VERSION FILE
|
|
||||||
# TIME: Thu Dec 23 16:27:38 2021
|
|
||||||
__version__ = '0.2.3.0'
|
|
||||||
__gitsha__ = '3e65d21'
|
|
||||||
version_info = (0, 2, 3, 0)
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
# Weights
|
|
||||||
|
|
||||||
Put the downloaded weights to this folder.
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : __init__.py.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2023/8/2 15:55 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from .models import *
|
|
||||||
from .utils import *
|
|
||||||
@ -1,221 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.autograd import Function, Variable
|
|
||||||
from torch.nn import Module, parameter
|
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
try:
|
|
||||||
from queue import Queue
|
|
||||||
except ImportError:
|
|
||||||
from Queue import Queue
|
|
||||||
|
|
||||||
# from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
from timm.layers import DropPath, trunc_normal_
|
|
||||||
# from timm import register_model
|
|
||||||
# from timm.layers.helpers import to_2tuple
|
|
||||||
# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
||||||
|
|
||||||
|
|
||||||
# LVC
|
|
||||||
class Encoding(nn.Module):
|
|
||||||
def __init__(self, in_channels, num_codes):
|
|
||||||
super(Encoding, self).__init__()
|
|
||||||
# init codewords and smoothing factor
|
|
||||||
self.in_channels, self.num_codes = in_channels, num_codes
|
|
||||||
num_codes = 64
|
|
||||||
std = 1. / ((num_codes * in_channels)**0.5)
|
|
||||||
# [num_codes, channels]
|
|
||||||
self.codewords = nn.Parameter(
|
|
||||||
torch.empty(num_codes, in_channels, dtype=torch.float).uniform_(-std, std), requires_grad=True)
|
|
||||||
# [num_codes]
|
|
||||||
self.scale = nn.Parameter(torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def scaled_l2(x, codewords, scale):
|
|
||||||
num_codes, in_channels = codewords.size()
|
|
||||||
b = x.size(0)
|
|
||||||
expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels))
|
|
||||||
|
|
||||||
# ---处理codebook (num_code, c1)
|
|
||||||
reshaped_codewords = codewords.view((1, 1, num_codes, in_channels))
|
|
||||||
|
|
||||||
# 把scale从1, num_code变成 batch, c2, N, num_codes
|
|
||||||
reshaped_scale = scale.view((1, 1, num_codes)) # N, num_codes
|
|
||||||
|
|
||||||
# ---计算rik = z1 - d # b, N, num_codes
|
|
||||||
scaled_l2_norm = reshaped_scale * (expanded_x - reshaped_codewords).pow(2).sum(dim=3)
|
|
||||||
return scaled_l2_norm
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def aggregate(assignment_weights, x, codewords):
|
|
||||||
num_codes, in_channels = codewords.size()
|
|
||||||
|
|
||||||
# ---处理codebook
|
|
||||||
reshaped_codewords = codewords.view((1, 1, num_codes, in_channels))
|
|
||||||
b = x.size(0)
|
|
||||||
|
|
||||||
# ---处理特征向量x b, c1, N
|
|
||||||
expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels))
|
|
||||||
|
|
||||||
#变换rei b, N, num_codes,-
|
|
||||||
assignment_weights = assignment_weights.unsqueeze(3) # b, N, num_codes,
|
|
||||||
|
|
||||||
# ---开始计算eik,必须在Rei计算完之后
|
|
||||||
encoded_feat = (assignment_weights * (expanded_x - reshaped_codewords)).sum(1)
|
|
||||||
return encoded_feat
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
assert x.dim() == 4 and x.size(1) == self.in_channels
|
|
||||||
b, in_channels, w, h = x.size()
|
|
||||||
|
|
||||||
# [batch_size, height x width, channels]
|
|
||||||
x = x.view(b, self.in_channels, -1).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# assignment_weights: [batch_size, channels, num_codes]
|
|
||||||
assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
|
|
||||||
return encoded_feat
|
|
||||||
|
|
||||||
|
|
||||||
# 1*1 3*3 1*1
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
|
|
||||||
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
|
|
||||||
super(ConvBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
expansion = 4
|
|
||||||
c = out_channels // expansion
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_channels, c, kernel_size=1, stride=1, padding=0, bias=False) # [64, 256, 1, 1]
|
|
||||||
self.bn1 = norm_layer(c)
|
|
||||||
self.act1 = act_layer(inplace=True)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(c, c, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False)
|
|
||||||
self.bn2 = norm_layer(c)
|
|
||||||
self.act2 = act_layer(inplace=True)
|
|
||||||
|
|
||||||
self.conv3 = nn.Conv2d(c, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
|
||||||
self.bn3 = norm_layer(out_channels)
|
|
||||||
self.act3 = act_layer(inplace=True)
|
|
||||||
|
|
||||||
if res_conv:
|
|
||||||
self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
|
||||||
self.residual_bn = norm_layer(out_channels)
|
|
||||||
|
|
||||||
self.res_conv = res_conv
|
|
||||||
self.drop_block = drop_block
|
|
||||||
self.drop_path = drop_path
|
|
||||||
|
|
||||||
def zero_init_last_bn(self):
|
|
||||||
nn.init.zeros_(self.bn3.weight)
|
|
||||||
|
|
||||||
def forward(self, x, return_x_2=True):
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
if self.drop_block is not None:
|
|
||||||
x = self.drop_block(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
|
|
||||||
x = self.conv2(x) #if x_t_r is None else self.conv2(x + x_t_r)
|
|
||||||
x = self.bn2(x)
|
|
||||||
if self.drop_block is not None:
|
|
||||||
x = self.drop_block(x)
|
|
||||||
x2 = self.act2(x)
|
|
||||||
|
|
||||||
x = self.conv3(x2)
|
|
||||||
x = self.bn3(x)
|
|
||||||
if self.drop_block is not None:
|
|
||||||
x = self.drop_block(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
if self.res_conv:
|
|
||||||
residual = self.residual_conv(residual)
|
|
||||||
residual = self.residual_bn(residual)
|
|
||||||
|
|
||||||
x += residual
|
|
||||||
x = self.act3(x)
|
|
||||||
|
|
||||||
if return_x_2:
|
|
||||||
return x, x2
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Mean(nn.Module):
|
|
||||||
def __init__(self, dim, keep_dim=False):
|
|
||||||
super(Mean, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.keep_dim = keep_dim
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return input.mean(self.dim, self.keep_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W]
|
|
||||||
"""
|
|
||||||
def __init__(self, in_features, hidden_features=None,
|
|
||||||
out_features=None, act_layer=nn.GELU, drop=0.):
|
|
||||||
super().__init__()
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = hidden_features or in_features
|
|
||||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
|
||||||
self.act = act_layer()
|
|
||||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormChannel(nn.Module):
|
|
||||||
"""
|
|
||||||
LayerNorm only for Channel Dimension.
|
|
||||||
Input: tensor in shape [B, C, H, W]
|
|
||||||
"""
|
|
||||||
def __init__(self, num_channels, eps=1e-05):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
u = x.mean(1, keepdim=True)
|
|
||||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
||||||
x = (x - u) / torch.sqrt(s + self.eps)
|
|
||||||
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
|
|
||||||
+ self.bias.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(nn.GroupNorm):
|
|
||||||
"""
|
|
||||||
Group Normalization with 1 group.
|
|
||||||
Input: tensor in shape [B, C, H, W]
|
|
||||||
"""
|
|
||||||
def __init__(self, num_channels, **kwargs):
|
|
||||||
super().__init__(1, num_channels, **kwargs)
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
'''
|
|
||||||
@File : __init__.py.py
|
|
||||||
@Contact : zpyovo@hotmail.com
|
|
||||||
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
||||||
@Description :
|
|
||||||
|
|
||||||
@Modify Time @Author @Version @Desciption
|
|
||||||
------------ ------- -------- -----------
|
|
||||||
2023/8/2 15:55 Pengyu Zhang 1.0 None
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
import importlib
|
|
||||||
from basicsr.utils import scandir
|
|
||||||
from os import path as osp
|
|
||||||
|
|
||||||
# automatically scan and import arch modules for registry
|
|
||||||
# scan all the files that end with '_arch.py' under the archs folder
|
|
||||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
|
||||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_modles.py')]
|
|
||||||
# import all the arch modules
|
|
||||||
_arch_modules = [importlib.import_module(f'yolo5x_qr.models.{file_name}') for file_name in arch_filenames]
|
|
||||||
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Experimental modules
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ..utils.downloads import attempt_download
|
|
||||||
|
|
||||||
|
|
||||||
class Sum(nn.Module):
|
|
||||||
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
|
||||||
def __init__(self, n, weight=False): # n: number of inputs
|
|
||||||
super().__init__()
|
|
||||||
self.weight = weight # apply weights boolean
|
|
||||||
self.iter = range(n - 1) # iter object
|
|
||||||
if weight:
|
|
||||||
self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = x[0] # no weight
|
|
||||||
if self.weight:
|
|
||||||
w = torch.sigmoid(self.w) * 2
|
|
||||||
for i in self.iter:
|
|
||||||
y = y + x[i + 1] * w[i]
|
|
||||||
else:
|
|
||||||
for i in self.iter:
|
|
||||||
y = y + x[i + 1]
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class MixConv2d(nn.Module):
|
|
||||||
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
|
|
||||||
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
|
|
||||||
super().__init__()
|
|
||||||
n = len(k) # number of convolutions
|
|
||||||
if equal_ch: # equal c_ per group
|
|
||||||
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
|
|
||||||
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
|
|
||||||
else: # equal weight.numel() per group
|
|
||||||
b = [c2] + [0] * n
|
|
||||||
a = np.eye(n + 1, n, k=-1)
|
|
||||||
a -= np.roll(a, 1, axis=1)
|
|
||||||
a *= np.array(k) ** 2
|
|
||||||
a[0] = 1
|
|
||||||
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
|
||||||
|
|
||||||
self.m = nn.ModuleList([
|
|
||||||
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
|
|
||||||
self.bn = nn.BatchNorm2d(c2)
|
|
||||||
self.act = nn.SiLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
|
||||||
# Ensemble of models
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
||||||
y = [module(x, augment, profile, visualize)[0] for module in self]
|
|
||||||
# y = torch.stack(y).max(0)[0] # max ensemble
|
|
||||||
# y = torch.stack(y).mean(0) # mean ensemble
|
|
||||||
y = torch.cat(y, 1) # nms ensemble
|
|
||||||
return y, None # inference, train output
|
|
||||||
|
|
||||||
|
|
||||||
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
|
||||||
from ..models.yolo import Detect, Model
|
|
||||||
|
|
||||||
model = Ensemble()
|
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
||||||
|
|
||||||
# Model compatibility updates
|
|
||||||
if not hasattr(ckpt, 'stride'):
|
|
||||||
ckpt.stride = torch.tensor([32.])
|
|
||||||
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
|
||||||
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
|
||||||
|
|
||||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
|
||||||
|
|
||||||
# Module compatibility updates
|
|
||||||
for m in model.modules():
|
|
||||||
t = type(m)
|
|
||||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
|
||||||
m.inplace = inplace # torch 1.7.0 compatibility
|
|
||||||
if t is Detect and not isinstance(m.anchor_grid, list):
|
|
||||||
delattr(m, 'anchor_grid')
|
|
||||||
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
|
||||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
|
||||||
|
|
||||||
# Return model
|
|
||||||
if len(model) == 1:
|
|
||||||
return model[-1]
|
|
||||||
|
|
||||||
# Return detection ensemble
|
|
||||||
print(f'Ensemble created with {weights}\n')
|
|
||||||
for k in 'names', 'nc', 'yaml':
|
|
||||||
setattr(model, k, getattr(model[0], k))
|
|
||||||
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
|
||||||
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
|
||||||
return model
|
|
||||||
@ -1,396 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
YOLO-specific modules
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
$ python models/yolo.py --cfg yolov5s.yaml
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import contextlib
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
|
||||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
||||||
|
|
||||||
from thirdTool.yolo5x_qr.models.common import *
|
|
||||||
from thirdTool.yolo5x_qr.models.experimental import *
|
|
||||||
from thirdTool.yolo5x_qr.utils.autoanchor import check_anchor_order
|
|
||||||
from thirdTool.yolo5x_qr.utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
|
|
||||||
from thirdTool.yolo5x_qr.utils.plots import feature_visualization
|
|
||||||
from thirdTool.yolo5x_qr.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
|
|
||||||
time_sync)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import thop # for FLOPs computation
|
|
||||||
except ImportError:
|
|
||||||
thop = None
|
|
||||||
|
|
||||||
|
|
||||||
class Detect(nn.Module):
|
|
||||||
# YOLOv5 Detect head for detection models
|
|
||||||
stride = None # strides computed during build
|
|
||||||
dynamic = False # force grid reconstruction
|
|
||||||
export = False # export mode
|
|
||||||
|
|
||||||
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
|
||||||
super().__init__()
|
|
||||||
self.nc = nc # number of classes
|
|
||||||
self.no = nc + 5 # number of outputs per anchor
|
|
||||||
self.nl = len(anchors) # number of detection layers
|
|
||||||
self.na = len(anchors[0]) // 2 # number of anchors
|
|
||||||
self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
|
|
||||||
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
|
|
||||||
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
|
||||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
||||||
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
z = [] # inference output
|
|
||||||
for i in range(self.nl):
|
|
||||||
x[i] = self.m[i](x[i]) # conv
|
|
||||||
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
|
||||||
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
||||||
|
|
||||||
if not self.training: # inference
|
|
||||||
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
|
||||||
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
|
||||||
|
|
||||||
if isinstance(self, Segment): # (boxes + masks)
|
|
||||||
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
|
|
||||||
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
|
|
||||||
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
|
|
||||||
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
|
|
||||||
else: # Detect (boxes only)
|
|
||||||
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
|
|
||||||
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
|
|
||||||
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
|
|
||||||
y = torch.cat((xy, wh, conf), 4)
|
|
||||||
z.append(y.view(bs, self.na * nx * ny, self.no))
|
|
||||||
|
|
||||||
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
|
||||||
|
|
||||||
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
|
|
||||||
d = self.anchors[i].device
|
|
||||||
t = self.anchors[i].dtype
|
|
||||||
shape = 1, self.na, ny, nx, 2 # grid shape
|
|
||||||
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
|
||||||
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
|
|
||||||
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
|
||||||
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
|
||||||
return grid, anchor_grid
|
|
||||||
|
|
||||||
|
|
||||||
class Segment(Detect):
|
|
||||||
# YOLOv5 Segment head for segmentation models
|
|
||||||
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
|
|
||||||
super().__init__(nc, anchors, ch, inplace)
|
|
||||||
self.nm = nm # number of masks
|
|
||||||
self.npr = npr # number of protos
|
|
||||||
self.no = 5 + nc + self.nm # number of outputs per anchor
|
|
||||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
||||||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
|
||||||
self.detect = Detect.forward
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
p = self.proto(x[0])
|
|
||||||
x = self.detect(self, x)
|
|
||||||
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
|
||||||
# YOLOv5 base model
|
|
||||||
def forward(self, x, profile=False, visualize=False):
|
|
||||||
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
|
||||||
|
|
||||||
def _forward_once(self, x, profile=False, visualize=False):
|
|
||||||
y, dt = [], [] # outputs
|
|
||||||
for m in self.model:
|
|
||||||
if m.f != -1: # if not from previous layer
|
|
||||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
|
||||||
if profile:
|
|
||||||
self._profile_one_layer(m, x, dt)
|
|
||||||
x = m(x) # run
|
|
||||||
y.append(x if m.i in self.save else None) # save output
|
|
||||||
if visualize:
|
|
||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
|
||||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
|
||||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
|
||||||
t = time_sync()
|
|
||||||
for _ in range(10):
|
|
||||||
m(x.copy() if c else x)
|
|
||||||
dt.append((time_sync() - t) * 100)
|
|
||||||
if m == self.model[0]:
|
|
||||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
|
||||||
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
|
||||||
if c:
|
|
||||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
||||||
|
|
||||||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
|
||||||
LOGGER.info('Fusing layers... ')
|
|
||||||
for m in self.model.modules():
|
|
||||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
||||||
delattr(m, 'bn') # remove batchnorm
|
|
||||||
m.forward = m.forward_fuse # update forward
|
|
||||||
self.info()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def info(self, verbose=False, img_size=640): # print model information
|
|
||||||
model_info(self, verbose, img_size)
|
|
||||||
|
|
||||||
def _apply(self, fn):
|
|
||||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
|
||||||
self = super()._apply(fn)
|
|
||||||
m = self.model[-1] # Detect()
|
|
||||||
if isinstance(m, (Detect, Segment)):
|
|
||||||
m.stride = fn(m.stride)
|
|
||||||
m.grid = list(map(fn, m.grid))
|
|
||||||
if isinstance(m.anchor_grid, list):
|
|
||||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(BaseModel):
|
|
||||||
# YOLOv5 detection model
|
|
||||||
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(cfg, dict):
|
|
||||||
self.yaml = cfg # model dict
|
|
||||||
else: # is *.yaml
|
|
||||||
import yaml # for torch hub
|
|
||||||
self.yaml_file = Path(cfg).name
|
|
||||||
with open(cfg, encoding='ascii', errors='ignore') as f:
|
|
||||||
self.yaml = yaml.safe_load(f) # model dict
|
|
||||||
|
|
||||||
# Define model
|
|
||||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
|
||||||
if nc and nc != self.yaml['nc']:
|
|
||||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
|
||||||
self.yaml['nc'] = nc # override yaml value
|
|
||||||
if anchors:
|
|
||||||
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
|
|
||||||
self.yaml['anchors'] = round(anchors) # override yaml value
|
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
|
||||||
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
|
||||||
self.inplace = self.yaml.get('inplace', True)
|
|
||||||
|
|
||||||
# Build strides, anchors
|
|
||||||
m = self.model[-1] # Detect()
|
|
||||||
if isinstance(m, (Detect, Segment)):
|
|
||||||
s = 256 # 2x min stride
|
|
||||||
m.inplace = self.inplace
|
|
||||||
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
|
||||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
|
||||||
check_anchor_order(m)
|
|
||||||
m.anchors /= m.stride.view(-1, 1, 1)
|
|
||||||
self.stride = m.stride
|
|
||||||
self._initialize_biases() # only run once
|
|
||||||
|
|
||||||
# Init weights, biases
|
|
||||||
initialize_weights(self)
|
|
||||||
self.info()
|
|
||||||
LOGGER.info('')
|
|
||||||
|
|
||||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
||||||
if augment:
|
|
||||||
return self._forward_augment(x) # augmented inference, None
|
|
||||||
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
|
||||||
img_size = x.shape[-2:] # height, width
|
|
||||||
s = [1, 0.83, 0.67] # scales
|
|
||||||
f = [None, 3, None] # flips (2-ud, 3-lr)
|
|
||||||
y = [] # outputs
|
|
||||||
for si, fi in zip(s, f):
|
|
||||||
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
|
||||||
yi = self._forward_once(xi)[0] # forward
|
|
||||||
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
|
||||||
yi = self._descale_pred(yi, fi, si, img_size)
|
|
||||||
y.append(yi)
|
|
||||||
y = self._clip_augmented(y) # clip augmented tails
|
|
||||||
return torch.cat(y, 1), None # augmented inference, train
|
|
||||||
|
|
||||||
def _descale_pred(self, p, flips, scale, img_size):
|
|
||||||
# de-scale predictions following augmented inference (inverse operation)
|
|
||||||
if self.inplace:
|
|
||||||
p[..., :4] /= scale # de-scale
|
|
||||||
if flips == 2:
|
|
||||||
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
|
||||||
elif flips == 3:
|
|
||||||
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
|
||||||
else:
|
|
||||||
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
|
||||||
if flips == 2:
|
|
||||||
y = img_size[0] - y # de-flip ud
|
|
||||||
elif flips == 3:
|
|
||||||
x = img_size[1] - x # de-flip lr
|
|
||||||
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
|
||||||
return p
|
|
||||||
|
|
||||||
def _clip_augmented(self, y):
|
|
||||||
# Clip YOLOv5 augmented inference tails
|
|
||||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
|
||||||
g = sum(4 ** x for x in range(nl)) # grid points
|
|
||||||
e = 1 # exclude layer count
|
|
||||||
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
|
|
||||||
y[0] = y[0][:, :-i] # large
|
|
||||||
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
|
||||||
y[-1] = y[-1][:, i:] # small
|
|
||||||
return y
|
|
||||||
|
|
||||||
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
|
||||||
# https://arxiv.org/abs/1708.02002 section 3.3
|
|
||||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
|
||||||
m = self.model[-1] # Detect() module
|
|
||||||
for mi, s in zip(m.m, m.stride): # from
|
|
||||||
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
|
|
||||||
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
|
|
||||||
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
|
||||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
|
||||||
|
|
||||||
|
|
||||||
Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class SegmentationModel(DetectionModel):
|
|
||||||
# YOLOv5 segmentation model
|
|
||||||
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, anchors=None):
|
|
||||||
super().__init__(cfg, ch, nc, anchors)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationModel(BaseModel):
|
|
||||||
# YOLOv5 classification model
|
|
||||||
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
|
||||||
super().__init__()
|
|
||||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
|
||||||
|
|
||||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
|
||||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
|
||||||
if isinstance(model, DetectMultiBackend):
|
|
||||||
model = model.model # unwrap DetectMultiBackend
|
|
||||||
model.model = model.model[:cutoff] # backbone
|
|
||||||
m = model.model[-1] # last layer
|
|
||||||
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
|
||||||
c = Classify(ch, nc) # Classify()
|
|
||||||
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
|
||||||
model.model[-1] = c # replace
|
|
||||||
self.model = model.model
|
|
||||||
self.stride = model.stride
|
|
||||||
self.save = []
|
|
||||||
self.nc = nc
|
|
||||||
|
|
||||||
def _from_yaml(self, cfg):
|
|
||||||
# Create a YOLOv5 classification model from a *.yaml file
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_model(d, ch): # model_dict, input_channels(3)
|
|
||||||
# Parse a YOLOv5 model.yaml dictionary
|
|
||||||
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
|
||||||
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
|
||||||
if act:
|
|
||||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
|
||||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
|
||||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
|
||||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
|
||||||
|
|
||||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
|
||||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
|
||||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
|
||||||
for j, a in enumerate(args):
|
|
||||||
with contextlib.suppress(NameError):
|
|
||||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
|
||||||
|
|
||||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
|
||||||
if m in {
|
|
||||||
Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
|
|
||||||
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, EVCBlock, ODConv_3rd, ConvNextBlock}:
|
|
||||||
c1, c2 = ch[f], args[0]
|
|
||||||
if c2 != no: # if not output
|
|
||||||
c2 = make_divisible(c2 * gw, 8)
|
|
||||||
|
|
||||||
args = [c1, c2, *args[1:]]
|
|
||||||
if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x, EVCBlock}:
|
|
||||||
args.insert(2, n) # number of repeats
|
|
||||||
n = 1
|
|
||||||
elif m is nn.BatchNorm2d:
|
|
||||||
args = [ch[f]]
|
|
||||||
elif m is Concat:
|
|
||||||
c2 = sum(ch[x] for x in f)
|
|
||||||
# TODO: channel, gw, gd
|
|
||||||
elif m in {Detect, Segment}:
|
|
||||||
args.append([ch[x] for x in f])
|
|
||||||
if isinstance(args[1], int): # number of anchors
|
|
||||||
args[1] = [list(range(args[1] * 2))] * len(f)
|
|
||||||
if m is Segment:
|
|
||||||
args[3] = make_divisible(args[3] * gw, 8)
|
|
||||||
elif m is Contract:
|
|
||||||
c2 = ch[f] * args[0] ** 2
|
|
||||||
elif m is Expand:
|
|
||||||
c2 = ch[f] // args[0] ** 2
|
|
||||||
elif m is SOCA:
|
|
||||||
c1, c2 = ch[f], args[0]
|
|
||||||
if c2 != no:
|
|
||||||
c2 = make_divisible(c2 * gw, 8)
|
|
||||||
args = [c1, *args[1:]]
|
|
||||||
else:
|
|
||||||
c2 = ch[f]
|
|
||||||
|
|
||||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
|
||||||
t = str(m)[8:-2].replace('__main__.', '') # module type
|
|
||||||
np = sum(x.numel() for x in m_.parameters()) # number params
|
|
||||||
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
|
||||||
LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
|
|
||||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
|
||||||
layers.append(m_)
|
|
||||||
if i == 0:
|
|
||||||
ch = []
|
|
||||||
ch.append(c2)
|
|
||||||
return nn.Sequential(*layers), sorted(save)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
|
|
||||||
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
|
|
||||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
||||||
parser.add_argument('--profile', action='store_true', help='profile model speed')
|
|
||||||
parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
|
|
||||||
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
|
|
||||||
opt = parser.parse_args()
|
|
||||||
opt.cfg = check_yaml(opt.cfg) # check YAML
|
|
||||||
print_args(vars(opt))
|
|
||||||
device = select_device(opt.device)
|
|
||||||
|
|
||||||
# Create model
|
|
||||||
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
|
|
||||||
model = Model(opt.cfg).to(device)
|
|
||||||
|
|
||||||
# Options
|
|
||||||
if opt.line_profile: # profile layer by layer
|
|
||||||
model(im, profile=True)
|
|
||||||
|
|
||||||
elif opt.profile: # profile forward-backward
|
|
||||||
results = profile(input=im, ops=[model], n=3)
|
|
||||||
|
|
||||||
elif opt.test: # test all models
|
|
||||||
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
|
|
||||||
try:
|
|
||||||
_ = Model(cfg)
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Error in {cfg}: {e}')
|
|
||||||
|
|
||||||
else: # report fused model summary
|
|
||||||
model.fuse()
|
|
||||||
@ -1,151 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
|
||||||
|
|
||||||
Usage - sources:
|
|
||||||
$ python detect.py --weights yolov5s.pt --source 0 # webcam
|
|
||||||
img.jpg # image
|
|
||||||
vid.mp4 # video
|
|
||||||
screen # screenshot
|
|
||||||
path/ # directory
|
|
||||||
list.txt # list of images
|
|
||||||
list.streams # list of streams
|
|
||||||
'path/*.jpg' # glob
|
|
||||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
|
||||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
|
||||||
|
|
||||||
Usage - formats:
|
|
||||||
$ python detect.py --weights yolov5s.pt # PyTorch
|
|
||||||
yolov5s.torchscript # TorchScript
|
|
||||||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
|
||||||
yolov5s_openvino_model # OpenVINO
|
|
||||||
yolov5s.engine # TensorRT
|
|
||||||
yolov5s.mlmodel # CoreML (macOS-only)
|
|
||||||
yolov5s_saved_model # TensorFlow SavedModel
|
|
||||||
yolov5s.pb # TensorFlow GraphDef
|
|
||||||
yolov5s.tflite # TensorFlow Lite
|
|
||||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
|
||||||
yolov5s_paddle_model # PaddlePaddle
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
|
||||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
|
||||||
if str(ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
||||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
||||||
|
|
||||||
from models.common import DetectMultiBackend
|
|
||||||
from utils.general import (LOGGER, Profile, check_img_size, cv2,
|
|
||||||
non_max_suppression, scale_boxes)
|
|
||||||
from utils.augmentations import (letterbox)
|
|
||||||
from utils.torch_utils import select_device, smart_inference_mode
|
|
||||||
|
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def run(
|
|
||||||
model = '', # model path or triton URL
|
|
||||||
imagePath ='', # file/dir/URL/glob/screen/0(webcam)
|
|
||||||
imgsz=(512, 512), # inference size (height, width)
|
|
||||||
conf_thres=0.25, # confidence threshold
|
|
||||||
iou_thres=0.45, # NMS IOU threshold
|
|
||||||
max_det=1, # maximum detections per image
|
|
||||||
save_crop=True, # save cropped prediction boxes
|
|
||||||
save_dir='data/QR2023_roi/images/', # do not save images/videos
|
|
||||||
classes=None, # filter by class: --class 0, or --class 0 2 3
|
|
||||||
agnostic_nms=False, # class-agnostic NMS
|
|
||||||
augment=False, # augmented inference
|
|
||||||
):
|
|
||||||
|
|
||||||
if save_crop:
|
|
||||||
# Directories
|
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
stride, names, pt = model.stride, model.names, model.pt
|
|
||||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
|
||||||
im0 = cv2.imread(imagePath) # BGR
|
|
||||||
im = letterbox(im0, imgsz, stride=32, auto=True)[0] # padded resize
|
|
||||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
im = np.ascontiguousarray(im) # contiguous
|
|
||||||
p = imagePath
|
|
||||||
|
|
||||||
bs = 1
|
|
||||||
# Run inference
|
|
||||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
|
||||||
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
|
|
||||||
|
|
||||||
with dt[0]:
|
|
||||||
im = torch.from_numpy(im).to(model.device)
|
|
||||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
|
||||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
|
||||||
if len(im.shape) == 3:
|
|
||||||
im = im[None] # expand for batch dim
|
|
||||||
|
|
||||||
# Inference
|
|
||||||
with dt[1]:
|
|
||||||
visualize = False
|
|
||||||
pred = model(im, augment=augment, visualize=visualize)
|
|
||||||
|
|
||||||
# NMS
|
|
||||||
with dt[2]:
|
|
||||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
|
||||||
|
|
||||||
|
|
||||||
# Process predictions
|
|
||||||
for i, det in enumerate(pred): # per image
|
|
||||||
|
|
||||||
p = Path(p)
|
|
||||||
save_path = os.path.join(save_dir, p.name)
|
|
||||||
|
|
||||||
if len(det):
|
|
||||||
# Rescale boxes from img_size to im0 size
|
|
||||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
|
||||||
|
|
||||||
# Write results
|
|
||||||
a = det[:, :4]
|
|
||||||
b = det[:,4:5]
|
|
||||||
for *xyxy, conf, cls in reversed(det):
|
|
||||||
|
|
||||||
x_min, y_min, x_max, y_max = xyxy[:4]
|
|
||||||
x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
|
|
||||||
quarter_width = (x_max - x_min) // 2
|
|
||||||
quarter_height = (y_max - y_min) // 2
|
|
||||||
|
|
||||||
# Save results (image with detections)
|
|
||||||
if save_crop:
|
|
||||||
# 以左上顶点坐标为原点截图4分之一原图
|
|
||||||
# Convert im0 (NumPy array) to PIL image
|
|
||||||
im0 = Image.fromarray(np.uint8(im0))
|
|
||||||
cropped_im = im0.crop((x_min, y_min, x_min + quarter_width, y_min + quarter_height))
|
|
||||||
cropped_im.save(save_path)
|
|
||||||
|
|
||||||
# Print time (inference-only)
|
|
||||||
LOGGER.info(f"{p}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# check_requirements(exclude=('tensorboard', 'thop'))
|
|
||||||
meta_info = '/project/dataset/QR2023/terminal-box/meta_info_big_box_terminal.txt'
|
|
||||||
with open(meta_info) as fin:
|
|
||||||
paths = [line.strip() for line in fin]
|
|
||||||
data_len = len(paths)
|
|
||||||
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
weights = '/project/yolov5-qr/runs/train_QR/exp10/weights/qr_roi_cloud_detect_20230831.pt'
|
|
||||||
device = '4'
|
|
||||||
device = select_device(device)
|
|
||||||
qrbox_model = DetectMultiBackend(weights, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
for index in range(0, data_len):
|
|
||||||
imagePath = paths[index]
|
|
||||||
run(model=qrbox_model, imagePath=imagePath)
|
|
||||||
@ -1,82 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
utils/initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import platform
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
def emojis(str=''):
|
|
||||||
# Return platform-dependent emoji-safe version of string
|
|
||||||
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
|
||||||
|
|
||||||
|
|
||||||
class TryExcept(contextlib.ContextDecorator):
|
|
||||||
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
|
||||||
def __init__(self, msg=''):
|
|
||||||
self.msg = msg
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, value, traceback):
|
|
||||||
if value:
|
|
||||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def threaded(func):
|
|
||||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
|
||||||
thread.start()
|
|
||||||
return thread
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def join_threads(verbose=False):
|
|
||||||
# Join all daemon threads, i.e. atexit.register(lambda: join_threads())
|
|
||||||
main_thread = threading.current_thread()
|
|
||||||
for t in threading.enumerate():
|
|
||||||
if t is not main_thread:
|
|
||||||
if verbose:
|
|
||||||
print(f'Joining thread {t.name}')
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
|
|
||||||
def notebook_init(verbose=True):
|
|
||||||
# Check system software and hardware
|
|
||||||
print('Checking setup...')
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
from ..utils.general import check_font, is_colab
|
|
||||||
from ..utils.torch_utils import select_device # imports
|
|
||||||
|
|
||||||
check_font()
|
|
||||||
|
|
||||||
# import psutil
|
|
||||||
|
|
||||||
if is_colab():
|
|
||||||
shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory
|
|
||||||
|
|
||||||
# System info
|
|
||||||
display = None
|
|
||||||
if verbose:
|
|
||||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
|
||||||
ram = psutil.virtual_memory().total
|
|
||||||
total, used, free = shutil.disk_usage('/')
|
|
||||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
|
||||||
from IPython import display
|
|
||||||
display.clear_output()
|
|
||||||
s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)'
|
|
||||||
else:
|
|
||||||
s = ''
|
|
||||||
|
|
||||||
select_device(newline=False)
|
|
||||||
print(emojis(f'Setup complete ✅ {s}'))
|
|
||||||
return display
|
|
||||||
@ -1,397 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Image augmentation functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
|
|
||||||
from ..utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box, xywhn2xyxy
|
|
||||||
from ..utils.metrics import bbox_ioa
|
|
||||||
|
|
||||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
|
||||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
|
||||||
|
|
||||||
|
|
||||||
class Albumentations:
|
|
||||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
|
||||||
def __init__(self, size=640):
|
|
||||||
self.transform = None
|
|
||||||
prefix = colorstr('albumentations: ')
|
|
||||||
try:
|
|
||||||
import albumentations as A
|
|
||||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
|
||||||
|
|
||||||
T = [
|
|
||||||
A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
|
|
||||||
A.Blur(p=0.01),
|
|
||||||
A.MedianBlur(p=0.01),
|
|
||||||
A.ToGray(p=0.01),
|
|
||||||
A.CLAHE(p=0.01),
|
|
||||||
A.RandomBrightnessContrast(p=0.0),
|
|
||||||
A.RandomGamma(p=0.0),
|
|
||||||
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
|
||||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
|
||||||
|
|
||||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
|
||||||
except ImportError: # package not installed, skip
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'{prefix}{e}')
|
|
||||||
|
|
||||||
def __call__(self, im, labels, p=1.0):
|
|
||||||
if self.transform and random.random() < p:
|
|
||||||
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
|
|
||||||
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
|
|
||||||
return im, labels
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
|
|
||||||
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
|
|
||||||
return TF.normalize(x, mean, std, inplace=inplace)
|
|
||||||
|
|
||||||
|
|
||||||
def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
|
|
||||||
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
|
|
||||||
for i in range(3):
|
|
||||||
x[:, i] = x[:, i] * std[i] + mean[i]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
|
||||||
# HSV color-space augmentation
|
|
||||||
if hgain or sgain or vgain:
|
|
||||||
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
|
||||||
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
|
|
||||||
dtype = im.dtype # uint8
|
|
||||||
|
|
||||||
x = np.arange(0, 256, dtype=r.dtype)
|
|
||||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
|
||||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
|
||||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
|
||||||
|
|
||||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
|
||||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
|
|
||||||
|
|
||||||
|
|
||||||
def hist_equalize(im, clahe=True, bgr=False):
|
|
||||||
# Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
|
|
||||||
yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
|
|
||||||
if clahe:
|
|
||||||
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
||||||
yuv[:, :, 0] = c.apply(yuv[:, :, 0])
|
|
||||||
else:
|
|
||||||
yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
|
|
||||||
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
|
|
||||||
|
|
||||||
|
|
||||||
def replicate(im, labels):
|
|
||||||
# Replicate labels
|
|
||||||
h, w = im.shape[:2]
|
|
||||||
boxes = labels[:, 1:].astype(int)
|
|
||||||
x1, y1, x2, y2 = boxes.T
|
|
||||||
s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
|
|
||||||
for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
|
|
||||||
x1b, y1b, x2b, y2b = boxes[i]
|
|
||||||
bh, bw = y2b - y1b, x2b - x1b
|
|
||||||
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
|
|
||||||
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
|
|
||||||
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
|
|
||||||
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
|
|
||||||
|
|
||||||
return im, labels
|
|
||||||
|
|
||||||
|
|
||||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
|
||||||
# Resize and pad image while meeting stride-multiple constraints
|
|
||||||
shape = im.shape[:2] # current shape [height, width]
|
|
||||||
if isinstance(new_shape, int):
|
|
||||||
new_shape = (new_shape, new_shape)
|
|
||||||
|
|
||||||
# Scale ratio (new / old)
|
|
||||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
||||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
|
||||||
r = min(r, 1.0)
|
|
||||||
|
|
||||||
# Compute padding
|
|
||||||
ratio = r, r # width, height ratios
|
|
||||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
||||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
|
||||||
if auto: # minimum rectangle
|
|
||||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
|
||||||
elif scaleFill: # stretch
|
|
||||||
dw, dh = 0.0, 0.0
|
|
||||||
new_unpad = (new_shape[1], new_shape[0])
|
|
||||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
|
||||||
|
|
||||||
dw /= 2 # divide padding into 2 sides
|
|
||||||
dh /= 2
|
|
||||||
|
|
||||||
if shape[::-1] != new_unpad: # resize
|
|
||||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
||||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
||||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
||||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
|
||||||
return im, ratio, (dw, dh)
|
|
||||||
|
|
||||||
|
|
||||||
def random_perspective(im,
|
|
||||||
targets=(),
|
|
||||||
segments=(),
|
|
||||||
degrees=10,
|
|
||||||
translate=.1,
|
|
||||||
scale=.1,
|
|
||||||
shear=10,
|
|
||||||
perspective=0.0,
|
|
||||||
border=(0, 0)):
|
|
||||||
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
|
|
||||||
# targets = [cls, xyxy]
|
|
||||||
|
|
||||||
height = im.shape[0] + border[0] * 2 # shape(h,w,c)
|
|
||||||
width = im.shape[1] + border[1] * 2
|
|
||||||
|
|
||||||
# Center
|
|
||||||
C = np.eye(3)
|
|
||||||
C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
|
|
||||||
C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
|
|
||||||
|
|
||||||
# Perspective
|
|
||||||
P = np.eye(3)
|
|
||||||
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
|
||||||
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
|
||||||
|
|
||||||
# Rotation and Scale
|
|
||||||
R = np.eye(3)
|
|
||||||
a = random.uniform(-degrees, degrees)
|
|
||||||
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
|
||||||
s = random.uniform(1 - scale, 1 + scale)
|
|
||||||
# s = 2 ** random.uniform(-scale, scale)
|
|
||||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
|
||||||
|
|
||||||
# Shear
|
|
||||||
S = np.eye(3)
|
|
||||||
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
|
||||||
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
|
||||||
|
|
||||||
# Translation
|
|
||||||
T = np.eye(3)
|
|
||||||
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
|
|
||||||
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
|
|
||||||
|
|
||||||
# Combined rotation matrix
|
|
||||||
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
|
||||||
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
|
||||||
if perspective:
|
|
||||||
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
|
|
||||||
else: # affine
|
|
||||||
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
|
||||||
|
|
||||||
# Visualize
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
|
|
||||||
# ax[0].imshow(im[:, :, ::-1]) # base
|
|
||||||
# ax[1].imshow(im2[:, :, ::-1]) # warped
|
|
||||||
|
|
||||||
# Transform label coordinates
|
|
||||||
n = len(targets)
|
|
||||||
if n:
|
|
||||||
use_segments = any(x.any() for x in segments) and len(segments) == n
|
|
||||||
new = np.zeros((n, 4))
|
|
||||||
if use_segments: # warp segments
|
|
||||||
segments = resample_segments(segments) # upsample
|
|
||||||
for i, segment in enumerate(segments):
|
|
||||||
xy = np.ones((len(segment), 3))
|
|
||||||
xy[:, :2] = segment
|
|
||||||
xy = xy @ M.T # transform
|
|
||||||
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
|
|
||||||
|
|
||||||
# clip
|
|
||||||
new[i] = segment2box(xy, width, height)
|
|
||||||
|
|
||||||
else: # warp boxes
|
|
||||||
xy = np.ones((n * 4, 3))
|
|
||||||
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
|
||||||
xy = xy @ M.T # transform
|
|
||||||
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
|
||||||
|
|
||||||
# create new boxes
|
|
||||||
x = xy[:, [0, 2, 4, 6]]
|
|
||||||
y = xy[:, [1, 3, 5, 7]]
|
|
||||||
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
|
||||||
|
|
||||||
# clip
|
|
||||||
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
|
||||||
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
|
||||||
|
|
||||||
# filter candidates
|
|
||||||
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
|
|
||||||
targets = targets[i]
|
|
||||||
targets[:, 1:5] = new[i]
|
|
||||||
|
|
||||||
return im, targets
|
|
||||||
|
|
||||||
|
|
||||||
def copy_paste(im, labels, segments, p=0.5):
|
|
||||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
|
||||||
n = len(segments)
|
|
||||||
if p and n:
|
|
||||||
h, w, c = im.shape # height, width, channels
|
|
||||||
im_new = np.zeros(im.shape, np.uint8)
|
|
||||||
for j in random.sample(range(n), k=round(p * n)):
|
|
||||||
l, s = labels[j], segments[j]
|
|
||||||
box = w - l[3], l[2], w - l[1], l[4]
|
|
||||||
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
|
||||||
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
|
||||||
labels = np.concatenate((labels, [[l[0], *box]]), 0)
|
|
||||||
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
|
|
||||||
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
|
|
||||||
|
|
||||||
result = cv2.flip(im, 1) # augment segments (flip left-right)
|
|
||||||
i = cv2.flip(im_new, 1).astype(bool)
|
|
||||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
|
||||||
|
|
||||||
return im, labels, segments
|
|
||||||
|
|
||||||
|
|
||||||
def cutout(im, labels, p=0.5):
|
|
||||||
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
|
|
||||||
if random.random() < p:
|
|
||||||
h, w = im.shape[:2]
|
|
||||||
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
|
|
||||||
for s in scales:
|
|
||||||
mask_h = random.randint(1, int(h * s)) # create random masks
|
|
||||||
mask_w = random.randint(1, int(w * s))
|
|
||||||
|
|
||||||
# box
|
|
||||||
xmin = max(0, random.randint(0, w) - mask_w // 2)
|
|
||||||
ymin = max(0, random.randint(0, h) - mask_h // 2)
|
|
||||||
xmax = min(w, xmin + mask_w)
|
|
||||||
ymax = min(h, ymin + mask_h)
|
|
||||||
|
|
||||||
# apply random color mask
|
|
||||||
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
|
|
||||||
|
|
||||||
# return unobscured labels
|
|
||||||
if len(labels) and s > 0.03:
|
|
||||||
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
|
|
||||||
ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h)) # intersection over area
|
|
||||||
labels = labels[ioa < 0.60] # remove >60% obscured labels
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
def mixup(im, labels, im2, labels2):
|
|
||||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
|
||||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
|
||||||
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
|
||||||
labels = np.concatenate((labels, labels2), 0)
|
|
||||||
return im, labels
|
|
||||||
|
|
||||||
|
|
||||||
def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
|
||||||
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
|
||||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
|
||||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
|
||||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
|
||||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
|
||||||
|
|
||||||
|
|
||||||
def classify_albumentations(
|
|
||||||
augment=True,
|
|
||||||
size=224,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
|
|
||||||
hflip=0.5,
|
|
||||||
vflip=0.0,
|
|
||||||
jitter=0.4,
|
|
||||||
mean=IMAGENET_MEAN,
|
|
||||||
std=IMAGENET_STD,
|
|
||||||
auto_aug=False):
|
|
||||||
# YOLOv5 classification Albumentations (optional, only used if package is installed)
|
|
||||||
prefix = colorstr('albumentations: ')
|
|
||||||
try:
|
|
||||||
import albumentations as A
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
|
||||||
if augment: # Resize and crop
|
|
||||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
|
|
||||||
if auto_aug:
|
|
||||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
|
||||||
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
|
||||||
else:
|
|
||||||
if hflip > 0:
|
|
||||||
T += [A.HorizontalFlip(p=hflip)]
|
|
||||||
if vflip > 0:
|
|
||||||
T += [A.VerticalFlip(p=vflip)]
|
|
||||||
if jitter > 0:
|
|
||||||
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
|
|
||||||
T += [A.ColorJitter(*color_jitter, 0)]
|
|
||||||
else: # Use fixed crop for eval set (reproducibility)
|
|
||||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
|
||||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
|
||||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
|
||||||
return A.Compose(T)
|
|
||||||
|
|
||||||
except ImportError: # package not installed, skip
|
|
||||||
LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)')
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'{prefix}{e}')
|
|
||||||
|
|
||||||
|
|
||||||
def classify_transforms(size=224):
|
|
||||||
# Transforms to apply if albumentations not installed
|
|
||||||
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
|
|
||||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
|
||||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
|
||||||
|
|
||||||
|
|
||||||
class LetterBox:
|
|
||||||
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
|
||||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
|
||||||
super().__init__()
|
|
||||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
|
||||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
|
||||||
self.stride = stride # used with auto
|
|
||||||
|
|
||||||
def __call__(self, im): # im = np.array HWC
|
|
||||||
imh, imw = im.shape[:2]
|
|
||||||
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
|
||||||
h, w = round(imh * r), round(imw * r) # resized image
|
|
||||||
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
|
||||||
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
|
||||||
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
|
||||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
||||||
return im_out
|
|
||||||
|
|
||||||
|
|
||||||
class CenterCrop:
|
|
||||||
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
|
||||||
def __init__(self, size=640):
|
|
||||||
super().__init__()
|
|
||||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
|
||||||
|
|
||||||
def __call__(self, im): # im = np.array HWC
|
|
||||||
imh, imw = im.shape[:2]
|
|
||||||
m = min(imh, imw) # min dimension
|
|
||||||
top, left = (imh - m) // 2, (imw - m) // 2
|
|
||||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
|
|
||||||
class ToTensor:
|
|
||||||
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
|
||||||
def __init__(self, half=False):
|
|
||||||
super().__init__()
|
|
||||||
self.half = half
|
|
||||||
|
|
||||||
def __call__(self, im): # im = np.array HWC in BGR order
|
|
||||||
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
|
||||||
im = torch.from_numpy(im) # to torch
|
|
||||||
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
|
||||||
im /= 255.0 # 0-255 to 0.0-1.0
|
|
||||||
return im
|
|
||||||
@ -1,169 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
AutoAnchor utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from ..utils import TryExcept
|
|
||||||
from ..utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
|
|
||||||
|
|
||||||
PREFIX = colorstr('AutoAnchor: ')
|
|
||||||
|
|
||||||
|
|
||||||
def check_anchor_order(m):
|
|
||||||
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
|
||||||
a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
|
|
||||||
da = a[-1] - a[0] # delta a
|
|
||||||
ds = m.stride[-1] - m.stride[0] # delta s
|
|
||||||
if da and (da.sign() != ds.sign()): # same order
|
|
||||||
LOGGER.info(f'{PREFIX}Reversing anchor order')
|
|
||||||
m.anchors[:] = m.anchors.flip(0)
|
|
||||||
|
|
||||||
|
|
||||||
@TryExcept(f'{PREFIX}ERROR')
|
|
||||||
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
||||||
# Check anchor fit to data, recompute if necessary
|
|
||||||
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
|
||||||
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
|
||||||
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
|
|
||||||
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
|
|
||||||
|
|
||||||
def metric(k): # compute metric
|
|
||||||
r = wh[:, None] / k[None]
|
|
||||||
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
|
||||||
best = x.max(1)[0] # best_x
|
|
||||||
aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
|
|
||||||
bpr = (best > 1 / thr).float().mean() # best possible recall
|
|
||||||
return bpr, aat
|
|
||||||
|
|
||||||
stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
|
|
||||||
anchors = m.anchors.clone() * stride # current anchors
|
|
||||||
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
|
||||||
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
|
|
||||||
if bpr > 0.98: # threshold to recompute
|
|
||||||
LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅')
|
|
||||||
else:
|
|
||||||
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
|
|
||||||
na = m.anchors.numel() // 2 # number of anchors
|
|
||||||
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
||||||
new_bpr = metric(anchors)[0]
|
|
||||||
if new_bpr > bpr: # replace anchors
|
|
||||||
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
|
||||||
m.anchors[:] = anchors.clone().view_as(m.anchors)
|
|
||||||
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
|
||||||
m.anchors /= stride
|
|
||||||
s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
|
|
||||||
else:
|
|
||||||
s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
|
|
||||||
LOGGER.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
|
|
||||||
""" Creates kmeans-evolved anchors from training dataset
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
dataset: path to data.yaml, or a loaded dataset
|
|
||||||
n: number of anchors
|
|
||||||
img_size: image size used for training
|
|
||||||
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
|
|
||||||
gen: generations to evolve anchors using genetic algorithm
|
|
||||||
verbose: print all results
|
|
||||||
|
|
||||||
Return:
|
|
||||||
k: kmeans evolved anchors
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from utils.autoanchor import *; _ = kmean_anchors()
|
|
||||||
"""
|
|
||||||
from scipy.cluster.vq import kmeans
|
|
||||||
|
|
||||||
npr = np.random
|
|
||||||
thr = 1 / thr
|
|
||||||
|
|
||||||
def metric(k, wh): # compute metrics
|
|
||||||
r = wh[:, None] / k[None]
|
|
||||||
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
|
||||||
# x = wh_iou(wh, torch.tensor(k)) # iou metric
|
|
||||||
return x, x.max(1)[0] # x, best_x
|
|
||||||
|
|
||||||
def anchor_fitness(k): # mutation fitness
|
|
||||||
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
|
|
||||||
return (best * (best > thr).float()).mean() # fitness
|
|
||||||
|
|
||||||
def print_results(k, verbose=True):
|
|
||||||
k = k[np.argsort(k.prod(1))] # sort small to large
|
|
||||||
x, best = metric(k, wh0)
|
|
||||||
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
|
|
||||||
s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
|
|
||||||
f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
|
|
||||||
f'past_thr={x[x > thr].mean():.3f}-mean: '
|
|
||||||
for x in k:
|
|
||||||
s += '%i,%i, ' % (round(x[0]), round(x[1]))
|
|
||||||
if verbose:
|
|
||||||
LOGGER.info(s[:-2])
|
|
||||||
return k
|
|
||||||
|
|
||||||
if isinstance(dataset, str): # *.yaml file
|
|
||||||
with open(dataset, errors='ignore') as f:
|
|
||||||
data_dict = yaml.safe_load(f) # model dict
|
|
||||||
from utils.dataloaders import LoadImagesAndLabels
|
|
||||||
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
|
|
||||||
|
|
||||||
# Get label wh
|
|
||||||
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
|
||||||
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
|
|
||||||
|
|
||||||
# Filter
|
|
||||||
i = (wh0 < 3.0).any(1).sum()
|
|
||||||
if i:
|
|
||||||
LOGGER.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size')
|
|
||||||
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
|
|
||||||
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
|
||||||
|
|
||||||
# Kmeans init
|
|
||||||
try:
|
|
||||||
LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
|
|
||||||
assert n <= len(wh) # apply overdetermined constraint
|
|
||||||
s = wh.std(0) # sigmas for whitening
|
|
||||||
k = kmeans(wh / s, n, iter=30)[0] * s # points
|
|
||||||
assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
|
|
||||||
except Exception:
|
|
||||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
|
|
||||||
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
|
|
||||||
wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
|
|
||||||
k = print_results(k, verbose=False)
|
|
||||||
|
|
||||||
# Plot
|
|
||||||
# k, d = [None] * 20, [None] * 20
|
|
||||||
# for i in tqdm(range(1, 21)):
|
|
||||||
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
|
|
||||||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
|
|
||||||
# ax = ax.ravel()
|
|
||||||
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
|
|
||||||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
|
|
||||||
# ax[0].hist(wh[wh[:, 0]<100, 0],400)
|
|
||||||
# ax[1].hist(wh[wh[:, 1]<100, 1],400)
|
|
||||||
# fig.savefig('wh.png', dpi=200)
|
|
||||||
|
|
||||||
# Evolve
|
|
||||||
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
|
||||||
pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
|
|
||||||
for _ in pbar:
|
|
||||||
v = np.ones(sh)
|
|
||||||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
|
||||||
v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
|
|
||||||
kg = (k.copy() * v).clip(min=2.0)
|
|
||||||
fg = anchor_fitness(kg)
|
|
||||||
if fg > f:
|
|
||||||
f, k = fg, kg.copy()
|
|
||||||
pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
|
|
||||||
if verbose:
|
|
||||||
print_results(k, verbose)
|
|
||||||
|
|
||||||
return print_results(k).astype(np.float32)
|
|
||||||
@ -1,127 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Download utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import subprocess
|
|
||||||
import urllib
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def is_url(url, check=True):
|
|
||||||
# Check if string is URL and check if URL exists
|
|
||||||
try:
|
|
||||||
url = str(url)
|
|
||||||
result = urllib.parse.urlparse(url)
|
|
||||||
assert all([result.scheme, result.netloc]) # check if is url
|
|
||||||
return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
|
|
||||||
except (AssertionError, urllib.request.HTTPError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def gsutil_getsize(url=''):
|
|
||||||
# gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
|
|
||||||
output = subprocess.check_output(['gsutil', 'du', url], shell=True, encoding='utf-8')
|
|
||||||
if output:
|
|
||||||
return int(output.split()[0])
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def url_getsize(url='https://ultralytics.com/images/bus.jpg'):
|
|
||||||
# Return downloadable file size in bytes
|
|
||||||
response = requests.head(url, allow_redirects=True)
|
|
||||||
return int(response.headers.get('content-length', -1))
|
|
||||||
|
|
||||||
|
|
||||||
def curl_download(url, filename, *, silent: bool = False) -> bool:
|
|
||||||
"""
|
|
||||||
Download a file from a url to a filename using curl.
|
|
||||||
"""
|
|
||||||
silent_option = 'sS' if silent else '' # silent
|
|
||||||
proc = subprocess.run([
|
|
||||||
'curl',
|
|
||||||
'-#',
|
|
||||||
f'-{silent_option}L',
|
|
||||||
url,
|
|
||||||
'--output',
|
|
||||||
filename,
|
|
||||||
'--retry',
|
|
||||||
'9',
|
|
||||||
'-C',
|
|
||||||
'-',])
|
|
||||||
return proc.returncode == 0
|
|
||||||
|
|
||||||
|
|
||||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
|
||||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
|
||||||
from ..utils.general import LOGGER
|
|
||||||
|
|
||||||
file = Path(file)
|
|
||||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
|
||||||
try: # url1
|
|
||||||
LOGGER.info(f'Downloading {url} to {file}...')
|
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
|
||||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
|
||||||
except Exception as e: # url2
|
|
||||||
if file.exists():
|
|
||||||
file.unlink() # remove partial downloads
|
|
||||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
|
||||||
# curl download, retry and resume on fail
|
|
||||||
curl_download(url2 or url, file)
|
|
||||||
finally:
|
|
||||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
|
||||||
if file.exists():
|
|
||||||
file.unlink() # remove partial downloads
|
|
||||||
LOGGER.info(f'ERROR: {assert_msg}\n{error_msg}')
|
|
||||||
LOGGER.info('')
|
|
||||||
|
|
||||||
|
|
||||||
def attempt_download(file, repo='ultralytics/yolov5', release='v7.0'):
|
|
||||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v7.0', etc.
|
|
||||||
from ..utils.general import LOGGER
|
|
||||||
|
|
||||||
def github_assets(repository, version='latest'):
|
|
||||||
# Return GitHub repo tag (i.e. 'v7.0') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
|
||||||
if version != 'latest':
|
|
||||||
version = f'tags/{version}' # i.e. tags/v7.0
|
|
||||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
|
||||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
|
||||||
|
|
||||||
file = Path(str(file).strip().replace("'", ''))
|
|
||||||
if not file.exists():
|
|
||||||
# URL specified
|
|
||||||
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
|
||||||
if str(file).startswith(('http:/', 'https:/')): # download
|
|
||||||
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
|
||||||
file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
|
|
||||||
if Path(file).is_file():
|
|
||||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
|
||||||
else:
|
|
||||||
safe_download(file=file, url=url, min_bytes=1E5)
|
|
||||||
return file
|
|
||||||
|
|
||||||
# GitHub assets
|
|
||||||
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
|
||||||
try:
|
|
||||||
tag, assets = github_assets(repo, release)
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
tag, assets = github_assets(repo) # latest release
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
|
|
||||||
except Exception:
|
|
||||||
tag = release
|
|
||||||
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
|
||||||
if name in assets:
|
|
||||||
safe_download(file,
|
|
||||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
|
||||||
min_bytes=1E5,
|
|
||||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
|
|
||||||
|
|
||||||
return str(file)
|
|
||||||
@ -1,378 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Model validation metrics
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ..utils import TryExcept, threaded
|
|
||||||
|
|
||||||
|
|
||||||
def fitness(x):
|
|
||||||
# Model fitness as a weighted combination of metrics
|
|
||||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
|
||||||
return (x[:, :4] * w).sum(1)
|
|
||||||
|
|
||||||
|
|
||||||
def smooth(y, f=0.05):
|
|
||||||
# Box filter of fraction f
|
|
||||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
|
||||||
p = np.ones(nf // 2) # ones padding
|
|
||||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
|
||||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
|
||||||
|
|
||||||
|
|
||||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=''):
|
|
||||||
""" Compute the average precision, given the recall and precision curves.
|
|
||||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
|
||||||
# Arguments
|
|
||||||
tp: True positives (nparray, nx1 or nx10).
|
|
||||||
conf: Objectness value from 0-1 (nparray).
|
|
||||||
pred_cls: Predicted object classes (nparray).
|
|
||||||
target_cls: True object classes (nparray).
|
|
||||||
plot: Plot precision-recall curve at mAP@0.5
|
|
||||||
save_dir: Plot save directory
|
|
||||||
# Returns
|
|
||||||
The average precision as computed in py-faster-rcnn.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Sort by objectness
|
|
||||||
i = np.argsort(-conf)
|
|
||||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
|
||||||
|
|
||||||
# Find unique classes
|
|
||||||
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
|
||||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
|
||||||
|
|
||||||
# Create Precision-Recall curve and compute AP for each class
|
|
||||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
|
||||||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
|
||||||
for ci, c in enumerate(unique_classes):
|
|
||||||
i = pred_cls == c
|
|
||||||
n_l = nt[ci] # number of labels
|
|
||||||
n_p = i.sum() # number of predictions
|
|
||||||
if n_p == 0 or n_l == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Accumulate FPs and TPs
|
|
||||||
fpc = (1 - tp[i]).cumsum(0)
|
|
||||||
tpc = tp[i].cumsum(0)
|
|
||||||
|
|
||||||
# Recall
|
|
||||||
recall = tpc / (n_l + eps) # recall curve
|
|
||||||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
|
||||||
|
|
||||||
# Precision
|
|
||||||
precision = tpc / (tpc + fpc) # precision curve
|
|
||||||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
|
||||||
|
|
||||||
# AP from recall-precision curve
|
|
||||||
for j in range(tp.shape[1]):
|
|
||||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
|
||||||
if plot and j == 0:
|
|
||||||
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
|
||||||
|
|
||||||
# Compute F1 (harmonic mean of precision and recall)
|
|
||||||
f1 = 2 * p * r / (p + r + eps)
|
|
||||||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
|
||||||
names = dict(enumerate(names)) # to dict
|
|
||||||
if plot:
|
|
||||||
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
|
|
||||||
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
|
|
||||||
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
|
|
||||||
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
|
|
||||||
|
|
||||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
|
||||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
|
||||||
tp = (r * nt).round() # true positives
|
|
||||||
fp = (tp / (p + eps) - tp).round() # false positives
|
|
||||||
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_ap(recall, precision):
|
|
||||||
""" Compute the average precision, given the recall and precision curves
|
|
||||||
# Arguments
|
|
||||||
recall: The recall curve (list)
|
|
||||||
precision: The precision curve (list)
|
|
||||||
# Returns
|
|
||||||
Average precision, precision curve, recall curve
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Append sentinel values to beginning and end
|
|
||||||
mrec = np.concatenate(([0.0], recall, [1.0]))
|
|
||||||
mpre = np.concatenate(([1.0], precision, [0.0]))
|
|
||||||
|
|
||||||
# Compute the precision envelope
|
|
||||||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
|
||||||
|
|
||||||
# Integrate area under curve
|
|
||||||
method = 'interp' # methods: 'continuous', 'interp'
|
|
||||||
if method == 'interp':
|
|
||||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
|
||||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
|
||||||
else: # 'continuous'
|
|
||||||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
|
||||||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
|
||||||
|
|
||||||
return ap, mpre, mrec
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrix:
|
|
||||||
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
|
|
||||||
def __init__(self, nc, conf=0.25, iou_thres=0.45):
|
|
||||||
self.matrix = np.zeros((nc + 1, nc + 1))
|
|
||||||
self.nc = nc # number of classes
|
|
||||||
self.conf = conf
|
|
||||||
self.iou_thres = iou_thres
|
|
||||||
|
|
||||||
def process_batch(self, detections, labels):
|
|
||||||
"""
|
|
||||||
Return intersection-over-union (Jaccard index) of boxes.
|
|
||||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
||||||
Arguments:
|
|
||||||
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
|
||||||
labels (Array[M, 5]), class, x1, y1, x2, y2
|
|
||||||
Returns:
|
|
||||||
None, updates confusion matrix accordingly
|
|
||||||
"""
|
|
||||||
if detections is None:
|
|
||||||
gt_classes = labels.int()
|
|
||||||
for gc in gt_classes:
|
|
||||||
self.matrix[self.nc, gc] += 1 # background FN
|
|
||||||
return
|
|
||||||
|
|
||||||
detections = detections[detections[:, 4] > self.conf]
|
|
||||||
gt_classes = labels[:, 0].int()
|
|
||||||
detection_classes = detections[:, 5].int()
|
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
|
||||||
|
|
||||||
x = torch.where(iou > self.iou_thres)
|
|
||||||
if x[0].shape[0]:
|
|
||||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
||||||
if x[0].shape[0] > 1:
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
||||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
||||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
||||||
else:
|
|
||||||
matches = np.zeros((0, 3))
|
|
||||||
|
|
||||||
n = matches.shape[0] > 0
|
|
||||||
m0, m1, _ = matches.transpose().astype(int)
|
|
||||||
for i, gc in enumerate(gt_classes):
|
|
||||||
j = m0 == i
|
|
||||||
if n and sum(j) == 1:
|
|
||||||
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
|
|
||||||
else:
|
|
||||||
self.matrix[self.nc, gc] += 1 # true background
|
|
||||||
|
|
||||||
if n:
|
|
||||||
for i, dc in enumerate(detection_classes):
|
|
||||||
if not any(m1 == i):
|
|
||||||
self.matrix[dc, self.nc] += 1 # predicted background
|
|
||||||
|
|
||||||
def tp_fp(self):
|
|
||||||
tp = self.matrix.diagonal() # true positives
|
|
||||||
fp = self.matrix.sum(1) - tp # false positives
|
|
||||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
|
||||||
return tp[:-1], fp[:-1] # remove background class
|
|
||||||
|
|
||||||
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
|
|
||||||
def plot(self, normalize=True, save_dir='', names=()):
|
|
||||||
import seaborn as sn
|
|
||||||
|
|
||||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
|
||||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
|
||||||
nc, nn = self.nc, len(names) # number of classes, names
|
|
||||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
|
||||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
|
||||||
ticklabels = (names + ['background']) if labels else 'auto'
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
|
||||||
sn.heatmap(array,
|
|
||||||
ax=ax,
|
|
||||||
annot=nc < 30,
|
|
||||||
annot_kws={
|
|
||||||
'size': 8},
|
|
||||||
cmap='Blues',
|
|
||||||
fmt='.2f',
|
|
||||||
square=True,
|
|
||||||
vmin=0.0,
|
|
||||||
xticklabels=ticklabels,
|
|
||||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
|
||||||
ax.set_xlabel('True')
|
|
||||||
ax.set_ylabel('Predicted')
|
|
||||||
ax.set_title('Confusion Matrix')
|
|
||||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
def print(self):
|
|
||||||
for i in range(self.nc + 1):
|
|
||||||
print(' '.join(map(str, self.matrix[i])))
|
|
||||||
|
|
||||||
# by AI&CV Wasserstein ####### start #########################
|
|
||||||
def Wasserstein(box1, box2, xywh=True):
|
|
||||||
box2 = box2.T
|
|
||||||
if xywh:
|
|
||||||
b1_cx, b1_cy = (box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2
|
|
||||||
b1_w, b1_h = box1[2] - box1[0], box1[3] - box1[1]
|
|
||||||
b2_cx, b2_cy = (box2[0] + box2[0]) / 2, (box2[1] + box2[3]) / 2
|
|
||||||
b1_w, b1_h = box2[2] - box2[0], box2[3] - box2[1]
|
|
||||||
else:
|
|
||||||
b1_cx, b1_cy, b1_w, b1_h = box1[0], box1[1], box1[2], box1[3]
|
|
||||||
b2_cx, b2_cy, b2_w, b2_h = box2[0], box2[1], box2[2], box2[3]
|
|
||||||
cx_L2Norm = torch.pow((b1_cx - b2_cx), 2)
|
|
||||||
cy_L2Norm = torch.pow((b1_cy - b2_cy), 2)
|
|
||||||
p1 = cx_L2Norm + cy_L2Norm
|
|
||||||
w_FroNorm = torch.pow((b1_w - b2_w)/2, 2)
|
|
||||||
h_FroNorm = torch.pow((b1_h - b2_h)/2, 2)
|
|
||||||
p2 = w_FroNorm + h_FroNorm
|
|
||||||
return p1 + p2
|
|
||||||
# by AI&CV Wasserstein ####### end #########################
|
|
||||||
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
|
||||||
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
|
|
||||||
|
|
||||||
# Get the coordinates of bounding boxes
|
|
||||||
if xywh: # transform from xywh to xyxy
|
|
||||||
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
|
||||||
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
|
||||||
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
|
||||||
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
|
||||||
else: # x1, y1, x2, y2 = box1
|
|
||||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
|
||||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
|
||||||
w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
|
|
||||||
w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
|
|
||||||
|
|
||||||
# Intersection area
|
|
||||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
|
|
||||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
|
|
||||||
|
|
||||||
# Union Area
|
|
||||||
union = w1 * h1 + w2 * h2 - inter + eps
|
|
||||||
|
|
||||||
# IoU
|
|
||||||
iou = inter / union
|
|
||||||
if CIoU or DIoU or GIoU:
|
|
||||||
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
|
||||||
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
|
||||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
|
||||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
|
||||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
|
||||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
|
||||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
|
||||||
with torch.no_grad():
|
|
||||||
alpha = v / (v - iou + (1 + eps))
|
|
||||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
|
||||||
return iou - rho2 / c2 # DIoU
|
|
||||||
c_area = cw * ch + eps # convex area
|
|
||||||
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
|
||||||
return iou # IoU
|
|
||||||
|
|
||||||
|
|
||||||
def box_iou(box1, box2, eps=1e-7):
|
|
||||||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
|
||||||
"""
|
|
||||||
Return intersection-over-union (Jaccard index) of boxes.
|
|
||||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
||||||
Arguments:
|
|
||||||
box1 (Tensor[N, 4])
|
|
||||||
box2 (Tensor[M, 4])
|
|
||||||
Returns:
|
|
||||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
|
||||||
IoU values for every element in boxes1 and boxes2
|
|
||||||
"""
|
|
||||||
|
|
||||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
|
||||||
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
|
||||||
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
|
||||||
|
|
||||||
# IoU = inter / (area1 + area2 - inter)
|
|
||||||
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
|
||||||
|
|
||||||
|
|
||||||
def bbox_ioa(box1, box2, eps=1e-7):
|
|
||||||
""" Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
|
|
||||||
box1: np.array of shape(4)
|
|
||||||
box2: np.array of shape(nx4)
|
|
||||||
returns: np.array of shape(n)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Get the coordinates of bounding boxes
|
|
||||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1
|
|
||||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
|
||||||
|
|
||||||
# Intersection area
|
|
||||||
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
|
||||||
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
|
||||||
|
|
||||||
# box2 area
|
|
||||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
|
||||||
|
|
||||||
# Intersection over box2 area
|
|
||||||
return inter_area / box2_area
|
|
||||||
|
|
||||||
|
|
||||||
def wh_iou(wh1, wh2, eps=1e-7):
|
|
||||||
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
|
|
||||||
wh1 = wh1[:, None] # [N,1,2]
|
|
||||||
wh2 = wh2[None] # [1,M,2]
|
|
||||||
inter = torch.min(wh1, wh2).prod(2) # [N,M]
|
|
||||||
return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)
|
|
||||||
|
|
||||||
|
|
||||||
# Plots ----------------------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@threaded
|
|
||||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
|
||||||
# Precision-recall curve
|
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
||||||
py = np.stack(py, axis=1)
|
|
||||||
|
|
||||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
||||||
for i, y in enumerate(py.T):
|
|
||||||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
|
||||||
else:
|
|
||||||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
|
||||||
|
|
||||||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
|
||||||
ax.set_xlabel('Recall')
|
|
||||||
ax.set_ylabel('Precision')
|
|
||||||
ax.set_xlim(0, 1)
|
|
||||||
ax.set_ylim(0, 1)
|
|
||||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
|
||||||
ax.set_title('Precision-Recall Curve')
|
|
||||||
fig.savefig(save_dir, dpi=250)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
@threaded
|
|
||||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
|
||||||
# Metric-confidence curve
|
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
||||||
|
|
||||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
||||||
for i, y in enumerate(py):
|
|
||||||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
|
||||||
else:
|
|
||||||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
|
||||||
|
|
||||||
y = smooth(py.mean(0), 0.05)
|
|
||||||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
|
||||||
ax.set_xlabel(xlabel)
|
|
||||||
ax.set_ylabel(ylabel)
|
|
||||||
ax.set_xlim(0, 1)
|
|
||||||
ax.set_ylim(0, 1)
|
|
||||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
|
||||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
|
||||||
fig.savefig(save_dir, dpi=250)
|
|
||||||
plt.close(fig)
|
|
||||||
@ -1,560 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
Plotting utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from copy import copy
|
|
||||||
from pathlib import Path
|
|
||||||
from urllib.error import URLError
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sn
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
|
|
||||||
from ..utils import TryExcept, threaded
|
|
||||||
from ..utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
|
|
||||||
is_ascii, xywh2xyxy, xyxy2xywh)
|
|
||||||
from ..utils.metrics import fitness
|
|
||||||
# from utils.segment.general import scale_image
|
|
||||||
|
|
||||||
# Settings
|
|
||||||
RANK = int(os.getenv('RANK', -1))
|
|
||||||
matplotlib.rc('font', **{'size': 11})
|
|
||||||
matplotlib.use('Agg') # for writing to files only
|
|
||||||
|
|
||||||
|
|
||||||
class Colors:
|
|
||||||
# Ultralytics color palette https://ultralytics.com/
|
|
||||||
def __init__(self):
|
|
||||||
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
|
||||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
|
||||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
|
||||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
|
||||||
self.n = len(self.palette)
|
|
||||||
|
|
||||||
def __call__(self, i, bgr=False):
|
|
||||||
c = self.palette[int(i) % self.n]
|
|
||||||
return (c[2], c[1], c[0]) if bgr else c
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hex2rgb(h): # rgb order (PIL)
|
|
||||||
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
|
||||||
|
|
||||||
|
|
||||||
colors = Colors() # create instance for 'from utils.plots import colors'
|
|
||||||
|
|
||||||
|
|
||||||
def check_pil_font(font=FONT, size=10):
|
|
||||||
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
|
||||||
font = Path(font)
|
|
||||||
font = font if font.exists() else (CONFIG_DIR / font.name)
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
|
||||||
except Exception: # download if missing
|
|
||||||
try:
|
|
||||||
check_font(font)
|
|
||||||
return ImageFont.truetype(str(font), size)
|
|
||||||
except TypeError:
|
|
||||||
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
|
|
||||||
except URLError: # not online
|
|
||||||
return ImageFont.load_default()
|
|
||||||
|
|
||||||
|
|
||||||
class Annotator:
|
|
||||||
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
|
||||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
|
||||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
|
||||||
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
|
||||||
self.pil = pil or non_ascii
|
|
||||||
if self.pil: # use PIL
|
|
||||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
|
||||||
self.draw = ImageDraw.Draw(self.im)
|
|
||||||
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
|
||||||
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
|
||||||
else: # use cv2
|
|
||||||
self.im = im
|
|
||||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
|
||||||
|
|
||||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
|
||||||
# Add one xyxy box to image with label
|
|
||||||
if self.pil or not is_ascii(label):
|
|
||||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
|
||||||
if label:
|
|
||||||
w, h = self.font.getsize(label) # text width, height (WARNING: deprecated) in 9.2.0
|
|
||||||
# _, _, w, h = self.font.getbbox(label) # text width, height (New)
|
|
||||||
outside = box[1] - h >= 0 # label fits outside box
|
|
||||||
self.draw.rectangle(
|
|
||||||
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
|
||||||
box[1] + 1 if outside else box[1] + h + 1),
|
|
||||||
fill=color,
|
|
||||||
)
|
|
||||||
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
|
||||||
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
|
||||||
else: # cv2
|
|
||||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
|
||||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
|
||||||
if label:
|
|
||||||
tf = max(self.lw - 1, 1) # font thickness
|
|
||||||
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
|
||||||
outside = p1[1] - h >= 3
|
|
||||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
|
||||||
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
|
||||||
cv2.putText(self.im,
|
|
||||||
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
|
||||||
0,
|
|
||||||
self.lw / 3,
|
|
||||||
txt_color,
|
|
||||||
thickness=tf,
|
|
||||||
lineType=cv2.LINE_AA)
|
|
||||||
|
|
||||||
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
|
||||||
"""Plot masks at once.
|
|
||||||
Args:
|
|
||||||
masks (tensor): predicted masks on cuda, shape: [n, h, w]
|
|
||||||
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
|
|
||||||
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
|
|
||||||
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
|
||||||
"""
|
|
||||||
if self.pil:
|
|
||||||
# convert to numpy first
|
|
||||||
self.im = np.asarray(self.im).copy()
|
|
||||||
if len(masks) == 0:
|
|
||||||
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
|
||||||
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
|
|
||||||
colors = colors[:, None, None] # shape(n,1,1,3)
|
|
||||||
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
|
||||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
|
||||||
|
|
||||||
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
|
||||||
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
|
|
||||||
|
|
||||||
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
|
||||||
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
|
||||||
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
|
||||||
im_mask = (im_gpu * 255).byte().cpu().numpy()
|
|
||||||
self.im[:] = im_mask if retina_masks else scale_image(im_gpu.shape, im_mask, self.im.shape)
|
|
||||||
if self.pil:
|
|
||||||
# convert im back to PIL and update draw
|
|
||||||
self.fromarray(self.im)
|
|
||||||
|
|
||||||
def rectangle(self, xy, fill=None, outline=None, width=1):
|
|
||||||
# Add rectangle to image (PIL-only)
|
|
||||||
self.draw.rectangle(xy, fill, outline, width)
|
|
||||||
|
|
||||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
|
||||||
# Add text to image (PIL-only)
|
|
||||||
if anchor == 'bottom': # start y from font bottom
|
|
||||||
w, h = self.font.getsize(text) # text width, height
|
|
||||||
xy[1] += 1 - h
|
|
||||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
|
||||||
|
|
||||||
def fromarray(self, im):
|
|
||||||
# Update self.im from a numpy array
|
|
||||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
|
||||||
self.draw = ImageDraw.Draw(self.im)
|
|
||||||
|
|
||||||
def result(self):
|
|
||||||
# Return annotated image as array
|
|
||||||
return np.asarray(self.im)
|
|
||||||
|
|
||||||
|
|
||||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
|
||||||
"""
|
|
||||||
x: Features to be visualized
|
|
||||||
module_type: Module type
|
|
||||||
stage: Module stage within model
|
|
||||||
n: Maximum number of feature maps to plot
|
|
||||||
save_dir: Directory to save results
|
|
||||||
"""
|
|
||||||
if 'Detect' not in module_type:
|
|
||||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
|
||||||
if height > 1 and width > 1:
|
|
||||||
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
|
||||||
|
|
||||||
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
|
||||||
n = min(n, channels) # number of plots
|
|
||||||
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
|
||||||
ax = ax.ravel()
|
|
||||||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
|
||||||
for i in range(n):
|
|
||||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
|
||||||
ax[i].axis('off')
|
|
||||||
|
|
||||||
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
|
||||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
|
||||||
|
|
||||||
|
|
||||||
def hist2d(x, y, n=100):
|
|
||||||
# 2d histogram used in labels.png and evolve.png
|
|
||||||
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
|
||||||
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
|
||||||
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
|
||||||
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
|
||||||
return np.log(hist[xidx, yidx])
|
|
||||||
|
|
||||||
|
|
||||||
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
|
||||||
from scipy.signal import butter, filtfilt
|
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
|
||||||
def butter_lowpass(cutoff, fs, order):
|
|
||||||
nyq = 0.5 * fs
|
|
||||||
normal_cutoff = cutoff / nyq
|
|
||||||
return butter(order, normal_cutoff, btype='low', analog=False)
|
|
||||||
|
|
||||||
b, a = butter_lowpass(cutoff, fs, order=order)
|
|
||||||
return filtfilt(b, a, data) # forward-backward filter
|
|
||||||
|
|
||||||
|
|
||||||
def output_to_target(output, max_det=300):
|
|
||||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
|
||||||
targets = []
|
|
||||||
for i, o in enumerate(output):
|
|
||||||
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
|
||||||
j = torch.full((conf.shape[0], 1), i)
|
|
||||||
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
|
||||||
return torch.cat(targets, 0).numpy()
|
|
||||||
|
|
||||||
|
|
||||||
@threaded
|
|
||||||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
|
|
||||||
# Plot image grid with labels
|
|
||||||
if isinstance(images, torch.Tensor):
|
|
||||||
images = images.cpu().float().numpy()
|
|
||||||
if isinstance(targets, torch.Tensor):
|
|
||||||
targets = targets.cpu().numpy()
|
|
||||||
|
|
||||||
max_size = 1920 # max image size
|
|
||||||
max_subplots = 16 # max image subplots, i.e. 4x4
|
|
||||||
bs, _, h, w = images.shape # batch size, _, height, width
|
|
||||||
bs = min(bs, max_subplots) # limit plot images
|
|
||||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
|
||||||
if np.max(images[0]) <= 1:
|
|
||||||
images *= 255 # de-normalise (optional)
|
|
||||||
|
|
||||||
# Build Image
|
|
||||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
|
||||||
for i, im in enumerate(images):
|
|
||||||
if i == max_subplots: # if last batch has fewer images than we expect
|
|
||||||
break
|
|
||||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
||||||
im = im.transpose(1, 2, 0)
|
|
||||||
mosaic[y:y + h, x:x + w, :] = im
|
|
||||||
|
|
||||||
# Resize (optional)
|
|
||||||
scale = max_size / ns / max(h, w)
|
|
||||||
if scale < 1:
|
|
||||||
h = math.ceil(scale * h)
|
|
||||||
w = math.ceil(scale * w)
|
|
||||||
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
|
||||||
|
|
||||||
# Annotate
|
|
||||||
fs = int((h + w) * ns * 0.01) # font size
|
|
||||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
|
||||||
for i in range(i + 1):
|
|
||||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
||||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
|
||||||
if paths:
|
|
||||||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
|
||||||
if len(targets) > 0:
|
|
||||||
ti = targets[targets[:, 0] == i] # image targets
|
|
||||||
boxes = xywh2xyxy(ti[:, 2:6]).T
|
|
||||||
classes = ti[:, 1].astype('int')
|
|
||||||
labels = ti.shape[1] == 6 # labels if no conf column
|
|
||||||
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
|
|
||||||
|
|
||||||
if boxes.shape[1]:
|
|
||||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
|
||||||
boxes[[0, 2]] *= w # scale to pixels
|
|
||||||
boxes[[1, 3]] *= h
|
|
||||||
elif scale < 1: # absolute coords need scale if image scales
|
|
||||||
boxes *= scale
|
|
||||||
boxes[[0, 2]] += x
|
|
||||||
boxes[[1, 3]] += y
|
|
||||||
for j, box in enumerate(boxes.T.tolist()):
|
|
||||||
cls = classes[j]
|
|
||||||
color = colors(cls)
|
|
||||||
cls = names[cls] if names else cls
|
|
||||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
|
||||||
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
|
|
||||||
annotator.box_label(box, label, color=color)
|
|
||||||
annotator.im.save(fname) # save
|
|
||||||
|
|
||||||
|
|
||||||
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
|
|
||||||
# Plot LR simulating training for full epochs
|
|
||||||
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
|
|
||||||
y = []
|
|
||||||
for _ in range(epochs):
|
|
||||||
scheduler.step()
|
|
||||||
y.append(optimizer.param_groups[0]['lr'])
|
|
||||||
plt.plot(y, '.-', label='LR')
|
|
||||||
plt.xlabel('epoch')
|
|
||||||
plt.ylabel('LR')
|
|
||||||
plt.grid()
|
|
||||||
plt.xlim(0, epochs)
|
|
||||||
plt.ylim(0)
|
|
||||||
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def plot_val_txt(): # from utils.plots import *; plot_val()
|
|
||||||
# Plot val.txt histograms
|
|
||||||
x = np.loadtxt('val.txt', dtype=np.float32)
|
|
||||||
box = xyxy2xywh(x[:, :4])
|
|
||||||
cx, cy = box[:, 0], box[:, 1]
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
|
|
||||||
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
|
|
||||||
ax.set_aspect('equal')
|
|
||||||
plt.savefig('hist2d.png', dpi=300)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
|
|
||||||
ax[0].hist(cx, bins=600)
|
|
||||||
ax[1].hist(cy, bins=600)
|
|
||||||
plt.savefig('hist1d.png', dpi=200)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
|
|
||||||
# Plot targets.txt histograms
|
|
||||||
x = np.loadtxt('targets.txt', dtype=np.float32).T
|
|
||||||
s = ['x targets', 'y targets', 'width targets', 'height targets']
|
|
||||||
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
|
||||||
ax = ax.ravel()
|
|
||||||
for i in range(4):
|
|
||||||
ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
|
|
||||||
ax[i].legend()
|
|
||||||
ax[i].set_title(s[i])
|
|
||||||
plt.savefig('targets.jpg', dpi=200)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
|
|
||||||
# Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
|
|
||||||
save_dir = Path(file).parent if file else Path(dir)
|
|
||||||
plot2 = False # plot additional results
|
|
||||||
if plot2:
|
|
||||||
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
|
|
||||||
|
|
||||||
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
|
|
||||||
# for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
|
|
||||||
for f in sorted(save_dir.glob('study*.txt')):
|
|
||||||
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
|
|
||||||
x = np.arange(y.shape[1]) if x is None else np.array(x)
|
|
||||||
if plot2:
|
|
||||||
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
|
|
||||||
for i in range(7):
|
|
||||||
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
|
|
||||||
ax[i].set_title(s[i])
|
|
||||||
|
|
||||||
j = y[3].argmax() + 1
|
|
||||||
ax2.plot(y[5, 1:j],
|
|
||||||
y[3, 1:j] * 1E2,
|
|
||||||
'.-',
|
|
||||||
linewidth=2,
|
|
||||||
markersize=8,
|
|
||||||
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
|
|
||||||
|
|
||||||
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
|
|
||||||
'k.-',
|
|
||||||
linewidth=2,
|
|
||||||
markersize=8,
|
|
||||||
alpha=.25,
|
|
||||||
label='EfficientDet')
|
|
||||||
|
|
||||||
ax2.grid(alpha=0.2)
|
|
||||||
ax2.set_yticks(np.arange(20, 60, 5))
|
|
||||||
ax2.set_xlim(0, 57)
|
|
||||||
ax2.set_ylim(25, 55)
|
|
||||||
ax2.set_xlabel('GPU Speed (ms/img)')
|
|
||||||
ax2.set_ylabel('COCO AP val')
|
|
||||||
ax2.legend(loc='lower right')
|
|
||||||
f = save_dir / 'study.png'
|
|
||||||
print(f'Saving {f}...')
|
|
||||||
plt.savefig(f, dpi=300)
|
|
||||||
|
|
||||||
|
|
||||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
|
||||||
def plot_labels(labels, names=(), save_dir=Path('')):
|
|
||||||
# plot dataset labels
|
|
||||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
||||||
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
|
||||||
nc = int(c.max() + 1) # number of classes
|
|
||||||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
|
||||||
|
|
||||||
# seaborn correlogram
|
|
||||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
|
||||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# matplotlib labels
|
|
||||||
matplotlib.use('svg') # faster
|
|
||||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
|
||||||
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
|
||||||
with contextlib.suppress(Exception): # color histogram bars by class
|
|
||||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
|
||||||
ax[0].set_ylabel('instances')
|
|
||||||
if 0 < len(names) < 30:
|
|
||||||
ax[0].set_xticks(range(len(names)))
|
|
||||||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
|
||||||
else:
|
|
||||||
ax[0].set_xlabel('classes')
|
|
||||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
|
||||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
|
||||||
|
|
||||||
# rectangles
|
|
||||||
labels[:, 1:3] = 0.5 # center
|
|
||||||
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
|
|
||||||
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
|
||||||
for cls, *box in labels[:1000]:
|
|
||||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
|
||||||
ax[1].imshow(img)
|
|
||||||
ax[1].axis('off')
|
|
||||||
|
|
||||||
for a in [0, 1, 2, 3]:
|
|
||||||
for s in ['top', 'right', 'left', 'bottom']:
|
|
||||||
ax[a].spines[s].set_visible(False)
|
|
||||||
|
|
||||||
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
|
||||||
matplotlib.use('Agg')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
|
|
||||||
# Show classification image grid with labels (optional) and predictions (optional)
|
|
||||||
from ..utils.augmentations import denormalize
|
|
||||||
|
|
||||||
names = names or [f'class{i}' for i in range(1000)]
|
|
||||||
blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
|
|
||||||
dim=0) # select batch index 0, block by channels
|
|
||||||
n = min(len(blocks), nmax) # number of plots
|
|
||||||
m = min(8, round(n ** 0.5)) # 8 x 8 default
|
|
||||||
fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
|
|
||||||
ax = ax.ravel() if m > 1 else [ax]
|
|
||||||
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
|
||||||
for i in range(n):
|
|
||||||
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
|
|
||||||
ax[i].axis('off')
|
|
||||||
if labels is not None:
|
|
||||||
s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
|
|
||||||
ax[i].set_title(s, fontsize=8, verticalalignment='top')
|
|
||||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
if verbose:
|
|
||||||
LOGGER.info(f'Saving {f}')
|
|
||||||
if labels is not None:
|
|
||||||
LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
|
|
||||||
if pred is not None:
|
|
||||||
LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
|
||||||
# Plot evolve.csv hyp evolution results
|
|
||||||
evolve_csv = Path(evolve_csv)
|
|
||||||
data = pd.read_csv(evolve_csv)
|
|
||||||
keys = [x.strip() for x in data.columns]
|
|
||||||
x = data.values
|
|
||||||
f = fitness(x)
|
|
||||||
j = np.argmax(f) # max fitness index
|
|
||||||
plt.figure(figsize=(10, 12), tight_layout=True)
|
|
||||||
matplotlib.rc('font', **{'size': 8})
|
|
||||||
print(f'Best results from row {j} of {evolve_csv}:')
|
|
||||||
for i, k in enumerate(keys[7:]):
|
|
||||||
v = x[:, 7 + i]
|
|
||||||
mu = v[j] # best single result
|
|
||||||
plt.subplot(6, 5, i + 1)
|
|
||||||
plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
|
|
||||||
plt.plot(mu, f.max(), 'k+', markersize=15)
|
|
||||||
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
|
|
||||||
if i % 5 != 0:
|
|
||||||
plt.yticks([])
|
|
||||||
print(f'{k:>15}: {mu:.3g}')
|
|
||||||
f = evolve_csv.with_suffix('.png') # filename
|
|
||||||
plt.savefig(f, dpi=200)
|
|
||||||
plt.close()
|
|
||||||
print(f'Saved {f}')
|
|
||||||
|
|
||||||
|
|
||||||
def plot_results(file='path/to/results.csv', dir=''):
|
|
||||||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
|
||||||
save_dir = Path(file).parent if file else Path(dir)
|
|
||||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
|
||||||
ax = ax.ravel()
|
|
||||||
files = list(save_dir.glob('results*.csv'))
|
|
||||||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
|
||||||
for f in files:
|
|
||||||
try:
|
|
||||||
data = pd.read_csv(f)
|
|
||||||
s = [x.strip() for x in data.columns]
|
|
||||||
x = data.values[:, 0]
|
|
||||||
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
|
|
||||||
y = data.values[:, j].astype('float')
|
|
||||||
# y[y == 0] = np.nan # don't show zero values
|
|
||||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
|
|
||||||
ax[i].set_title(s[j], fontsize=12)
|
|
||||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
|
||||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'Warning: Plotting error for {f}: {e}')
|
|
||||||
ax[1].legend()
|
|
||||||
fig.savefig(save_dir / 'results.png', dpi=200)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
|
|
||||||
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
|
|
||||||
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
|
|
||||||
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
|
|
||||||
files = list(Path(save_dir).glob('frames*.txt'))
|
|
||||||
for fi, f in enumerate(files):
|
|
||||||
try:
|
|
||||||
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
|
|
||||||
n = results.shape[1] # number of rows
|
|
||||||
x = np.arange(start, min(stop, n) if stop else n)
|
|
||||||
results = results[:, x]
|
|
||||||
t = (results[0] - results[0].min()) # set t0=0s
|
|
||||||
results[0] = x
|
|
||||||
for i, a in enumerate(ax):
|
|
||||||
if i < len(results):
|
|
||||||
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
|
|
||||||
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
|
|
||||||
a.set_title(s[i])
|
|
||||||
a.set_xlabel('time (s)')
|
|
||||||
# if fi == len(files) - 1:
|
|
||||||
# a.set_ylim(bottom=0)
|
|
||||||
for side in ['top', 'right']:
|
|
||||||
a.spines[side].set_visible(False)
|
|
||||||
else:
|
|
||||||
a.remove()
|
|
||||||
except Exception as e:
|
|
||||||
print(f'Warning: Plotting error for {f}; {e}')
|
|
||||||
ax[1].legend()
|
|
||||||
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
|
|
||||||
|
|
||||||
|
|
||||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
|
||||||
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
|
||||||
xyxy = torch.tensor(xyxy).view(-1, 4)
|
|
||||||
b = xyxy2xywh(xyxy) # boxes
|
|
||||||
if square:
|
|
||||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
|
||||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
|
||||||
xyxy = xywh2xyxy(b).long()
|
|
||||||
clip_boxes(xyxy, im.shape)
|
|
||||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
|
||||||
if save:
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
|
||||||
f = str(increment_path(file).with_suffix('.jpg'))
|
|
||||||
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
|
||||||
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
|
||||||
return crop
|
|
||||||
@ -1,433 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
"""
|
|
||||||
PyTorch utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import warnings
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from ..utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
|
||||||
RANK = int(os.getenv('RANK', -1))
|
|
||||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
|
||||||
|
|
||||||
try:
|
|
||||||
import thop # for FLOPs computation
|
|
||||||
except ImportError:
|
|
||||||
thop = None
|
|
||||||
|
|
||||||
# Suppress PyTorch warnings
|
|
||||||
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
|
||||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
|
||||||
def decorate(fn):
|
|
||||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
|
||||||
|
|
||||||
return decorate
|
|
||||||
|
|
||||||
|
|
||||||
def smartCrossEntropyLoss(label_smoothing=0.0):
|
|
||||||
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
|
||||||
if check_version(torch.__version__, '1.10.0'):
|
|
||||||
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
|
||||||
if label_smoothing > 0:
|
|
||||||
LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
|
|
||||||
return nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
|
|
||||||
def smart_DDP(model):
|
|
||||||
# Model DDP creation with checks
|
|
||||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
|
||||||
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
|
||||||
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
|
||||||
if check_version(torch.__version__, '1.11.0'):
|
|
||||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
|
||||||
else:
|
|
||||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_classifier_output(model, n=1000):
|
|
||||||
# Update a TorchVision classification model to class count 'n' if required
|
|
||||||
from ..models.common import Classify
|
|
||||||
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
|
||||||
if isinstance(m, Classify): # YOLOv5 Classify() head
|
|
||||||
if m.linear.out_features != n:
|
|
||||||
m.linear = nn.Linear(m.linear.in_features, n)
|
|
||||||
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
|
||||||
if m.out_features != n:
|
|
||||||
setattr(model, name, nn.Linear(m.in_features, n))
|
|
||||||
elif isinstance(m, nn.Sequential):
|
|
||||||
types = [type(x) for x in m]
|
|
||||||
if nn.Linear in types:
|
|
||||||
i = types.index(nn.Linear) # nn.Linear index
|
|
||||||
if m[i].out_features != n:
|
|
||||||
m[i] = nn.Linear(m[i].in_features, n)
|
|
||||||
elif nn.Conv2d in types:
|
|
||||||
i = types.index(nn.Conv2d) # nn.Conv2d index
|
|
||||||
if m[i].out_channels != n:
|
|
||||||
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
|
||||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
|
||||||
if local_rank not in [-1, 0]:
|
|
||||||
dist.barrier(device_ids=[local_rank])
|
|
||||||
yield
|
|
||||||
if local_rank == 0:
|
|
||||||
dist.barrier(device_ids=[0])
|
|
||||||
|
|
||||||
|
|
||||||
def device_count():
|
|
||||||
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
|
|
||||||
assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
|
|
||||||
try:
|
|
||||||
cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
|
|
||||||
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
|
|
||||||
except Exception:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def select_device(device='', batch_size=0, newline=True,model_path=None):
|
|
||||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
|
||||||
model_name = model_path.split('/')[-1]
|
|
||||||
s = f'{model_name} 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
|
||||||
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
|
||||||
cpu = device == 'cpu'
|
|
||||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
|
||||||
if cpu or mps:
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
|
||||||
elif device: # non-cpu device requested
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
|
||||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
|
||||||
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
|
||||||
|
|
||||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
|
||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
|
||||||
n = len(devices) # device count
|
|
||||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
|
||||||
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
|
||||||
space = ' ' * (len(s) + 1)
|
|
||||||
for i, d in enumerate(devices):
|
|
||||||
p = torch.cuda.get_device_properties(i)
|
|
||||||
s += f"{'' if i == 0 else space}CUDA:{torch.version.cuda} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
|
||||||
arg = 'cuda:0'
|
|
||||||
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
|
|
||||||
s += 'MPS\n'
|
|
||||||
arg = 'mps'
|
|
||||||
else: # revert to CPU
|
|
||||||
s += 'CPU\n'
|
|
||||||
arg = 'cpu'
|
|
||||||
|
|
||||||
if not newline:
|
|
||||||
s = s.rstrip()
|
|
||||||
LOGGER.info(s)
|
|
||||||
return torch.device(arg)
|
|
||||||
|
|
||||||
|
|
||||||
def time_sync():
|
|
||||||
# PyTorch-accurate time
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def profile(input, ops, n=10, device=None):
|
|
||||||
""" YOLOv5 speed/memory/FLOPs profiler
|
|
||||||
Usage:
|
|
||||||
input = torch.randn(16, 3, 640, 640)
|
|
||||||
m1 = lambda x: x * torch.sigmoid(x)
|
|
||||||
m2 = nn.SiLU()
|
|
||||||
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
if not isinstance(device, torch.device):
|
|
||||||
device = select_device(device)
|
|
||||||
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
|
||||||
f"{'input':>24s}{'output':>24s}")
|
|
||||||
|
|
||||||
for x in input if isinstance(input, list) else [input]:
|
|
||||||
x = x.to(device)
|
|
||||||
x.requires_grad = True
|
|
||||||
for m in ops if isinstance(ops, list) else [ops]:
|
|
||||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
|
||||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
|
||||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
|
||||||
try:
|
|
||||||
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
|
||||||
except Exception:
|
|
||||||
flops = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _ in range(n):
|
|
||||||
t[0] = time_sync()
|
|
||||||
y = m(x)
|
|
||||||
t[1] = time_sync()
|
|
||||||
try:
|
|
||||||
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
|
||||||
t[2] = time_sync()
|
|
||||||
except Exception: # no backward method
|
|
||||||
# print(e) # for debug
|
|
||||||
t[2] = float('nan')
|
|
||||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
|
||||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
|
||||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
|
||||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
|
||||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
|
||||||
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
|
||||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
results.append(None)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def is_parallel(model):
|
|
||||||
# Returns True if model is of type DP or DDP
|
|
||||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
|
||||||
|
|
||||||
|
|
||||||
def de_parallel(model):
|
|
||||||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
|
||||||
return model.module if is_parallel(model) else model
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(model):
|
|
||||||
for m in model.modules():
|
|
||||||
t = type(m)
|
|
||||||
if t is nn.Conv2d:
|
|
||||||
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
||||||
elif t is nn.BatchNorm2d:
|
|
||||||
m.eps = 1e-3
|
|
||||||
m.momentum = 0.03
|
|
||||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
|
||||||
m.inplace = True
|
|
||||||
|
|
||||||
|
|
||||||
def find_modules(model, mclass=nn.Conv2d):
|
|
||||||
# Finds layer indices matching module class 'mclass'
|
|
||||||
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
|
||||||
|
|
||||||
|
|
||||||
def sparsity(model):
|
|
||||||
# Return global model sparsity
|
|
||||||
a, b = 0, 0
|
|
||||||
for p in model.parameters():
|
|
||||||
a += p.numel()
|
|
||||||
b += (p == 0).sum()
|
|
||||||
return b / a
|
|
||||||
|
|
||||||
|
|
||||||
def prune(model, amount=0.3):
|
|
||||||
# Prune model to requested global sparsity
|
|
||||||
import torch.nn.utils.prune as prune
|
|
||||||
for name, m in model.named_modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
|
||||||
prune.remove(m, 'weight') # make permanent
|
|
||||||
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
|
||||||
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
|
||||||
fusedconv = nn.Conv2d(conv.in_channels,
|
|
||||||
conv.out_channels,
|
|
||||||
kernel_size=conv.kernel_size,
|
|
||||||
stride=conv.stride,
|
|
||||||
padding=conv.padding,
|
|
||||||
dilation=conv.dilation,
|
|
||||||
groups=conv.groups,
|
|
||||||
bias=True).requires_grad_(False).to(conv.weight.device)
|
|
||||||
|
|
||||||
# Prepare filters
|
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
|
||||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
|
||||||
|
|
||||||
# Prepare spatial bias
|
|
||||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
|
||||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
|
||||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
|
||||||
|
|
||||||
return fusedconv
|
|
||||||
|
|
||||||
|
|
||||||
def model_info(model, verbose=False, imgsz=640):
|
|
||||||
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
|
||||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
|
||||||
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
|
||||||
if verbose:
|
|
||||||
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
|
||||||
name = name.replace('module_list.', '')
|
|
||||||
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
|
||||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
|
||||||
|
|
||||||
try: # FLOPs
|
|
||||||
p = next(model.parameters())
|
|
||||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
|
||||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
|
||||||
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
|
||||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
|
||||||
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
|
||||||
except Exception:
|
|
||||||
fs = ''
|
|
||||||
|
|
||||||
name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
|
|
||||||
LOGGER.info(f'model summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
|
|
||||||
|
|
||||||
|
|
||||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
|
||||||
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
|
||||||
if ratio == 1.0:
|
|
||||||
return img
|
|
||||||
h, w = img.shape[2:]
|
|
||||||
s = (int(h * ratio), int(w * ratio)) # new size
|
|
||||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
|
||||||
if not same_shape: # pad/crop img
|
|
||||||
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
|
||||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
|
||||||
|
|
||||||
|
|
||||||
def copy_attr(a, b, include=(), exclude=()):
|
|
||||||
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
|
||||||
for k, v in b.__dict__.items():
|
|
||||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
setattr(a, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
|
||||||
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
|
||||||
g = [], [], [] # optimizer parameter groups
|
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
|
||||||
for v in model.modules():
|
|
||||||
for p_name, p in v.named_parameters(recurse=0):
|
|
||||||
if p_name == 'bias': # bias (no decay)
|
|
||||||
g[2].append(p)
|
|
||||||
elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
|
|
||||||
g[1].append(p)
|
|
||||||
else:
|
|
||||||
g[0].append(p) # weight (with decay)
|
|
||||||
|
|
||||||
if name == 'Adam':
|
|
||||||
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
|
||||||
elif name == 'AdamW':
|
|
||||||
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
|
||||||
elif name == 'RMSProp':
|
|
||||||
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
|
||||||
elif name == 'SGD':
|
|
||||||
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
|
||||||
|
|
||||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
|
||||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
|
||||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
|
||||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
|
|
||||||
# YOLOv5 torch.hub.load() wrapper with smart error/issue handling
|
|
||||||
if check_version(torch.__version__, '1.9.1'):
|
|
||||||
kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
|
|
||||||
if check_version(torch.__version__, '1.12.0'):
|
|
||||||
kwargs['trust_repo'] = True # argument required starting in torch 0.12
|
|
||||||
try:
|
|
||||||
return torch.hub.load(repo, model, **kwargs)
|
|
||||||
except Exception:
|
|
||||||
return torch.hub.load(repo, model, force_reload=True, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
|
||||||
# Resume training from a partially trained checkpoint
|
|
||||||
best_fitness = 0.0
|
|
||||||
start_epoch = ckpt['epoch'] + 1
|
|
||||||
if ckpt['optimizer'] is not None:
|
|
||||||
optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
|
||||||
best_fitness = ckpt['best_fitness']
|
|
||||||
if ema and ckpt.get('ema'):
|
|
||||||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
|
||||||
ema.updates = ckpt['updates']
|
|
||||||
if resume:
|
|
||||||
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
|
|
||||||
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
|
|
||||||
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
|
|
||||||
if epochs < start_epoch:
|
|
||||||
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
|
||||||
epochs += ckpt['epoch'] # finetune additional epochs
|
|
||||||
return best_fitness, start_epoch, epochs
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
|
||||||
# YOLOv5 simple early stopper
|
|
||||||
def __init__(self, patience=30):
|
|
||||||
self.best_fitness = 0.0 # i.e. mAP
|
|
||||||
self.best_epoch = 0
|
|
||||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
|
||||||
self.possible_stop = False # possible stop may occur next epoch
|
|
||||||
|
|
||||||
def __call__(self, epoch, fitness):
|
|
||||||
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
|
||||||
self.best_epoch = epoch
|
|
||||||
self.best_fitness = fitness
|
|
||||||
delta = epoch - self.best_epoch # epochs without improvement
|
|
||||||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
|
||||||
stop = delta >= self.patience # stop training if patience exceeded
|
|
||||||
if stop:
|
|
||||||
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
|
||||||
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
|
||||||
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
|
||||||
f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
|
|
||||||
return stop
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
|
||||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
|
||||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
|
||||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
|
||||||
# Create EMA
|
|
||||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
|
||||||
self.updates = updates # number of EMA updates
|
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
|
||||||
for p in self.ema.parameters():
|
|
||||||
p.requires_grad_(False)
|
|
||||||
|
|
||||||
def update(self, model):
|
|
||||||
# Update EMA parameters
|
|
||||||
self.updates += 1
|
|
||||||
d = self.decay(self.updates)
|
|
||||||
|
|
||||||
msd = de_parallel(model).state_dict() # model state_dict
|
|
||||||
for k, v in self.ema.state_dict().items():
|
|
||||||
if v.dtype.is_floating_point: # true for FP16 and FP32
|
|
||||||
v *= d
|
|
||||||
v += (1 - d) * msd[k].detach()
|
|
||||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
|
||||||
|
|
||||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
|
||||||
# Update EMA attributes
|
|
||||||
copy_attr(self.ema, model, include, exclude)
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
import os
|
|
||||||
import oss2
|
|
||||||
|
|
||||||
def get_model(name):
|
|
||||||
os.makedirs(os.path.join(os.path.dirname(__file__), "model"), exist_ok=True)
|
|
||||||
fn = os.path.join(os.path.dirname(__file__), "model", name)
|
|
||||||
if os.path.exists(fn):
|
|
||||||
return fn
|
|
||||||
with open(fn + ".tmp", 'wb') as tf:
|
|
||||||
print(f"Downloading {name}...")
|
|
||||||
tf.write(oss_get(name, "emblem-models"))
|
|
||||||
print(f"Downloaded {name}")
|
|
||||||
os.rename(fn + ".tmp", fn)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
def oss_get(name, bucket=None):
|
|
||||||
try:
|
|
||||||
return oss_bucket(bucket).get_object(name).read()
|
|
||||||
except oss2.exceptions.NoSuchKey:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def oss_bucket(bucketname):
|
|
||||||
auth = oss2.Auth('LTAI5tC2qXGxwHZUZP7DoD1A', 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm')
|
|
||||||
bucket = oss2.Bucket(auth, 'oss-rg-china-mainland.aliyuncs.com', bucketname)
|
|
||||||
return bucket
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(get_model('sr.prototxt'))
|
|
||||||
@ -346,10 +346,10 @@ Page({
|
|||||||
});
|
});
|
||||||
const ctx = wx.createCameraContext();
|
const ctx = wx.createCameraContext();
|
||||||
this.log(`camera set initial zoom to ${initial_zoom}x, will zoom in to ${rule.zoom}x when qr is found`);
|
this.log(`camera set initial zoom to ${initial_zoom}x, will zoom in to ${rule.zoom}x when qr is found`);
|
||||||
ctx.setZoom({ zoom: initial_zoom });
|
ctx.setZoom({ zoom: 2 });
|
||||||
this.on_qr_found = () => {
|
this.on_qr_found = () => {
|
||||||
this.log(`qr found, zoom to ${rule.zoom}x`);
|
this.log(`qr found, zoom to ${rule.zoom}x`);
|
||||||
ctx.setZoom({ zoom: rule.zoom });
|
ctx.setZoom({ zoom: 6 });
|
||||||
this.setData({
|
this.setData({
|
||||||
zoom: rule.zoom,
|
zoom: rule.zoom,
|
||||||
qrmarkers_class: "",
|
qrmarkers_class: "",
|
||||||
|
|||||||
@ -36,9 +36,8 @@ def main():
|
|||||||
'-t', '0',
|
'-t', '0',
|
||||||
'emblemapi.wsgi:application'
|
'emblemapi.wsgi:application'
|
||||||
] + gunicorn_args, cwd=os.path.join(BASE_DIR, 'api'))
|
] + gunicorn_args, cwd=os.path.join(BASE_DIR, 'api'))
|
||||||
detection = subprocess.Popen(["python3", "app.py"], cwd=os.path.join(BASE_DIR, "detection"))
|
|
||||||
|
|
||||||
procs = [nginx, gunicorn, detection]
|
procs = [nginx, gunicorn]
|
||||||
atexit.register(lambda: [x.kill() for x in procs])
|
atexit.register(lambda: [x.kill() for x in procs])
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -11,5 +11,4 @@ pip3 install -r requirements.txt
|
|||||||
|
|
||||||
tmux new-window -n serve "cd api && ./manage.py runserver; $SHELL -i"
|
tmux new-window -n serve "cd api && ./manage.py runserver; $SHELL -i"
|
||||||
tmux split-window "cd web && npm run serve; $SHELL -i"
|
tmux split-window "cd web && npm run serve; $SHELL -i"
|
||||||
tmux split-window "cd detection && python3 app.py; $SHELL -i"
|
|
||||||
$SHELL -i
|
$SHELL -i
|
||||||
|
|||||||