detection: Download model dynamically from oss bucket

This commit is contained in:
Fam Zheng 2024-09-01 22:10:16 +01:00
parent 8a4f10aeed
commit cc33eebbe5
3 changed files with 31 additions and 4 deletions

View File

@ -14,6 +14,7 @@ import json
from qr_verify_Tool.affinity import roi_affinity_siml from qr_verify_Tool.affinity import roi_affinity_siml
from qr_verify_Tool.roi_img_process import * from qr_verify_Tool.roi_img_process import *
from qr_verify_Tool.models import * from qr_verify_Tool.models import *
from utils import get_model
app = Flask(__name__) app = Flask(__name__)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -21,13 +22,13 @@ app = Flask(__name__)
from matplotlib.ticker import NullLocator from matplotlib.ticker import NullLocator
device = os.environ.get("CUDA_DEVICE", "cpu") device = os.environ.get("CUDA_DEVICE", "cpu")
qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_cloud_detect_20230928.pt', device=device) qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device)
qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_roi_cloud_detect_20230928.pt', device=device) qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device)
from thirdTool import qrsrgan from thirdTool import qrsrgan
dot_realesrgan = qrsrgan.RealsrGan(model_path='model/net_g_2000.pth', device=device) dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device)
roi_generator = GeneratorUNet() roi_generator = GeneratorUNet()
roi_generator.load_state_dict(torch.load("model/roi_net_g_20240306.pth", map_location=torch.device('cpu'))) roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu')))
@app.route('/upload', methods=['POST', 'GET']) @app.route('/upload', methods=['POST', 'GET'])
def upload(): def upload():

View File

@ -6,3 +6,4 @@ opencv-python==4.5.5.62
timm==0.9.2 timm==0.9.2
pandas pandas
seaborn seaborn
oss2

25
detection/utils.py Normal file
View File

@ -0,0 +1,25 @@
import os
import oss2
def get_model(name):
fn = os.path.join(os.path.dirname(__file__), "model", name)
if os.path.exists(fn):
return fn
with open(fn + ".tmp", 'wb') as tf:
tf.write(oss_get(name, "emblem-models"))
os.rename(fn + ".tmp", fn)
return fn
def oss_get(name, bucket=None):
try:
return oss_bucket(bucket).get_object(name).read()
except oss2.exceptions.NoSuchKey:
return None
def oss_bucket(bucketname):
auth = oss2.Auth('LTAI5tC2qXGxwHZUZP7DoD1A', 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm')
bucket = oss2.Bucket(auth, 'oss-rg-china-mainland.aliyuncs.com', bucketname)
return bucket
if __name__ == "__main__":
print(get_model('sr.prototxt'))