From ccf0fe5abacb82fbeeede3cd2e1fbb4252af889b Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Sun, 22 Dec 2024 10:46:58 +0000 Subject: [PATCH] improve ci --- .gitlab-ci.yml | 15 +++++++++++++++ Dockerfile | 1 + Makefile | 6 +++++- detection/app.py | 50 ++++++++++++++++++++++++++++++++---------------- 4 files changed, 55 insertions(+), 17 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 17c621a..35bfc98 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,5 @@ stages: + - download-models - test-and-build - build-docker - deploy @@ -8,6 +9,17 @@ cache: paths: - opencv/ - emtest/target + +download-models: + stage: download-models + script: + - make download-models + artifacts: + paths: + - detection/model + cache: + key: models + paths: - detection/model test: @@ -21,6 +33,8 @@ test: - make opencv -j$(nproc --ignore=2) - make -C alg qrtool -j$(nproc --ignore=2) - make test + dependencies: + - download-models build-alg: stage: test-and-build @@ -56,6 +70,7 @@ build-docker: dependencies: - build-web - build-alg + - download-models except: - main diff --git a/Dockerfile b/Dockerfile index ccc1ffb..4223fb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,7 @@ RUN pip3 install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 ADD requirements.txt requirements.txt RUN pip3 install -r requirements.txt ADD detection /emblem/detection +ADD models /emblem/detection/models ADD alg /emblem/alg ADD api /emblem/api ADD web /emblem/web diff --git a/Makefile b/Makefile index 9d15e17..6ea8907 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ ALG_FILES = \ ) \ ) -docker-build: build/Dockerfile build/packages.txt build/requirements.txt \ +docker-build: build/Dockerfile build/packages.txt build/requirements.txt download-models \ build/nginx.conf $(WEB_FILES) $(API_FILES) $(ALG_FILES) $(DETECTION_FILES) $(SCRIPTS_FILES) $(DATASET_FILES) find build docker build --network=host -t $(IMAGE) build @@ -147,3 +147,7 @@ opencv.js: opencv/src/LICENSE FORCE alg/qrtool: make -C alg qrtool + + +download-models: FORCE + cd detection && python3 app.py --download-models-only \ No newline at end of file diff --git a/detection/app.py b/detection/app.py index 5bcc5c2..7fd39ce 100755 --- a/detection/app.py +++ b/detection/app.py @@ -15,20 +15,29 @@ from qr_verify_Tool.affinity import roi_affinity_siml from qr_verify_Tool.roi_img_process import * from qr_verify_Tool.models import * from utils import get_model +from thirdTool import qrsrgan app = Flask(__name__) # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "0" -from matplotlib.ticker import NullLocator -device = os.environ.get("CUDA_DEVICE", "cpu") -qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device) -qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device) +class Detection(object): + def get_models(self): + self.qr_cloud_detect_model = get_model('qr_cloud_detect_20230928.pt') + self.qr_roi_cloud_detect_model = get_model('qr_roi_cloud_detect_20230928.pt') + self.qr_net_g_2000_model = get_model('net_g_2000.pth') + self.roi_net_g_model = get_model("roi_net_g_20240306.pth") -from thirdTool import qrsrgan -dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device) -roi_generator = GeneratorUNet() -roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu'))) + def init_models(self): + device = os.environ.get("CUDA_DEVICE", "cpu") + self.qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_cloud_detect_model, device=device) + self.qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_roi_cloud_detect_model, device=device) + + self.dot_realesrgan = qrsrgan.RealsrGan(model_path=self.qr_net_g_2000_model, device=device) + self.roi_generator = GeneratorUNet() + self.roi_generator.load_state_dict(torch.load(self.roi_net_g_model, map_location=torch.device('cpu'))) + +detection = Detection() @app.route('/upload', methods=['POST', 'GET']) def upload(): @@ -64,10 +73,10 @@ def qr_roi_cloud_comparison(): std_roi_im_fix = roi_img_processing(std_roi_img) #提取特征提取区域 - ter_roi_points, qr_roi_conf = qr_roi_detect_now(ter_img) + ter_roi_points, qr_roi_conf = detection.qr_roi_detect_now(ter_img) ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1) # 图像还原计算 - ter_roi_im_re = roi_generator(ter_roi_img) + ter_roi_im_re = detection.roi_generator(ter_roi_img) # roi特征区域处理 ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re) @@ -97,7 +106,7 @@ def dot_detection_top(): threshold = request.form.get('threshold') angle = request.form.get('angle') img = Image.open(image_file) - points, conf = qr_box_detect_now(img) + points, conf = detection.qr_box_detect_now(img) def angle_ok(a): if a is None: @@ -142,7 +151,7 @@ def dot_detection_top(): return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"} #采用模型生成实现图像增强,进行网线角度检测 - dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, dot_realesrgan, save_dots=False, + dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, detection.dot_realesrgan, save_dots=False, res_code=None) if dl_sr_lines_arctan is not None: if angle_ok(dl_sr_lines_arctan): @@ -166,7 +175,7 @@ def dot_detection_bottom(): threshold = request.form.get('threshold') angle = request.form.get('angle') img = Image.open(image_file) - points, conf = qr_box_detect_now(img) + points, conf = detection.qr_box_detect_now(img) def angle_ok(a): if a is None: @@ -218,7 +227,7 @@ def dot_detection_bottom(): return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"} #采用模型生成实现图像增强,进行网线角度检测 - dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, dot_realesrgan, save_dots=False, + dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, detection.dot_realesrgan, save_dots=False, res_code=None) if dl_sr_lines_arctan is not None: if angle_ok(dl_sr_lines_arctan): @@ -241,9 +250,18 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", "-l", default="0.0.0.0") parser.add_argument("--port", "-p", type=int, default=6006) + parser.add_argument("--download-models-only", action="store_true") return parser.parse_args() -if __name__ == '__main__': +def main(): args = parse_args() - print("starting detection server on %s:%d" % (args.host, args.port)) + detection.get_models() + if args.download_models_only: + print("Models loaded successfully") + return + detection.init_models() + print(f"Starting server at {args.host}:{args.port}") app.run(host=args.host, port=args.port, debug=True) + +if __name__ == '__main__': + main() \ No newline at end of file