diff --git a/emblem5/ai/fetch-scans.py b/emblem5/ai/fetch-scans.py index 92ee9f2..9799431 100755 --- a/emblem5/ai/fetch-scans.py +++ b/emblem5/ai/fetch-scans.py @@ -44,16 +44,19 @@ class ScanDataFetcher(object): ret[md['id']] = md return ret - def fetch(self, sample_rate=None): + def fetch(self, sample_rate=None, scan_ids=None): local_scan_data = self.load_local_scan_data() logger.info(f'local_scan_data: {len(local_scan_data)}') url = 'https://themblem.com/api/v1/scan-data-labels/' r = requests.get(url, headers=self.make_headers()) data = r.json() fetch_backlog = [] + scan_ids_set = set(scan_ids) if scan_ids else None for item in data['items']: if 'code' not in item or 'id' not in item or not item.get('labels') or 'image' not in item or not item.get('image'): continue + if scan_ids_set and item['id'] not in scan_ids_set: + continue if item['id'] in local_scan_data: local_labels = local_scan_data[item['id']]['labels'] if local_labels == item['labels']: @@ -137,6 +140,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--data-dir', type=str, default='data') parser.add_argument('--sample-rate', '-r', type=float) + parser.add_argument('--scan-ids', type=str, help='Comma-separated list of scan IDs to fetch') return parser.parse_args() def main(): @@ -145,7 +149,8 @@ def main(): data_dir = args.data_dir fetcher = ScanDataFetcher() logger.info('fetch') - fetcher.fetch(args.sample_rate) + scan_ids = [int(x.strip()) for x in args.scan_ids.split(',')] if args.scan_ids else None + fetcher.fetch(args.sample_rate, scan_ids) if __name__ == "__main__": main()