328 lines
11 KiB
Python
Executable File
328 lines
11 KiB
Python
Executable File
#! /usr/bin/env python3
|
|
import flask
|
|
import oss2
|
|
import os
|
|
from PIL import Image
|
|
import re
|
|
from common import *
|
|
import io
|
|
import time
|
|
from ossclient import *
|
|
import argparse
|
|
import hashlib
|
|
|
|
'''
|
|
Emblem infer service.
|
|
|
|
This provides a simple http api for torchvision model inference.
|
|
|
|
The model is downloaded from a predefined aliyun oss bucket.
|
|
|
|
'''
|
|
|
|
app = flask.Flask(__name__)
|
|
|
|
local_model_dir = 'models'
|
|
data_dir = 'data'
|
|
scans_dir = os.path.join(data_dir, 'scans')
|
|
|
|
os.makedirs(local_model_dir, exist_ok=True)
|
|
|
|
def get_file_md5(fname):
|
|
return hashlib.md5(open(fname, 'rb').read()).hexdigest()
|
|
|
|
def download_model(model_name):
|
|
model_path = os.path.join(local_model_dir, model_name)
|
|
if not os.path.exists(model_path):
|
|
obj = oss_get_object('https://oss-rg-china-mainland.aliyuncs.com', 'emblem-models', model_name)
|
|
with open(model_path, 'wb') as f:
|
|
f.write(obj.read())
|
|
return model_path, get_file_md5(model_path)
|
|
|
|
report_dir = '/tmp/emblem-reports/'
|
|
@app.route('/api/v5/report', methods=['POST'])
|
|
def report():
|
|
os.makedirs(report_dir, exist_ok=True)
|
|
for file in flask.request.files.values():
|
|
with open(os.path.join(report_dir, file.filename), 'wb') as f:
|
|
f.write(file.read())
|
|
return {
|
|
"ok": True,
|
|
}
|
|
|
|
@app.route('/api/v5/report/<name>', methods=['GET'])
|
|
def report_file(name):
|
|
return flask.send_file(os.path.join(report_dir, name))
|
|
|
|
@app.route('/api/v5/qr_verify', methods=['POST'])
|
|
def qr_verify():
|
|
try:
|
|
return do_qr_verify()
|
|
except Exception as e:
|
|
return {
|
|
"ok": False,
|
|
"error": str(e),
|
|
}
|
|
|
|
def find_best_frame(files):
|
|
best_frame = None
|
|
clarities = {}
|
|
best_clarity = 0
|
|
for file in files.values():
|
|
img = Image.open(file)
|
|
img_qrcode, img_qr = extract_qr(img)
|
|
if not img_qrcode:
|
|
continue
|
|
clarity = calc_clarity(img_qr)
|
|
clarities[file.filename] = clarity
|
|
if not best_frame or clarity > best_clarity:
|
|
best_frame = img
|
|
best_clarity = clarity
|
|
return best_frame, clarities
|
|
|
|
def do_qr_verify():
|
|
start_time = time.time()
|
|
fd = flask.request.form
|
|
model_path, model_md5 = download_model(fd.get('model', default_model))
|
|
model, transforms = load_model(model_path)
|
|
|
|
frame_image, clarities = find_best_frame(flask.request.files)
|
|
frame_qrcode, _ = extract_qr(frame_image)
|
|
std_image = Image.open(io.BytesIO(get_qr_image_bytes(frame_qrcode)))
|
|
if not std_image:
|
|
return {
|
|
"ok": False,
|
|
"error": f"No std image: {frame_qrcode}",
|
|
}
|
|
std_qrcode, _ = extract_qr(std_image)
|
|
if frame_qrcode != std_qrcode:
|
|
return {
|
|
"ok": False,
|
|
"error": "QR code mismatch",
|
|
}
|
|
|
|
predicted_class, probabilities = verify_frame(model, transforms, frame_image, std_image)
|
|
return {
|
|
"ok": True,
|
|
"result": {
|
|
"process_time": time.time() - start_time,
|
|
"predicted_class": predicted_class,
|
|
"probabilities": ', '.join([f'{v:.2%}' for k, v in probabilities]),
|
|
"clarities": clarities,
|
|
"model_md5": model_md5,
|
|
}
|
|
}
|
|
|
|
def make_qrsbs(frame_qrcode, frame_image):
|
|
qr_img = get_qr_image(frame_qrcode)
|
|
if not qr_img:
|
|
return None
|
|
ret = make_side_by_side_img(frame_image, qr_img)
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
|
|
ret.save(f.name)
|
|
with open(f.name, 'rb') as f:
|
|
return f.read()
|
|
|
|
@app.route('/api/v5/qrsbs', methods=['POST'])
|
|
def qrsb():
|
|
frame = flask.request.files['frame']
|
|
if not frame:
|
|
return {
|
|
"ok": False,
|
|
"error": "frame is required",
|
|
}
|
|
frame_image = Image.open(frame)
|
|
frame_qrcode, _ = extract_qr(frame_image)
|
|
cache_key = f'qrsbs_v1/{frame_qrcode}.jpg'
|
|
try:
|
|
qrsbs = oss_get_object('https://oss-cn-guangzhou.aliyuncs.com', 'emblem-cache', cache_key)
|
|
except Exception as e:
|
|
print(f'cache miss: {cache_key}')
|
|
qrsbs = make_qrsbs(frame_qrcode, frame_image)
|
|
if not qrsbs:
|
|
return flask.abort(404, 'QR code not found')
|
|
oss_put_object('https://oss-cn-guangzhou.aliyuncs.com', 'emblem-cache', cache_key, qrsbs)
|
|
return flask.send_file(io.BytesIO(qrsbs), mimetype='image/jpeg')
|
|
|
|
@app.route('/api/v5/infer/version', methods=['GET'])
|
|
def version():
|
|
return {'version': '1.0.0'}
|
|
|
|
@app.route('/api/v5/frame/<session_id>/<frame_id>', methods=['POST'])
|
|
def frame(session_id, frame_id):
|
|
data = flask.request.get_data()
|
|
auth = oss2.Auth(oss_ak, oss_sk)
|
|
endpoint = 'https://oss-cn-shenzhen.aliyuncs.com'
|
|
bucket = oss2.Bucket(auth, endpoint, 'emblem-frames-prod')
|
|
bucket.put_object(f'v5/{session_id}/{frame_id}', data)
|
|
return {
|
|
"ok": True,
|
|
}
|
|
|
|
def make_data_url(fname):
|
|
bs = open(fname, 'rb').read()
|
|
encoded = base64.b64encode(bs).decode()
|
|
return f'data:image/jpeg;base64,{encoded}'
|
|
|
|
def bottomcorner(img):
|
|
width = img.width
|
|
height = img.height
|
|
return img.crop((width - width // 4, height - height // 4, width, height))
|
|
|
|
def prepare_clarities(scan_dir):
|
|
frame_img = Image.open(os.path.join(scan_dir, 'frame-qr.jpg'))
|
|
std_img = Image.open(os.path.join(scan_dir, 'std-qr.jpg'))
|
|
frame_clarity = int(calc_clarity(bottomcorner(frame_img)))
|
|
std_clarity = int(calc_clarity(bottomcorner(std_img)))
|
|
return {
|
|
'relative': int(frame_clarity / std_clarity * 100),
|
|
'frame': frame_clarity,
|
|
'std': std_clarity,
|
|
}
|
|
|
|
def prepare_scan(scan, predicted_classes={}):
|
|
scan_dir = os.path.join(scans_dir, scan)
|
|
if not os.path.exists(scan_dir):
|
|
return None
|
|
files = os.listdir(scan_dir)
|
|
mdfile = os.path.join(scan_dir, 'metadata.json')
|
|
if os.path.exists(mdfile):
|
|
with open(mdfile, 'r') as f:
|
|
md = json.load(f)
|
|
else:
|
|
md = {}
|
|
frame_qr = Image.open(os.path.join(scan_dir, 'frame-qr.jpg'))
|
|
std_qr = Image.open(os.path.join(scan_dir, 'std-qr.jpg'))
|
|
return {
|
|
'name': scan,
|
|
'files': files,
|
|
'labels': md.get('labels', '').split(','),
|
|
'predicted_class': predicted_classes.get(scan),
|
|
'clarities': prepare_clarities(scan_dir),
|
|
'frame_qr_size': [frame_qr.width, frame_qr.height],
|
|
'std_qr_size': [std_qr.width, std_qr.height],
|
|
}
|
|
|
|
def match_one_term(md, term, predicted_class):
|
|
if term in ['mispredicted', 'incorrect']:
|
|
if predicted_class is None:
|
|
return False
|
|
pcname = 'pos' if predicted_class else 'neg'
|
|
labels = md.get('labels', '')
|
|
return labels and pcname not in labels
|
|
if term == 'correct':
|
|
if predicted_class is None:
|
|
return False
|
|
pcname = 'pos' if predicted_class else 'neg'
|
|
labels = md.get('labels', '')
|
|
return labels and pcname in labels
|
|
if term == 'pos':
|
|
return 'pos' in md.get('labels', '')
|
|
if term == 'neg':
|
|
return 'neg' in md.get('labels', '')
|
|
if term == 'succeeded':
|
|
return md.get('succeeded') == True
|
|
if term == 'failed':
|
|
return md.get('succeeded') == False
|
|
if re.match(r'^[0-9]+-[0-9]+$', term):
|
|
fs = term.split('-')
|
|
return int(fs[0]) <= int(md.get('id')) <= int(fs[1])
|
|
|
|
def match_query(md, q, predicted_class):
|
|
ret = True
|
|
for term in q.split() if q else []:
|
|
ret = ret and match_one_term(md, term, predicted_class)
|
|
return ret
|
|
|
|
def search_scans(q, predicted_classes):
|
|
ret = []
|
|
all_scans = os.listdir(scans_dir)
|
|
random.shuffle(all_scans)
|
|
for scan in all_scans:
|
|
scan_dir = os.path.join(scans_dir, scan)
|
|
state_file = os.path.join(scan_dir, 'fetch-state.json')
|
|
std_qr_file = os.path.join(scan_dir, 'std-qr.jpg')
|
|
if not os.path.exists(state_file) or not os.path.exists(std_qr_file):
|
|
continue
|
|
with open(os.path.join(scan_dir, 'metadata.json'), 'r') as f:
|
|
md = json.load(f)
|
|
if not match_query(md, q, predicted_classes.get(scan)):
|
|
continue
|
|
ret.append(scan)
|
|
return ret
|
|
|
|
def highlight_nongray(image):
|
|
saturation_threshold = 30
|
|
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
|
s_channel = hsv_image[:, :, 1]
|
|
return Image.fromarray(s_channel)
|
|
|
|
@app.route('/api/sbs/<scan>.jpg', methods=['GET'])
|
|
def sbs(scan):
|
|
scan_dir = os.path.join(scans_dir, scan)
|
|
if not os.path.exists(scan_dir):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Scan {scan} not found",
|
|
}
|
|
frame_img = Image.open(os.path.join(scan_dir, 'frame.jpg'))
|
|
std_img = Image.open(os.path.join(scan_dir, 'std.jpg'))
|
|
ret = make_side_by_side_img_with_margins(frame_img, std_img)
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
|
|
ret.save(f.name)
|
|
return flask.send_file(f.name, mimetype='image/jpeg')
|
|
|
|
@app.route('/api/allinone/<scan>.jpg', methods=['GET'])
|
|
def allinone(scan):
|
|
scan_dir = os.path.join(scans_dir, scan)
|
|
if not os.path.exists(scan_dir):
|
|
return {
|
|
"ok": False,
|
|
"error": f"Scan {scan} not found",
|
|
}
|
|
frame_img = Image.open(os.path.join(scan_dir, 'frame.jpg'))
|
|
std_img = Image.open(os.path.join(scan_dir, 'std.jpg'))
|
|
sbs_img = make_side_by_side_img_with_margins(frame_img, std_img)
|
|
ret_width = sbs_img.width
|
|
ret_height = sbs_img.height + frame_img.height
|
|
ret = Image.new('RGB', (ret_width, ret_height))
|
|
ret.paste(sbs_img, (0, 0))
|
|
ret.paste(frame_img, (0, sbs_img.height))
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
|
|
ret.save(f.name)
|
|
return flask.send_file(f.name, mimetype='image/jpeg')
|
|
|
|
@app.route('/api/scans', methods=['GET'])
|
|
def scans():
|
|
q = flask.request.args.get('filter', None)
|
|
verify_result = flask.request.args.get('verify_result', None)
|
|
if verify_result:
|
|
with open(f'data/{verify_result}', 'r') as f:
|
|
predicted_classes = json.load(f)
|
|
else:
|
|
predicted_classes = {}
|
|
scans = search_scans(q, predicted_classes)
|
|
scans = scans[:200]
|
|
return {
|
|
"scans": [prepare_scan(s, predicted_classes) for s in scans],
|
|
}
|
|
|
|
@app.route('/api/verify_results', methods=['GET'])
|
|
def verify_results():
|
|
return {
|
|
'results': [x for x in os.listdir('data') if x.endswith('.json')],
|
|
}
|
|
|
|
@app.route('/api/data/<path:path>', methods=['GET'])
|
|
def data(path):
|
|
return flask.send_file(os.path.abspath(os.path.join(data_dir, path)))
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--port', type=int, default=6500)
|
|
parser.add_argument('--debug', action='store_true')
|
|
return parser.parse_args()
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
app.run(host='0.0.0.0', port=args.port, debug=args.debug) |