themblem/emblem5/ai/server.py
2026-01-27 22:34:08 +00:00

336 lines
11 KiB
Python
Executable File

#! /usr/bin/env python3
import flask
import oss2
import os
from PIL import Image
import re
from common import *
import io
import time
from ossclient import *
import argparse
import hashlib
'''
Emblem infer service.
This provides a simple http api for torchvision model inference.
The model is downloaded from a predefined aliyun oss bucket.
'''
app = flask.Flask(__name__)
local_model_dir = 'models'
data_dir = 'data'
scans_dir = os.path.join(data_dir, 'scans')
os.makedirs(local_model_dir, exist_ok=True)
# Cache for loaded models to avoid reloading on every request
_model_cache = {}
def get_file_md5(fname):
return hashlib.md5(open(fname, 'rb').read()).hexdigest()
def download_model(model_name):
model_path = os.path.join(local_model_dir, model_name)
if not os.path.exists(model_path):
obj = oss_get_object('https://oss-rg-china-mainland.aliyuncs.com', 'emblem-models', model_name)
with open(model_path, 'wb') as f:
f.write(obj.read())
return model_path, get_file_md5(model_path)
report_dir = '/tmp/emblem-reports/'
@app.route('/api/v5/report', methods=['POST'])
def report():
os.makedirs(report_dir, exist_ok=True)
for file in flask.request.files.values():
with open(os.path.join(report_dir, file.filename), 'wb') as f:
f.write(file.read())
return {
"ok": True,
}
@app.route('/api/v5/report/<name>', methods=['GET'])
def report_file(name):
return flask.send_file(os.path.join(report_dir, name))
@app.route('/api/v5/qr_verify', methods=['POST'])
def qr_verify():
try:
return do_qr_verify()
except Exception as e:
return {
"ok": False,
"error": str(e),
}
def find_best_frame(files):
best_frame = None
clarities = {}
best_clarity = 0
for file in files.values():
img = Image.open(file)
img_qrcode, img_qr = extract_qr(img)
if not img_qrcode:
continue
clarity = calc_clarity(img_qr)
clarities[file.filename] = clarity
if not best_frame or clarity > best_clarity:
best_frame = img
best_clarity = clarity
return best_frame, clarities
def do_qr_verify():
start_time = time.time()
fd = flask.request.form
model_name = fd.get('model', default_model)
model_path, model_md5 = download_model(model_name)
# Cache models by path to avoid reloading on every request
if model_path not in _model_cache:
_model_cache[model_path] = load_model(model_path)
model, transforms = _model_cache[model_path]
frame_image, clarities = find_best_frame(flask.request.files)
frame_qrcode, _ = extract_qr(frame_image)
std_image = Image.open(io.BytesIO(get_qr_image_bytes(frame_qrcode)))
if not std_image:
return {
"ok": False,
"error": f"No std image: {frame_qrcode}",
}
std_qrcode, _ = extract_qr(std_image)
if frame_qrcode != std_qrcode:
return {
"ok": False,
"error": "QR code mismatch",
}
predicted_class, probabilities = verify_frame(model, transforms, frame_image, std_image)
return {
"ok": True,
"result": {
"process_time": time.time() - start_time,
"predicted_class": predicted_class,
"probabilities": ', '.join([f'{v:.2%}' for k, v in probabilities]),
"clarities": clarities,
"model_md5": model_md5,
"model_name": model_name,
}
}
def make_qrsbs(frame_qrcode, frame_image):
qr_img = get_qr_image(frame_qrcode)
if not qr_img:
return None
ret = make_side_by_side_img(frame_image, qr_img)
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
ret.save(f.name)
with open(f.name, 'rb') as f:
return f.read()
@app.route('/api/v5/qrsbs', methods=['POST'])
def qrsb():
frame = flask.request.files['frame']
if not frame:
return {
"ok": False,
"error": "frame is required",
}
frame_image = Image.open(frame)
frame_qrcode, _ = extract_qr(frame_image)
cache_key = f'qrsbs_v1/{frame_qrcode}.jpg'
try:
qrsbs = oss_get_object('https://oss-cn-guangzhou.aliyuncs.com', 'emblem-cache', cache_key)
except Exception as e:
print(f'cache miss: {cache_key}')
qrsbs = make_qrsbs(frame_qrcode, frame_image)
if not qrsbs:
return flask.abort(404, 'QR code not found')
oss_put_object('https://oss-cn-guangzhou.aliyuncs.com', 'emblem-cache', cache_key, qrsbs)
return flask.send_file(io.BytesIO(qrsbs), mimetype='image/jpeg')
@app.route('/api/v5/infer/version', methods=['GET'])
def version():
return {'version': '1.0.0'}
@app.route('/api/v5/frame/<session_id>/<frame_id>', methods=['POST'])
def frame(session_id, frame_id):
data = flask.request.get_data()
auth = oss2.Auth(oss_ak, oss_sk)
endpoint = 'https://oss-cn-shenzhen.aliyuncs.com'
bucket = oss2.Bucket(auth, endpoint, 'emblem-frames-prod')
bucket.put_object(f'v5/{session_id}/{frame_id}', data)
return {
"ok": True,
}
def make_data_url(fname):
bs = open(fname, 'rb').read()
encoded = base64.b64encode(bs).decode()
return f'data:image/jpeg;base64,{encoded}'
def bottomcorner(img):
width = img.width
height = img.height
return img.crop((width - width // 4, height - height // 4, width, height))
def prepare_clarities(scan_dir):
frame_img = Image.open(os.path.join(scan_dir, 'frame-qr.jpg'))
std_img = Image.open(os.path.join(scan_dir, 'std-qr.jpg'))
frame_clarity = int(calc_clarity(bottomcorner(frame_img)))
std_clarity = int(calc_clarity(bottomcorner(std_img)))
return {
'relative': int(frame_clarity / std_clarity * 100),
'frame': frame_clarity,
'std': std_clarity,
}
def prepare_scan(scan, predicted_classes={}):
scan_dir = os.path.join(scans_dir, scan)
if not os.path.exists(scan_dir):
return None
files = os.listdir(scan_dir)
mdfile = os.path.join(scan_dir, 'metadata.json')
if os.path.exists(mdfile):
with open(mdfile, 'r') as f:
md = json.load(f)
else:
md = {}
frame_qr = Image.open(os.path.join(scan_dir, 'frame-qr.jpg'))
std_qr = Image.open(os.path.join(scan_dir, 'std-qr.jpg'))
return {
'name': scan,
'files': files,
'labels': md.get('labels', '').split(','),
'predicted_class': predicted_classes.get(scan),
'clarities': prepare_clarities(scan_dir),
'frame_qr_size': [frame_qr.width, frame_qr.height],
'std_qr_size': [std_qr.width, std_qr.height],
}
def match_one_term(md, term, predicted_class):
if term in ['mispredicted', 'incorrect']:
if predicted_class is None:
return False
pcname = 'pos' if predicted_class else 'neg'
labels = md.get('labels', '')
return labels and pcname not in labels
if term == 'correct':
if predicted_class is None:
return False
pcname = 'pos' if predicted_class else 'neg'
labels = md.get('labels', '')
return labels and pcname in labels
if term == 'pos':
return 'pos' in md.get('labels', '')
if term == 'neg':
return 'neg' in md.get('labels', '')
if term == 'succeeded':
return md.get('succeeded') == True
if term == 'failed':
return md.get('succeeded') == False
if re.match(r'^[0-9]+-[0-9]+$', term):
fs = term.split('-')
return int(fs[0]) <= int(md.get('id')) <= int(fs[1])
def match_query(md, q, predicted_class):
ret = True
for term in q.split() if q else []:
ret = ret and match_one_term(md, term, predicted_class)
return ret
def search_scans(q, predicted_classes):
ret = []
all_scans = os.listdir(scans_dir)
random.shuffle(all_scans)
for scan in all_scans:
scan_dir = os.path.join(scans_dir, scan)
state_file = os.path.join(scan_dir, 'fetch-state.json')
std_qr_file = os.path.join(scan_dir, 'std-qr.jpg')
if not os.path.exists(state_file) or not os.path.exists(std_qr_file):
continue
with open(os.path.join(scan_dir, 'metadata.json'), 'r') as f:
md = json.load(f)
if not match_query(md, q, predicted_classes.get(scan)):
continue
ret.append(scan)
return ret
def highlight_nongray(image):
saturation_threshold = 30
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
s_channel = hsv_image[:, :, 1]
return Image.fromarray(s_channel)
@app.route('/api/sbs/<scan>.jpg', methods=['GET'])
def sbs(scan):
scan_dir = os.path.join(scans_dir, scan)
if not os.path.exists(scan_dir):
return {
"ok": False,
"error": f"Scan {scan} not found",
}
frame_img = Image.open(os.path.join(scan_dir, 'frame.jpg'))
std_img = Image.open(os.path.join(scan_dir, 'std.jpg'))
ret = make_side_by_side_img_with_margins(frame_img, std_img)
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
ret.save(f.name)
return flask.send_file(f.name, mimetype='image/jpeg')
@app.route('/api/allinone/<scan>.jpg', methods=['GET'])
def allinone(scan):
scan_dir = os.path.join(scans_dir, scan)
if not os.path.exists(scan_dir):
return {
"ok": False,
"error": f"Scan {scan} not found",
}
frame_img = Image.open(os.path.join(scan_dir, 'frame.jpg'))
std_img = Image.open(os.path.join(scan_dir, 'std.jpg'))
sbs_img = make_side_by_side_img_with_margins(frame_img, std_img)
ret_width = sbs_img.width
ret_height = sbs_img.height + frame_img.height
ret = Image.new('RGB', (ret_width, ret_height))
ret.paste(sbs_img, (0, 0))
ret.paste(frame_img, (0, sbs_img.height))
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as f:
ret.save(f.name)
return flask.send_file(f.name, mimetype='image/jpeg')
@app.route('/api/scans', methods=['GET'])
def scans():
q = flask.request.args.get('filter', None)
verify_result = flask.request.args.get('verify_result', None)
if verify_result:
with open(f'data/{verify_result}', 'r') as f:
predicted_classes = json.load(f)
else:
predicted_classes = {}
scans = search_scans(q, predicted_classes)
scans = scans[:200]
return {
"scans": [prepare_scan(s, predicted_classes) for s in scans],
}
@app.route('/api/verify_results', methods=['GET'])
def verify_results():
return {
'results': [x for x in os.listdir('data') if x.endswith('.json')],
}
@app.route('/api/data/<path:path>', methods=['GET'])
def data(path):
return flask.send_file(os.path.abspath(os.path.join(data_dir, path)))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=6500)
parser.add_argument('--debug', action='store_true')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
app.run(host='0.0.0.0', port=args.port, debug=args.debug)