442 lines
15 KiB
Python
Executable File
442 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import requests
|
|
import argparse
|
|
import logging
|
|
import base64
|
|
from multiprocessing import Pool
|
|
from tqdm import tqdm
|
|
import oss2
|
|
|
|
class SubCommand(object):
|
|
""" Base class of subcommand"""
|
|
help = ""
|
|
aliases = []
|
|
want_argv = False # Whether the command accepts extra arguments
|
|
|
|
envs = {
|
|
'prod': {
|
|
'server': 'https://themblem.com',
|
|
'token': '3ebd8c33-f46e-4b06-bda8-4c0f5f5eb530',
|
|
},
|
|
'dev': {
|
|
'server': 'https://dev.themblem.com',
|
|
'token': 'D91AB64B-C5CA-4B78-AC76-01722B8C8A5C',
|
|
},
|
|
}
|
|
|
|
def setup_args(self, parser):
|
|
pass
|
|
|
|
def do(self, args, argv):
|
|
"""Do command"""
|
|
print("Not implemented")
|
|
|
|
def get_env(self):
|
|
return self.envs[self.args.env]
|
|
|
|
def get_server(self):
|
|
return self.get_env()['server']
|
|
|
|
def make_headers(self):
|
|
return {
|
|
'Authorization': 'token ' + self.get_env()['token'],
|
|
}
|
|
|
|
class ActivateCommand(SubCommand):
|
|
name = "activate"
|
|
want_argv = True
|
|
help = "Activate code"
|
|
|
|
def setup_args(self, parser):
|
|
pass
|
|
|
|
def do(self, args, argv):
|
|
for i in argv:
|
|
print(i)
|
|
self.activate(i)
|
|
|
|
def activate(self, code):
|
|
server = self.get_server()
|
|
pk = self.get_pk(code)
|
|
url = f'{server}/api/v1/code/{pk}/'
|
|
r = requests.patch(url, headers=self.make_headers(), json={
|
|
'is_active': True,
|
|
})
|
|
print(r.text)
|
|
|
|
def get_pk(self, code):
|
|
server = self.get_server()
|
|
url = f'{server}/api/v1/code/?code=' + code
|
|
r = requests.get(url, headers=self.make_headers())
|
|
r = r.json()
|
|
return r['objects'][0]['id']
|
|
|
|
def _download_scan_data_worker(args_tuple):
|
|
"""Worker function for multiprocessing"""
|
|
scan_data, server, token, output_dir = args_tuple
|
|
scan_id = scan_data['id']
|
|
|
|
try:
|
|
# Save JSON
|
|
with open(f"{output_dir}/{scan_id}.json", "w") as f:
|
|
f.write(json.dumps(scan_data, indent=2) + "\n")
|
|
|
|
# Save message
|
|
if scan_data.get('message'):
|
|
with open(f"{output_dir}/{scan_id}.txt", "w") as f:
|
|
f.write(scan_data['message'] + "\n")
|
|
|
|
# Download image
|
|
if scan_data.get('image'):
|
|
url = f'{server}/api/v1/oss-image/?token={token}&name={scan_data["image"]}'
|
|
r = requests.get(url)
|
|
r.raise_for_status()
|
|
with open(f"{output_dir}/{scan_id}-frame.jpg", "wb") as f:
|
|
f.write(r.content)
|
|
|
|
# Download ROI if available
|
|
if scan_data.get('code'):
|
|
url = f'{server}/api/v1/code-feature-roi/?token={token}&code={scan_data["code"]}'
|
|
r = requests.get(url)
|
|
if r.status_code == 200:
|
|
with open(f"{output_dir}/{scan_id}-roi.jpg", "wb") as f:
|
|
f.write(r.content)
|
|
|
|
return scan_id, None
|
|
except Exception as e:
|
|
return scan_id, str(e)
|
|
|
|
class GetScanDataCommand(SubCommand):
|
|
name = "get-scan-data"
|
|
want_argv = True
|
|
help = "Get scan data from api"
|
|
|
|
def setup_args(self, parser):
|
|
parser.add_argument("--output", "-o", required=True)
|
|
parser.add_argument("--include-labels", help="Comma-separated list of labels to include")
|
|
parser.add_argument("--exclude-labels", help="Comma-separated list of labels to exclude")
|
|
parser.add_argument("--limit", type=int, help="Limit number of results")
|
|
parser.add_argument("--offset", type=int, help="Offset for pagination")
|
|
parser.add_argument("--query", "-q", help="Search query")
|
|
parser.add_argument("--workers", "-w", type=int, default=8, help="Number of parallel workers (default: 8)")
|
|
|
|
def do(self, args, argv):
|
|
os.makedirs(args.output, exist_ok=True)
|
|
|
|
server = self.get_server()
|
|
token = self.get_env()['token']
|
|
|
|
if argv:
|
|
# If IDs are provided, fetch them individually
|
|
scan_data_list = []
|
|
for i in argv:
|
|
sd = self.get_scan_data(i)
|
|
scan_data_list.append(sd)
|
|
else:
|
|
# If no IDs, query the list endpoint with filters
|
|
scan_data_list = self.query_scan_data(args)
|
|
|
|
if not scan_data_list:
|
|
print("No scan data found.")
|
|
return
|
|
|
|
print(f"Found {len(scan_data_list)} scan data records")
|
|
|
|
# Prepare arguments for worker function
|
|
worker_args = [(sd, server, token, args.output) for sd in scan_data_list]
|
|
|
|
# Use multiprocessing with tqdm progress bar
|
|
with Pool(processes=args.workers) as pool:
|
|
results = list(tqdm(
|
|
pool.imap(_download_scan_data_worker, worker_args),
|
|
total=len(scan_data_list),
|
|
desc="Downloading"
|
|
))
|
|
|
|
# Check for errors
|
|
errors = [r for r in results if r[1] is not None]
|
|
if errors:
|
|
print(f"\nErrors occurred for {len(errors)} records:", file=sys.stderr)
|
|
for scan_id, error in errors:
|
|
print(f" {scan_id}: {error}", file=sys.stderr)
|
|
|
|
def query_scan_data(self, args):
|
|
"""Query scan data list endpoint with filters, handling pagination"""
|
|
server = self.get_server()
|
|
url = f'{server}/api/v1/scan-data/'
|
|
base_params = {}
|
|
|
|
if args.include_labels:
|
|
base_params['include_labels'] = args.include_labels
|
|
if args.exclude_labels:
|
|
base_params['exclude_labels'] = args.exclude_labels
|
|
if args.query:
|
|
base_params['q'] = args.query
|
|
|
|
all_objects = []
|
|
offset = args.offset if args.offset else 0
|
|
limit = args.limit if args.limit else None
|
|
|
|
# If limit is set, respect it and don't paginate
|
|
if limit:
|
|
params = base_params.copy()
|
|
params['limit'] = limit
|
|
params['offset'] = offset
|
|
print(f"Querying: {url} with params: {params}")
|
|
r = requests.get(url, headers=self.make_headers(), params=params)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
return data.get('objects', [])
|
|
|
|
# Otherwise, fetch all pages
|
|
while True:
|
|
params = base_params.copy()
|
|
params['offset'] = offset
|
|
# Don't set limit to use API default
|
|
print(f"Querying: {url} with params: {params}")
|
|
r = requests.get(url, headers=self.make_headers(), params=params)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
objects = data.get('objects', [])
|
|
meta = data.get('meta', {})
|
|
|
|
if not objects:
|
|
break
|
|
|
|
all_objects.extend(objects)
|
|
|
|
# Check if we've fetched all pages
|
|
total_count = meta.get('total_count', 0)
|
|
current_limit = meta.get('limit', len(objects))
|
|
offset += len(objects)
|
|
|
|
print(f"Fetched {len(all_objects)}/{total_count} records...")
|
|
|
|
if offset >= total_count or len(objects) < current_limit:
|
|
break
|
|
|
|
return all_objects
|
|
|
|
def get_scan_data(self, i):
|
|
server = self.get_server()
|
|
url = f'{server}/api/v1/scan-data/{i}/'
|
|
print(url)
|
|
r = requests.get(url, headers=self.make_headers())
|
|
return r.json()
|
|
|
|
def get_image(self, name):
|
|
server = self.get_server()
|
|
token = self.get_env()['token']
|
|
url = f'{server}/api/v1/oss-image/?token={token}&name={name}'
|
|
r = requests.get(url)
|
|
r.raise_for_status()
|
|
return r.content
|
|
|
|
def get_roi(self, code):
|
|
server = self.get_server()
|
|
token = self.get_env()['token']
|
|
url = f'{server}/api/v1/code-feature-roi/?token={token}&code={code}'
|
|
r = requests.get(url)
|
|
if r.status_code == 404:
|
|
return None
|
|
return r.content
|
|
|
|
class UploadRoiCommand(SubCommand):
|
|
name = "upload-roi"
|
|
want_argv = True
|
|
help = "Upload roi, filename should be {code}.jpg"
|
|
|
|
def setup_args(self, parser):
|
|
parser.add_argument("file", action="append")
|
|
|
|
def do(self, args, argv):
|
|
for f in args.file:
|
|
print(f)
|
|
url = self.get_server() + "/api/v1/code-feature-roi/"
|
|
print(url)
|
|
r = requests.post(url, headers=self.make_headers(), files={
|
|
f: open(f, 'rb'),
|
|
})
|
|
print(r.text)
|
|
|
|
class UserInfoCommand(SubCommand):
|
|
name = "userinfo"
|
|
want_argv = True
|
|
help = "Call the userinfo API"
|
|
|
|
def setup_args(self, parser):
|
|
pass
|
|
|
|
def do(self, args, argv):
|
|
server = self.get_server()
|
|
url = f'{server}/api/v1/userinfo/'
|
|
r = requests.get(url, headers=self.make_headers())
|
|
r.raise_for_status()
|
|
print(r.json())
|
|
|
|
class UploadModelCommand(SubCommand):
|
|
name = "upload-model"
|
|
want_argv = False
|
|
help = "Upload model file to OSS bucket for qr-verify API"
|
|
|
|
def setup_args(self, parser):
|
|
parser.add_argument("file", help="Model file to upload")
|
|
|
|
def do(self, args, argv):
|
|
# OSS credentials (same as in emblem5/ai/ossclient.py)
|
|
oss_ak = 'LTAI5tC2qXGxwHZUZP7DoD1A'
|
|
oss_sk = 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm'
|
|
|
|
# OSS endpoint and bucket (same as in emblem5/ai/server.py download_model)
|
|
endpoint = 'https://oss-rg-china-mainland.aliyuncs.com'
|
|
bucket_name = 'emblem-models'
|
|
|
|
model_file = args.file
|
|
if not os.path.exists(model_file):
|
|
print(f"Error: File not found: {model_file}", file=sys.stderr)
|
|
return 1
|
|
|
|
model_name = os.path.basename(model_file)
|
|
print(f"Uploading {model_file} to {bucket_name}/{model_name}...")
|
|
|
|
try:
|
|
auth = oss2.Auth(oss_ak, oss_sk)
|
|
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
|
bucket.put_object_from_file(model_name, model_file)
|
|
print(f"Successfully uploaded {model_name} to OSS bucket {bucket_name}")
|
|
print(f"The qr-verify API will be able to download it as: {model_name}")
|
|
except Exception as e:
|
|
print(f"Error uploading model: {e}", file=sys.stderr)
|
|
return 1
|
|
|
|
class ListModelsCommand(SubCommand):
|
|
name = "list-models"
|
|
want_argv = False
|
|
help = "List models available in OSS bucket for qr-verify API"
|
|
|
|
def setup_args(self, parser):
|
|
pass
|
|
|
|
def do(self, args, argv):
|
|
# OSS credentials (same as in emblem5/ai/ossclient.py)
|
|
oss_ak = 'LTAI5tC2qXGxwHZUZP7DoD1A'
|
|
oss_sk = 'qPo9O6ZvEfqo4t8oflGEm0DoxLHJhm'
|
|
|
|
# OSS endpoint and bucket (same as in emblem5/ai/server.py download_model)
|
|
endpoint = 'https://oss-rg-china-mainland.aliyuncs.com'
|
|
bucket_name = 'emblem-models'
|
|
|
|
try:
|
|
auth = oss2.Auth(oss_ak, oss_sk)
|
|
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
|
|
|
print(f"Models in OSS bucket {bucket_name}:")
|
|
print("-" * 60)
|
|
|
|
models = []
|
|
for obj in oss2.ObjectIterator(bucket):
|
|
models.append(obj.key)
|
|
|
|
if not models:
|
|
print("No models found in bucket.")
|
|
else:
|
|
models.sort()
|
|
for i, model_name in enumerate(models, 1):
|
|
print(f"{i:3d}. {model_name}")
|
|
print("-" * 60)
|
|
print(f"Total: {len(models)} model(s)")
|
|
except Exception as e:
|
|
print(f"Error listing models: {e}", file=sys.stderr)
|
|
return 1
|
|
|
|
class QrVerifyCommand(SubCommand):
|
|
name = "qr-verify"
|
|
want_argv = False
|
|
help = "Post image to qr-verify endpoint"
|
|
|
|
def setup_args(self, parser):
|
|
parser.add_argument("--image", required=True, help="Path to image file")
|
|
parser.add_argument("--qrcode", required=True, help="QR code value")
|
|
|
|
def do(self, args, argv):
|
|
if not os.path.exists(args.image):
|
|
print(f"Error: Image file not found: {args.image}", file=sys.stderr)
|
|
return 1
|
|
|
|
# Read image and convert to base64 data URL
|
|
with open(args.image, 'rb') as f:
|
|
image_data = f.read()
|
|
|
|
# Determine MIME type from file extension
|
|
ext = os.path.splitext(args.image)[1].lower()
|
|
mime_types = {
|
|
'.jpg': 'image/jpeg',
|
|
'.jpeg': 'image/jpeg',
|
|
'.png': 'image/png',
|
|
'.gif': 'image/gif',
|
|
'.webp': 'image/webp',
|
|
}
|
|
mime_type = mime_types.get(ext, 'image/jpeg')
|
|
|
|
# Encode to base64 data URL
|
|
base64_data = base64.b64encode(image_data).decode('utf-8')
|
|
data_url = f"data:{mime_type};base64,{base64_data}"
|
|
|
|
# Prepare request data
|
|
data = {
|
|
'qrcode': args.qrcode,
|
|
'image_data_urls': [data_url],
|
|
}
|
|
|
|
# Post to qr-verify endpoint
|
|
server = self.get_server()
|
|
url = f'{server}/api/v1/qr-verify/'
|
|
headers = {
|
|
'Content-Type': 'application/json',
|
|
}
|
|
headers.update(self.make_headers())
|
|
|
|
try:
|
|
r = requests.post(url, headers=headers, json=data)
|
|
r.raise_for_status()
|
|
print(json.dumps(r.json(), indent=2))
|
|
except Exception as e:
|
|
print(f"Error posting to qr-verify: {e}", file=sys.stderr)
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
print(f"Response: {e.response.text}", file=sys.stderr)
|
|
return 1
|
|
|
|
def global_args(parser):
|
|
parser.add_argument("--env", "-E", help="Env", default="prod")
|
|
parser.add_argument("-D", "--debug", action="store_true",
|
|
help="Enable debug output")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
global_args(parser)
|
|
subparsers = parser.add_subparsers(title="subcommands")
|
|
for c in SubCommand.__subclasses__():
|
|
cmd = c()
|
|
p = subparsers.add_parser(cmd.name, aliases=cmd.aliases,
|
|
help=cmd.help)
|
|
cmd.setup_args(p)
|
|
p.set_defaults(func=cmd.do, cmdobj=cmd, all=False)
|
|
args, argv = parser.parse_known_args()
|
|
if not hasattr(args, "cmdobj"):
|
|
parser.print_usage()
|
|
return 1
|
|
if args.debug:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
if argv and not args.cmdobj.want_argv:
|
|
raise Exception("Unrecognized arguments:\n" + argv[0])
|
|
args.cmdobj.args = args
|
|
r = args.func(args, argv)
|
|
return r
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|