detection: Download model dynamically from oss bucket
This commit is contained in:
parent
8a4f10aeed
commit
cc33eebbe5
@ -14,6 +14,7 @@ import json
|
|||||||
from qr_verify_Tool.affinity import roi_affinity_siml
|
from qr_verify_Tool.affinity import roi_affinity_siml
|
||||||
from qr_verify_Tool.roi_img_process import *
|
from qr_verify_Tool.roi_img_process import *
|
||||||
from qr_verify_Tool.models import *
|
from qr_verify_Tool.models import *
|
||||||
|
from utils import get_model
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
@ -21,13 +22,13 @@ app = Flask(__name__)
|
|||||||
from matplotlib.ticker import NullLocator
|
from matplotlib.ticker import NullLocator
|
||||||
device = os.environ.get("CUDA_DEVICE", "cpu")
|
device = os.environ.get("CUDA_DEVICE", "cpu")
|
||||||
|
|
||||||
qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_cloud_detect_20230928.pt', device=device)
|
qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device)
|
||||||
qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path='model/qr_roi_cloud_detect_20230928.pt', device=device)
|
qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device)
|
||||||
|
|
||||||
from thirdTool import qrsrgan
|
from thirdTool import qrsrgan
|
||||||
dot_realesrgan = qrsrgan.RealsrGan(model_path='model/net_g_2000.pth', device=device)
|
dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device)
|
||||||
roi_generator = GeneratorUNet()
|
roi_generator = GeneratorUNet()
|
||||||
roi_generator.load_state_dict(torch.load("model/roi_net_g_20240306.pth", map_location=torch.device('cpu')))
|
roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu')))
|
||||||
|
|
||||||
@app.route('/upload', methods=['POST', 'GET'])
|
@app.route('/upload', methods=['POST', 'GET'])
|
||||||
def upload():
|
def upload():
|
||||||
|
|||||||
@ -6,3 +6,4 @@ opencv-python==4.5.5.62
|
|||||||
timm==0.9.2
|
timm==0.9.2
|
||||||
pandas
|
pandas
|
||||||
seaborn
|
seaborn
|
||||||
|
oss2
|
||||||
|
|||||||
25
detection/utils.py
Normal file
25
detection/utils.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
import oss2
|
||||||
|
|
||||||
|
def get_model(name):
|
||||||
|
fn = os.path.join(os.path.dirname(__file__), "model", name)
|
||||||
|
if os.path.exists(fn):
|
||||||
|
return fn
|
||||||
|
with open(fn + ".tmp", 'wb') as tf:
|
||||||
|
tf.write(oss_get(name, "emblem-models"))
|
||||||
|
os.rename(fn + ".tmp", fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def oss_get(name, bucket=None):
|
||||||
|
try:
|
||||||
|
return oss_bucket(bucket).get_object(name).read()
|
||||||
|
except oss2.exceptions.NoSuchKey:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def oss_bucket(bucketname):
|
||||||
|
auth = oss2.Auth('LTAI5tC2qXGxwHZUZP7DoD1A', 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm')
|
||||||
|
bucket = oss2.Bucket(auth, 'oss-rg-china-mainland.aliyuncs.com', bucketname)
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(get_model('sr.prototxt'))
|
||||||
Loading…
x
Reference in New Issue
Block a user