diff --git a/detection/app.py b/detection/app.py index 8313ca4..71e243c 100755 --- a/detection/app.py +++ b/detection/app.py @@ -14,6 +14,7 @@ import json from qr_verify_Tool.affinity import roi_affinity_siml from qr_verify_Tool.roi_img_process import * from qr_verify_Tool.models import * +from utils import get_model app = Flask(__name__) # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -21,13 +22,13 @@ app = Flask(__name__) from matplotlib.ticker import NullLocator device = os.environ.get("CUDA_DEVICE", "cpu") -qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_cloud_detect_20230928.pt', device=device) -qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_roi_cloud_detect_20230928.pt', device=device) +qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device) +qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device) from thirdTool import qrsrgan -dot_realesrgan = qrsrgan.RealsrGan(model_path='model/net_g_2000.pth', device=device) +dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device) roi_generator = GeneratorUNet() -roi_generator.load_state_dict(torch.load("model/roi_net_g_20240306.pth", map_location=torch.device('cpu'))) +roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu'))) @app.route('/upload', methods=['POST', 'GET']) def upload(): diff --git a/detection/requirements.txt b/detection/requirements.txt index 5e3a482..a304291 100644 --- a/detection/requirements.txt +++ b/detection/requirements.txt @@ -6,3 +6,4 @@ opencv-python==4.5.5.62 timm==0.9.2 pandas seaborn +oss2 diff --git a/detection/utils.py b/detection/utils.py new file mode 100644 index 0000000..c5c5fa6 --- /dev/null +++ b/detection/utils.py @@ -0,0 +1,25 @@ +import os +import oss2 + +def get_model(name): + fn = os.path.join(os.path.dirname(__file__), "model", name) + if os.path.exists(fn): + return fn + with open(fn + ".tmp", 'wb') as tf: + tf.write(oss_get(name, "emblem-models")) + os.rename(fn + ".tmp", fn) + return fn + +def oss_get(name, bucket=None): + try: + return oss_bucket(bucket).get_object(name).read() + except oss2.exceptions.NoSuchKey: + return None + +def oss_bucket(bucketname): + auth = oss2.Auth('LTAI5tC2qXGxwHZUZP7DoD1A', 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm') + bucket = oss2.Bucket(auth, 'oss-rg-china-mainland.aliyuncs.com', bucketname) + return bucket + +if __name__ == "__main__": + print(get_model('sr.prototxt'))