improve ci

This commit is contained in:
Fam Zheng 2024-12-22 10:46:58 +00:00
parent 8703f99c0d
commit ccf0fe5aba
4 changed files with 55 additions and 17 deletions

View File

@ -1,4 +1,5 @@
stages: stages:
- download-models
- test-and-build - test-and-build
- build-docker - build-docker
- deploy - deploy
@ -8,6 +9,17 @@ cache:
paths: paths:
- opencv/ - opencv/
- emtest/target - emtest/target
download-models:
stage: download-models
script:
- make download-models
artifacts:
paths:
- detection/model
cache:
key: models
paths:
- detection/model - detection/model
test: test:
@ -21,6 +33,8 @@ test:
- make opencv -j$(nproc --ignore=2) - make opencv -j$(nproc --ignore=2)
- make -C alg qrtool -j$(nproc --ignore=2) - make -C alg qrtool -j$(nproc --ignore=2)
- make test - make test
dependencies:
- download-models
build-alg: build-alg:
stage: test-and-build stage: test-and-build
@ -56,6 +70,7 @@ build-docker:
dependencies: dependencies:
- build-web - build-web
- build-alg - build-alg
- download-models
except: except:
- main - main

View File

@ -5,6 +5,7 @@ RUN pip3 install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0
ADD requirements.txt requirements.txt ADD requirements.txt requirements.txt
RUN pip3 install -r requirements.txt RUN pip3 install -r requirements.txt
ADD detection /emblem/detection ADD detection /emblem/detection
ADD models /emblem/detection/models
ADD alg /emblem/alg ADD alg /emblem/alg
ADD api /emblem/api ADD api /emblem/api
ADD web /emblem/web ADD web /emblem/web

View File

@ -61,7 +61,7 @@ ALG_FILES = \
) \ ) \
) )
docker-build: build/Dockerfile build/packages.txt build/requirements.txt \ docker-build: build/Dockerfile build/packages.txt build/requirements.txt download-models \
build/nginx.conf $(WEB_FILES) $(API_FILES) $(ALG_FILES) $(DETECTION_FILES) $(SCRIPTS_FILES) $(DATASET_FILES) build/nginx.conf $(WEB_FILES) $(API_FILES) $(ALG_FILES) $(DETECTION_FILES) $(SCRIPTS_FILES) $(DATASET_FILES)
find build find build
docker build --network=host -t $(IMAGE) build docker build --network=host -t $(IMAGE) build
@ -147,3 +147,7 @@ opencv.js: opencv/src/LICENSE FORCE
alg/qrtool: alg/qrtool:
make -C alg qrtool make -C alg qrtool
download-models: FORCE
cd detection && python3 app.py --download-models-only

View File

@ -15,20 +15,29 @@ from qr_verify_Tool.affinity import roi_affinity_siml
from qr_verify_Tool.roi_img_process import * from qr_verify_Tool.roi_img_process import *
from qr_verify_Tool.models import * from qr_verify_Tool.models import *
from utils import get_model from utils import get_model
from thirdTool import qrsrgan
app = Flask(__name__) app = Flask(__name__)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from matplotlib.ticker import NullLocator
class Detection(object):
def get_models(self):
self.qr_cloud_detect_model = get_model('qr_cloud_detect_20230928.pt')
self.qr_roi_cloud_detect_model = get_model('qr_roi_cloud_detect_20230928.pt')
self.qr_net_g_2000_model = get_model('net_g_2000.pth')
self.roi_net_g_model = get_model("roi_net_g_20240306.pth")
def init_models(self):
device = os.environ.get("CUDA_DEVICE", "cpu") device = os.environ.get("CUDA_DEVICE", "cpu")
self.qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_cloud_detect_model, device=device)
self.qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_roi_cloud_detect_model, device=device)
qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device) self.dot_realesrgan = qrsrgan.RealsrGan(model_path=self.qr_net_g_2000_model, device=device)
qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device) self.roi_generator = GeneratorUNet()
self.roi_generator.load_state_dict(torch.load(self.roi_net_g_model, map_location=torch.device('cpu')))
from thirdTool import qrsrgan detection = Detection()
dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device)
roi_generator = GeneratorUNet()
roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu')))
@app.route('/upload', methods=['POST', 'GET']) @app.route('/upload', methods=['POST', 'GET'])
def upload(): def upload():
@ -64,10 +73,10 @@ def qr_roi_cloud_comparison():
std_roi_im_fix = roi_img_processing(std_roi_img) std_roi_im_fix = roi_img_processing(std_roi_img)
#提取特征提取区域 #提取特征提取区域
ter_roi_points, qr_roi_conf = qr_roi_detect_now(ter_img) ter_roi_points, qr_roi_conf = detection.qr_roi_detect_now(ter_img)
ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1) ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1)
# 图像还原计算 # 图像还原计算
ter_roi_im_re = roi_generator(ter_roi_img) ter_roi_im_re = detection.roi_generator(ter_roi_img)
# roi特征区域处理 # roi特征区域处理
ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re) ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re)
@ -97,7 +106,7 @@ def dot_detection_top():
threshold = request.form.get('threshold') threshold = request.form.get('threshold')
angle = request.form.get('angle') angle = request.form.get('angle')
img = Image.open(image_file) img = Image.open(image_file)
points, conf = qr_box_detect_now(img) points, conf = detection.qr_box_detect_now(img)
def angle_ok(a): def angle_ok(a):
if a is None: if a is None:
@ -142,7 +151,7 @@ def dot_detection_top():
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"} return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
#采用模型生成实现图像增强,进行网线角度检测 #采用模型生成实现图像增强,进行网线角度检测
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, dot_realesrgan, save_dots=False, dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, detection.dot_realesrgan, save_dots=False,
res_code=None) res_code=None)
if dl_sr_lines_arctan is not None: if dl_sr_lines_arctan is not None:
if angle_ok(dl_sr_lines_arctan): if angle_ok(dl_sr_lines_arctan):
@ -166,7 +175,7 @@ def dot_detection_bottom():
threshold = request.form.get('threshold') threshold = request.form.get('threshold')
angle = request.form.get('angle') angle = request.form.get('angle')
img = Image.open(image_file) img = Image.open(image_file)
points, conf = qr_box_detect_now(img) points, conf = detection.qr_box_detect_now(img)
def angle_ok(a): def angle_ok(a):
if a is None: if a is None:
@ -218,7 +227,7 @@ def dot_detection_bottom():
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"} return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
#采用模型生成实现图像增强,进行网线角度检测 #采用模型生成实现图像增强,进行网线角度检测
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, dot_realesrgan, save_dots=False, dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, detection.dot_realesrgan, save_dots=False,
res_code=None) res_code=None)
if dl_sr_lines_arctan is not None: if dl_sr_lines_arctan is not None:
if angle_ok(dl_sr_lines_arctan): if angle_ok(dl_sr_lines_arctan):
@ -241,9 +250,18 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", "-l", default="0.0.0.0") parser.add_argument("--host", "-l", default="0.0.0.0")
parser.add_argument("--port", "-p", type=int, default=6006) parser.add_argument("--port", "-p", type=int, default=6006)
parser.add_argument("--download-models-only", action="store_true")
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': def main():
args = parse_args() args = parse_args()
print("starting detection server on %s:%d" % (args.host, args.port)) detection.get_models()
if args.download_models_only:
print("Models loaded successfully")
return
detection.init_models()
print(f"Starting server at {args.host}:{args.port}")
app.run(host=args.host, port=args.port, debug=True) app.run(host=args.host, port=args.port, debug=True)
if __name__ == '__main__':
main()