263 lines
8.4 KiB
Python
Executable File
263 lines
8.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
import numpy as np
|
|
import requests
|
|
import json
|
|
import subprocess
|
|
import cv2
|
|
from collections import defaultdict
|
|
from multiprocessing import Pool
|
|
|
|
class RoiResearch(object):
|
|
def __init__(self, token=None):
|
|
self.token = token
|
|
|
|
def login(self, username, password):
|
|
url = "https://themblem.com/api/v1/login/"
|
|
data = {
|
|
"username": username,
|
|
"password": password,
|
|
}
|
|
response = requests.post(url, json=data)
|
|
token = response.json()['token']
|
|
self.token = token
|
|
return token
|
|
|
|
def fetch_scan_data(self, last_id=None):
|
|
x = "/api/v1/scan-data/?limit=1000"
|
|
while True:
|
|
url = "https://themblem.com" + x
|
|
print("fetching", url)
|
|
response = requests.get(url, headers={"Authorization": f"Token {self.token}"})
|
|
r = response.json()
|
|
meta = r['meta']
|
|
if meta['next'] is None:
|
|
break
|
|
x = meta['next']
|
|
for sd in r['objects']:
|
|
if last_id and sd['id'] <= last_id:
|
|
return
|
|
yield sd
|
|
|
|
def get_last_id(list_file):
|
|
ret = None
|
|
if not os.path.exists(list_file):
|
|
return ret
|
|
with open(list_file, "r") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
id = json.loads(line)['id']
|
|
if not ret or id > ret:
|
|
ret = id
|
|
return ret
|
|
|
|
def get_all_ids(list_file):
|
|
ret = []
|
|
with open(list_file, "r") as f:
|
|
for line in f:
|
|
ret.append(json.loads(line)['id'])
|
|
return ret
|
|
|
|
def get_roi_data(args):
|
|
rr = RoiResearch(args.token)
|
|
if args.username and args.password:
|
|
print(rr.login(args.username, args.password))
|
|
list_file = "data/roi/list.txt"
|
|
last_id = get_last_id(list_file)
|
|
print("last_id", last_id)
|
|
with open(list_file, "a") as f:
|
|
for sd in rr.fetch_scan_data(last_id):
|
|
if not sd['image'] or not sd['labels']:
|
|
print(f"skipping {sd['id']}, no image or labels: image={sd['image']}, labels={sd['labels']}")
|
|
continue
|
|
print("new id", sd['id'])
|
|
line = json.dumps({
|
|
'id': sd['id'],
|
|
'image': sd['image'],
|
|
'labels': sd['labels'],
|
|
})
|
|
f.write(line + "\n")
|
|
all_ids = get_all_ids(list_file)
|
|
for id in all_ids:
|
|
outd = f"data/roi/samples/{id}"
|
|
if os.path.exists(f"{outd}/{id}.json"):
|
|
print(f"skipping {id}, json already exists")
|
|
continue
|
|
cmd = f'../scripts/emcli get-scan-data {id} -o {outd}'
|
|
print(cmd)
|
|
subprocess.call(cmd, shell=True)
|
|
|
|
def preprocess(img):
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(gray)
|
|
histo = cv2.calcHist([gray], [0], None, [256], [0, 256])
|
|
total = gray.shape[0] * gray.shape[1]
|
|
s = 0
|
|
low_thres = -1
|
|
high_thres = -1
|
|
low_proportion = 70
|
|
high_proportion = 80
|
|
for x in range(256):
|
|
i = histo[x]
|
|
s += i
|
|
ratio = s / total * 100
|
|
if low_thres == -1 and ratio > low_proportion:
|
|
low_thres = x
|
|
if high_thres == -1 and ratio > high_proportion:
|
|
high_thres = x
|
|
if low_thres == -1:
|
|
raise Exception("no threshold found")
|
|
print(low_thres, high_thres)
|
|
for i in range(gray.shape[0]):
|
|
for j in range(gray.shape[1]):
|
|
if gray[i, j] < low_thres:
|
|
gray[i, j] = 0
|
|
elif gray[i, j] > high_thres:
|
|
gray[i, j] = 0
|
|
else:
|
|
gray[i, j] = 255
|
|
return gray
|
|
|
|
window_y = 0
|
|
def show_img(img, name):
|
|
global window_y
|
|
cv2.imshow(name, img)
|
|
cv2.moveWindow(name, 640, window_y)
|
|
window_y += 200
|
|
|
|
def find_point(roi_processed, i, j):
|
|
tolerrance = 2
|
|
for di in range(-tolerrance, tolerrance + 1):
|
|
for dj in range(-tolerrance, tolerrance + 1):
|
|
x = i + di
|
|
y = j + dj
|
|
if x < 0 or x >= roi_processed.shape[0] or y < 0 or y >= roi_processed.shape[1]:
|
|
continue
|
|
if roi_processed[x, y] > 0:
|
|
return True
|
|
|
|
def calc_sim(frame_processed, roi_processed):
|
|
s = 0
|
|
total = 0
|
|
width = frame_processed.shape[1]
|
|
height = frame_processed.shape[0]
|
|
for i in range(height):
|
|
for j in range(width):
|
|
roi_val = roi_processed[i, j]
|
|
frame_val = frame_processed[i, j]
|
|
if roi_val:
|
|
total += 1
|
|
if find_point(frame_processed, i, j):
|
|
s += 1
|
|
if frame_val:
|
|
total += 1
|
|
if find_point(roi_processed, i, j):
|
|
s += 1
|
|
return s / total
|
|
|
|
def roi_sim(frame_roi_img, roi_img):
|
|
if frame_roi_img.shape != roi_img.shape:
|
|
raise Exception("size mismatch")
|
|
frame_processed = preprocess(frame_roi_img)
|
|
roi_processed = preprocess(roi_img)
|
|
siml = calc_sim(frame_processed, roi_processed)
|
|
if False:
|
|
show_img(frame_processed, "frame_processed")
|
|
show_img(roi_processed, "roi_processed")
|
|
show_img(frame_roi_img, "frame_roi_img")
|
|
show_img(roi_img, "roi_img")
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
return siml
|
|
|
|
def frame_roi(frame_file):
|
|
frame_roi_file = frame_file + ".roi.jpg"
|
|
if os.path.exists(frame_roi_file):
|
|
return cv2.imread(frame_roi_file)
|
|
cmd = f'./qrtool frame_roi {frame_file}'
|
|
subprocess.check_call(cmd, shell=True, cwd=os.path.dirname(os.path.abspath(__file__)) + "/../alg")
|
|
frame_roi_img = cv2.imread(frame_roi_file)
|
|
return frame_roi_img
|
|
|
|
def process_roi_data(id):
|
|
rd = f"data/roi/samples/{id}"
|
|
frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg")
|
|
roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg")
|
|
roi_img = cv2.imread(roi_file)
|
|
frame_roi_img = frame_roi(frame_file)
|
|
size = [128, 128]
|
|
frame_roi_img = cv2.resize(frame_roi_img, size)
|
|
roi_img = cv2.resize(roi_img, size)
|
|
sim = roi_sim(frame_roi_img, roi_img)
|
|
print(f"{id}: {sim}")
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--token", '-t', default='3ebd8c33-f46e-4b06-bda8-4c0f5f5eb530', type=str)
|
|
parser.add_argument("--username", "-u", type=str)
|
|
parser.add_argument("--password", "-p", type=str)
|
|
parser.add_argument("--download", "-d", action='store_true')
|
|
parser.add_argument("--preprocess", "-P", action='store_true')
|
|
parser.add_argument("--id", "-i", type=int, action='append')
|
|
return parser.parse_args()
|
|
|
|
def all_sample_ids():
|
|
return os.listdir("data/roi/samples")
|
|
|
|
def prepare_to_train(id):
|
|
rd = f"data/roi/samples/{id}"
|
|
side_by_side_file = os.path.join(rd, f"{id}-side-by-side.jpg")
|
|
if not os.path.exists(side_by_side_file):
|
|
frame_file = os.path.abspath(f"{rd}/{id}-frame.jpg")
|
|
roi_file = os.path.abspath(f"{rd}/{id}-roi.jpg")
|
|
if not os.path.exists(frame_file) or not os.path.exists(roi_file):
|
|
print(f"skipping {id}, no frame or roi")
|
|
return
|
|
roi_img = cv2.imread(roi_file)
|
|
try:
|
|
frame_roi_img = frame_roi(frame_file)
|
|
except Exception as e:
|
|
print(f"failed to get frame_roi for {id}: {e}")
|
|
return
|
|
frame_roi_img = cv2.resize(frame_roi_img, (128, 128))
|
|
roi_img = cv2.resize(roi_img, (128, 128))
|
|
side_by_side = np.concatenate((frame_roi_img, roi_img), axis=1)
|
|
cv2.imwrite(side_by_side_file, side_by_side)
|
|
json_file = os.path.abspath(f"{rd}/{id}.json")
|
|
with open(json_file, "r") as f:
|
|
data = json.load(f)
|
|
labels = data['labels']
|
|
label = 0
|
|
if 'pos' in labels:
|
|
label = 1
|
|
elif 'neg' in labels:
|
|
label = 0
|
|
else:
|
|
print(f"no label found for {id}")
|
|
return
|
|
shutil.copy(side_by_side_file, os.path.abspath(f"data/roi/train/{id}.jpg"))
|
|
with open(os.path.abspath(f"data/roi/train/labels.txt"), "a") as f:
|
|
f.write(f"{id}.jpg {label}\n")
|
|
return label
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.download:
|
|
get_roi_data(args)
|
|
if args.preprocess:
|
|
shutil.rmtree(os.path.abspath("data/roi/train"))
|
|
os.makedirs(os.path.abspath("data/roi/train"))
|
|
with Pool(processes=10) as pool:
|
|
pool.map(prepare_to_train, all_sample_ids())
|
|
return
|
|
if args.id:
|
|
for id in args.id:
|
|
process_roi_data(id)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|