diff --git a/scripts/emcli b/scripts/emcli index f916fb5..8439b77 100755 --- a/scripts/emcli +++ b/scripts/emcli @@ -5,6 +5,8 @@ import sys import json import requests import argparse +from multiprocessing import Pool +from tqdm import tqdm class SubCommand(object): """ Base class of subcommand""" @@ -70,6 +72,41 @@ class ActivateCommand(SubCommand): r = r.json() return r['objects'][0]['id'] +def _download_scan_data_worker(args_tuple): + """Worker function for multiprocessing""" + scan_data, server, token, output_dir = args_tuple + scan_id = scan_data['id'] + + try: + # Save JSON + with open(f"{output_dir}/{scan_id}.json", "w") as f: + f.write(json.dumps(scan_data, indent=2) + "\n") + + # Save message + if scan_data.get('message'): + with open(f"{output_dir}/{scan_id}.txt", "w") as f: + f.write(scan_data['message'] + "\n") + + # Download image + if scan_data.get('image'): + url = f'{server}/api/v1/oss-image/?token={token}&name={scan_data["image"]}' + r = requests.get(url) + r.raise_for_status() + with open(f"{output_dir}/{scan_id}-frame.jpg", "wb") as f: + f.write(r.content) + + # Download ROI if available + if scan_data.get('code'): + url = f'{server}/api/v1/code-feature-roi/?token={token}&code={scan_data["code"]}' + r = requests.get(url) + if r.status_code == 200: + with open(f"{output_dir}/{scan_id}-roi.jpg", "wb") as f: + f.write(r.content) + + return scan_id, None + except Exception as e: + return scan_id, str(e) + class GetScanDataCommand(SubCommand): name = "get-scan-data" want_argv = True @@ -77,22 +114,109 @@ class GetScanDataCommand(SubCommand): def setup_args(self, parser): parser.add_argument("--output", "-o", required=True) + parser.add_argument("--include-labels", help="Comma-separated list of labels to include") + parser.add_argument("--exclude-labels", help="Comma-separated list of labels to exclude") + parser.add_argument("--limit", type=int, help="Limit number of results") + parser.add_argument("--offset", type=int, help="Offset for pagination") + parser.add_argument("--query", "-q", help="Search query") + parser.add_argument("--workers", "-w", type=int, default=8, help="Number of parallel workers (default: 8)") def do(self, args, argv): os.makedirs(args.output, exist_ok=True) - for i in argv: - sd = self.get_scan_data(i) - print(sd) - with open(f"{args.output}/{i}.json", "w") as f: - f.write(json.dumps(sd, indent=2) + "\n") - with open(f"{args.output}/{i}.txt", "w") as f: - f.write(sd['message'] + "\n") - with open(f"{args.output}/{i}-frame.jpg", "wb") as f: - f.write(self.get_image(sd['image'])) - roi = self.get_roi(sd['code']) - if roi: - with open(f"{args.output}/{i}-roi.jpg", "wb") as f: - f.write(roi) + + server = self.get_server() + token = self.get_env()['token'] + + if argv: + # If IDs are provided, fetch them individually + scan_data_list = [] + for i in argv: + sd = self.get_scan_data(i) + scan_data_list.append(sd) + else: + # If no IDs, query the list endpoint with filters + scan_data_list = self.query_scan_data(args) + + if not scan_data_list: + print("No scan data found.") + return + + print(f"Found {len(scan_data_list)} scan data records") + + # Prepare arguments for worker function + worker_args = [(sd, server, token, args.output) for sd in scan_data_list] + + # Use multiprocessing with tqdm progress bar + with Pool(processes=args.workers) as pool: + results = list(tqdm( + pool.imap(_download_scan_data_worker, worker_args), + total=len(scan_data_list), + desc="Downloading" + )) + + # Check for errors + errors = [r for r in results if r[1] is not None] + if errors: + print(f"\nErrors occurred for {len(errors)} records:", file=sys.stderr) + for scan_id, error in errors: + print(f" {scan_id}: {error}", file=sys.stderr) + + def query_scan_data(self, args): + """Query scan data list endpoint with filters, handling pagination""" + server = self.get_server() + url = f'{server}/api/v1/scan-data/' + base_params = {} + + if args.include_labels: + base_params['include_labels'] = args.include_labels + if args.exclude_labels: + base_params['exclude_labels'] = args.exclude_labels + if args.query: + base_params['q'] = args.query + + all_objects = [] + offset = args.offset if args.offset else 0 + limit = args.limit if args.limit else None + + # If limit is set, respect it and don't paginate + if limit: + params = base_params.copy() + params['limit'] = limit + params['offset'] = offset + print(f"Querying: {url} with params: {params}") + r = requests.get(url, headers=self.make_headers(), params=params) + r.raise_for_status() + data = r.json() + return data.get('objects', []) + + # Otherwise, fetch all pages + while True: + params = base_params.copy() + params['offset'] = offset + # Don't set limit to use API default + print(f"Querying: {url} with params: {params}") + r = requests.get(url, headers=self.make_headers(), params=params) + r.raise_for_status() + data = r.json() + objects = data.get('objects', []) + meta = data.get('meta', {}) + + if not objects: + break + + all_objects.extend(objects) + + # Check if we've fetched all pages + total_count = meta.get('total_count', 0) + current_limit = meta.get('limit', len(objects)) + offset += len(objects) + + print(f"Fetched {len(all_objects)}/{total_count} records...") + + if offset >= total_count or len(objects) < current_limit: + break + + return all_objects def get_scan_data(self, i): server = self.get_server()