diff --git a/.gitignore b/.gitignore index 92c60a4..591b3f0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ build /detection/model /api/db.sqlite3 /emblemscanner-release +/models diff --git a/Makefile b/Makefile index 7ea8b32..260f48e 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick +.PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick verify DATA_DIR ?= $(HOME)/emblem @@ -102,6 +102,47 @@ train: FORCE train-quick: FORCE cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --quick --epochs 2 +browser: FORCE + cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR) + +verify: FORCE + @if [ -z "$(SBS_IMG)" ]; then \ + echo "Error: SBS_IMG is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \ + exit 1; \ + fi + @if [ -z "$(MODEL)" ]; then \ + echo "Error: MODEL is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \ + exit 1; \ + fi + @MODEL_PATH="$(MODEL)"; \ + if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \ + MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \ + fi; \ + if [ ! -f "$$MODEL_PATH" ]; then \ + echo "Error: Model file not found: $$MODEL_PATH"; \ + exit 1; \ + fi; \ + cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/verify_image.py "$$MODEL_PATH" "$(SBS_IMG)" + +check-accuracy: FORCE + @if [ -z "$(MODEL)" ]; then \ + echo "Error: MODEL is required. Usage: make check-accuracy MODEL=/path/to/model.pt [DATA_DIR=/path/to/data]"; \ + exit 1; \ + fi + @MODEL_PATH="$(MODEL)"; \ + if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \ + MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \ + fi; \ + if [ ! -f "$$MODEL_PATH" ]; then \ + echo "Error: Model file not found: $$MODEL_PATH"; \ + exit 1; \ + fi; \ + DATA_DIR="$(DATA_DIR)"; \ + if [ -z "$$DATA_DIR" ]; then \ + DATA_DIR="$(HOME)/emblem"; \ + fi; \ + cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/check_accuracy.py "$$MODEL_PATH" "$$DATA_DIR" + OPENCV_TAG := 4.9.0 opencv/src/LICENSE: rm -rf opencv/src opencv/contrib diff --git a/emblem5/ai/browser.py b/emblem5/ai/browser.py index ab82913..8e2923c 100644 --- a/emblem5/ai/browser.py +++ b/emblem5/ai/browser.py @@ -1,38 +1,118 @@ #!/usr/bin/env python3 import os +import sys +import argparse import streamlit as st -import random import json +import re from PIL import Image -def get_mispredicted_scans(): - with open('data/verify2.log', 'r') as f: - lines = f.readlines() - for line in lines: - fields = line.split() - if len(fields) != 6: +def get_scan_ids(data_dir): + scans_dir = os.path.join(data_dir, 'scans') + if not os.path.exists(scans_dir): + return [] + scan_ids = [] + for item in os.listdir(scans_dir): + scan_path = os.path.join(scans_dir, item) + if os.path.isdir(scan_path): + try: + scan_ids.append(int(item)) + except ValueError: + continue + return sorted(scan_ids, reverse=True) + +def parse_filter(filter_str): + """Parse filter string like '357193-358023 or 358024-358808' into list of ranges.""" + if not filter_str or not filter_str.strip(): + return [] + + ranges = [] + # Split by 'or' (case insensitive) + parts = re.split(r'\s+or\s+', filter_str.strip(), flags=re.IGNORECASE) + + for part in parts: + part = part.strip() + if not part: continue - if fields[1] != fields[2]: - yield fields[0] + + # Try to parse as range (start-end) + match = re.match(r'(\d+)\s*-\s*(\d+)', part) + if match: + start = int(match.group(1)) + end = int(match.group(2)) + ranges.append((start, end)) + else: + # Try to parse as single number + try: + num = int(part) + ranges.append((num, num)) + except ValueError: + continue + + return ranges + +def filter_scans(scan_ids, ranges): + """Filter scan IDs based on ranges.""" + if not ranges: + return scan_ids + + filtered = set() + for start, end in ranges: + for sid in scan_ids: + if start <= sid <= end: + filtered.add(sid) + + return sorted(filtered, reverse=True) def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/emblem'), help='Data directory') + args, unknown = parser.parse_known_args() + data_dir = args.data_dir + st.title('Browser') - # scan_ids = os.listdir('data/scans') - # to_show = sorted(scan_ids, key=lambda x: int(x), reverse=True)[:100] - to_show = list(get_mispredicted_scans()) - st.write(f'to show: {len(to_show)}') - for sid in to_show: - show_scan(sid) + scan_ids = get_scan_ids(data_dir) + total_scans = len(scan_ids) + + if total_scans == 0: + st.write('No scans found') + return -def show_scan(scan_id): - scan_dir = f'data/scans/{scan_id}' - mdfile = f'{scan_dir}/metadata.json' - md = json.load(open(mdfile)) + st.write(f'Total scans: {total_scans}') + + # Smart filter input + filter_str = st.text_input('Filter (e.g., "357193-358023 or 358024-358808")', value='') + + if filter_str: + ranges = parse_filter(filter_str) + if ranges: + to_show = filter_scans(scan_ids, ranges) + st.write(f'Filter matched: {len(to_show)} scans') + else: + st.warning('No valid ranges found in filter') + return + else: + # Default: show last 500 scans + default_count = min(500, total_scans) + to_show = scan_ids[:default_count] + st.write(f'Showing last {len(to_show)} scans (default)') + + for sid in to_show: + show_scan(str(sid), data_dir) + +def show_scan(scan_id, data_dir): + scan_dir = os.path.join(data_dir, 'scans', scan_id) + mdfile = os.path.join(scan_dir, 'metadata.json') if not os.path.exists(mdfile): return - sbs = Image.open(f'{scan_dir}/sbs.jpg') - st.write(f'{scan_id}: {md["labels"]}') + md = json.load(open(mdfile)) + sbs_path = os.path.join(scan_dir, 'sbs.jpg') + if not os.path.exists(sbs_path): + return + sbs = Image.open(sbs_path) + st.write(f'{scan_id}: {md.get("labels", "N/A")}') + st.write(f'SBS: {sbs_path}') st.image(sbs.resize((512, 256))) st.divider() diff --git a/emblem5/ai/check_accuracy.py b/emblem5/ai/check_accuracy.py new file mode 100755 index 0000000..31d5410 --- /dev/null +++ b/emblem5/ai/check_accuracy.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Temporary script to check model accuracy on specific scan ranges +Positive samples: 357193-358023 +Negative samples: 358024-358808 +""" +import sys +import os +import json +from PIL import Image +from tqdm import tqdm + +# Add current directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from common import load_model, predict + +def main(): + if len(sys.argv) < 3: + print("Usage: python3 check_accuracy.py ") + print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem") + sys.exit(1) + + model_path = sys.argv[1] + data_dir = sys.argv[2] + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + # Define scan ranges + pos_start, pos_end = 357193, 358023 + neg_start, neg_end = 358024, 358808 + + print(f"Loading model from {model_path}...") + model, transforms = load_model(model_path) + # Ensure model is not compiled + if hasattr(model, '_orig_mod'): + model = model._orig_mod + print("Model loaded successfully\n") + + scans_dir = os.path.join(data_dir, 'scans') + if not os.path.exists(scans_dir): + print(f"Error: Scans directory not found: {scans_dir}") + sys.exit(1) + + # Collect scan IDs in ranges + pos_scans = [] + neg_scans = [] + + print("Collecting scan IDs...") + for scan_id in range(pos_start, pos_end + 1): + scan_dir = os.path.join(scans_dir, str(scan_id)) + sbs_file = os.path.join(scan_dir, 'sbs.jpg') + if os.path.exists(sbs_file): + pos_scans.append(str(scan_id)) + + for scan_id in range(neg_start, neg_end + 1): + scan_dir = os.path.join(scans_dir, str(scan_id)) + sbs_file = os.path.join(scan_dir, 'sbs.jpg') + if os.path.exists(sbs_file): + neg_scans.append(str(scan_id)) + + print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})") + print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n") + + # Run predictions + pos_correct = 0 + pos_total = 0 + neg_correct = 0 + neg_total = 0 + + print("Evaluating positive samples...") + for scan_id in tqdm(pos_scans, desc="Positive"): + scan_dir = os.path.join(scans_dir, scan_id) + sbs_file = os.path.join(scan_dir, 'sbs.jpg') + + try: + image = Image.open(sbs_file).convert('RGB') + predicted_class, probabilities = predict(model, transforms, image) + + pos_total += 1 + if predicted_class == 1: + pos_correct += 1 + except Exception as e: + print(f"\nError processing {scan_id}: {e}") + continue + + print("\nEvaluating negative samples...") + neg_errors = [] + for scan_id in tqdm(neg_scans, desc="Negative"): + scan_dir = os.path.join(scans_dir, scan_id) + sbs_file = os.path.join(scan_dir, 'sbs.jpg') + + try: + image = Image.open(sbs_file).convert('RGB') + predicted_class, probabilities = predict(model, transforms, image) + + neg_sum = sum([p[0] for p in probabilities]) + pos_sum = sum([p[1] for p in probabilities]) + total = neg_sum + pos_sum + neg_prob = neg_sum / total if total > 0 else 0 + pos_prob = pos_sum / total if total > 0 else 0 + + neg_total += 1 + if predicted_class == 0: + neg_correct += 1 + else: + # Store first few errors for debugging + if len(neg_errors) < 5: + neg_errors.append((scan_id, neg_prob, pos_prob)) + except Exception as e: + print(f"\nError processing {scan_id}: {e}") + continue + + # Calculate accuracies + pos_acc = pos_correct / pos_total if pos_total > 0 else 0 + neg_acc = neg_correct / neg_total if neg_total > 0 else 0 + total_correct = pos_correct + neg_correct + total_samples = pos_total + neg_total + overall_acc = total_correct / total_samples if total_samples > 0 else 0 + + # Print results + print("\n" + "="*60) + print("Accuracy Results:") + print("="*60) + print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})") + print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})") + print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})") + print("="*60) + + if neg_errors: + print("\nSample negative misclassifications:") + for scan_id, neg_prob, pos_prob in neg_errors: + print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}") + +if __name__ == '__main__': + main() + diff --git a/emblem5/ai/common.py b/emblem5/ai/common.py index 7a314b6..3d4c036 100644 --- a/emblem5/ai/common.py +++ b/emblem5/ai/common.py @@ -94,7 +94,10 @@ def make_side_by_side_img_with_margins(frame_img, std_img): # Use a reasonable size - use the larger dimension or at least 512 edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512) - margin_ratio = find_min_margin_ratio(std_img, std_corners) + std_margin_ratio = find_min_margin_ratio(std_img, std_corners) + frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners) + # Use the minimum margin ratio to ensure both QR areas are captured correctly + margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05) std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio) frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio) diff --git a/emblem5/ai/make-sbs.py b/emblem5/ai/make-sbs.py index 93a1a0b..291a38e 100755 --- a/emblem5/ai/make-sbs.py +++ b/emblem5/ai/make-sbs.py @@ -5,14 +5,18 @@ import sys import json import random import argparse +from functools import partial from common import * def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--data-dir', required=True) + parser.add_argument('--scan-id', help='Process only this specific scan ID') + parser.add_argument('--scan-id-range', help='Process scan IDs in range (e.g., "358626-358630" or "358626-358630,359000-359010")') + parser.add_argument('--force', action='store_true', help='Overwrite existing sbs.jpg files') return parser.parse_args() -def process_scan(scan_dir): +def process_scan(scan_dir, force=False): if not os.path.isdir(scan_dir): return "scan_dir not found" frame_file = os.path.join(scan_dir, 'frame.jpg') @@ -20,10 +24,8 @@ def process_scan(scan_dir): if not os.path.exists(frame_file) or not os.path.exists(std_file): return "frame.jpg or std.jpg not found" sbs_file = os.path.join(scan_dir, 'sbs.jpg') - frame_qr_file = os.path.join(scan_dir, 'frame-qr.jpg') - std_qr_file = os.path.join(scan_dir, 'std-qr.jpg') try: - if not os.path.exists(sbs_file): + if force or not os.path.exists(sbs_file): frame_img = Image.open(frame_file) std_img = Image.open(std_file) sbs_img = make_side_by_side_img_with_margins(frame_img, std_img) @@ -39,10 +41,33 @@ def main(): args = parse_args() data_dir = args.data_dir scans_dir = os.path.join(data_dir, 'scans') - pool = Pool(cpu_count()) scan_ids = os.listdir(scans_dir) + if args.scan_id: + if args.scan_id not in scan_ids: + print(f"Error: scan-id '{args.scan_id}' not found") + sys.exit(1) + scan_ids = [args.scan_id] + elif args.scan_id_range: + ranges = parse_ranges(args.scan_id_range) + filtered_scan_ids = [] + for scan_id in scan_ids: + try: + scan_id_int = int(scan_id) + for val_range in ranges: + if in_range(scan_id_int, val_range): + filtered_scan_ids.append(scan_id) + break + except ValueError: + # Skip non-numeric scan IDs + continue + scan_ids = filtered_scan_ids + if not scan_ids: + print(f"Error: No scan IDs found in range '{args.scan_id_range}'") + sys.exit(1) + pool = Pool(cpu_count()) counts = defaultdict(int) - for result in tqdm(pool.imap(process_scan, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)): + process_scan_with_force = partial(process_scan, force=args.force) + for result in tqdm(pool.imap(process_scan_with_force, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)): counts[result] += 1 for k, v in counts.items(): print(f"{k}: {v}") diff --git a/emblem5/ai/train2.py b/emblem5/ai/train2.py index 8186c42..227d096 100755 --- a/emblem5/ai/train2.py +++ b/emblem5/ai/train2.py @@ -233,45 +233,22 @@ def load_scan_data(data_dir): def split_train_val(scan_data): """Split data into train and validation sets""" scan_ids = list(scan_data.keys()) + random.shuffle(scan_ids) print("Splitting data into train and validation sets...") - # Select 10% random scan_ids for initial validation set - val_size = int(len(scan_ids) * 0.1) - initial_val_scan_ids = random.sample(scan_ids, val_size) + # Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation + train_size = int(len(scan_ids) * 0.2) + val_size = int(len(scan_ids) * 0.4) - print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") + train_scan_ids = scan_ids[:train_size] + val_scan_ids = scan_ids[train_size:train_size + val_size] - # Get all (code, labels) combinations that appear in the initial validation set - val_code_labels = set() - for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"): - if scan_id in scan_data: - code = scan_data[scan_id].get('code') - labels = tuple(sorted(scan_data[scan_id].get('labels', []))) - if code: - val_code_labels.add((code, labels)) - - print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set") - - # Find all scans in train that have matching (code, labels) combinations and move them to validation - additional_val_scan_ids = set() - train_scan_ids = set(scan_ids) - set(initial_val_scan_ids) - - for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"): - if scan_id in scan_data: - code = scan_data[scan_id].get('code') - labels = tuple(sorted(scan_data[scan_id].get('labels', []))) - # If (code, labels) combination matches, move to validation - if code and (code, labels) in val_code_labels: - additional_val_scan_ids.add(scan_id) - - # Combine validation sets - all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids - all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids + print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)") # Create train and validation data dictionaries - train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} - val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids} + train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids} + val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids} print(f"Final split:") print(f" Total scans: {len(scan_data)}") @@ -418,7 +395,8 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler best_val_acc = val_acc timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" - best_model_path = f'models/best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' + models_dir = os.path.expanduser('~/emblem/models') + best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') save_model(model, transform, best_model_path, train_metadata) print(f' New best validation accuracy: {val_acc:.2f}%') print(f' Best model saved: {best_model_path}') @@ -445,7 +423,8 @@ def main(): print(f'Using hue jitter: {args.hue_jitter}') # Create models directory if it doesn't exist - os.makedirs('models', exist_ok=True) + models_dir = os.path.expanduser('~/emblem/models') + os.makedirs(models_dir, exist_ok=True) # Load scan data print("Loading scan data...") @@ -510,7 +489,11 @@ def main(): # Create model print("Creating ResNet18 model...") # model = torchvision.models.resnet18(pretrained=True) - model, _ = load_model('models/gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt') + models_dir = os.path.expanduser('~/emblem/models') + model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')) + # Unwrap compiled model if it exists (torch.compile can cause issues when loading) + if hasattr(model, '_orig_mod'): + model = model._orig_mod # Modify final layer for binary classification # model.fc = nn.Linear(model.fc.in_features, 2) model = model.to(device) @@ -564,7 +547,8 @@ def main(): # Save final model timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') mode_suffix = "_quick" if args.quick else "" - final_model_path = f'models/final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' + models_dir = os.path.expanduser('~/emblem/models') + final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt') save_model(model, transform, final_model_path, train_metadata) print(f"Training completed! Final model saved: {final_model_path}") diff --git a/emblem5/ai/verify_image.py b/emblem5/ai/verify_image.py new file mode 100755 index 0000000..3d58324 --- /dev/null +++ b/emblem5/ai/verify_image.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +""" +Simple script to verify a single image using a trained model +""" +import sys +import os + +# Add current directory to path for imports +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from PIL import Image +from common import load_model, predict + +def main(): + if len(sys.argv) < 3: + print("Usage: python3 verify_image.py ") + print("Example: python3 verify_image.py models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem/scans/358626/sbs.jpg") + sys.exit(1) + + model_path = sys.argv[1] + image_path = sys.argv[2] + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + if not os.path.exists(image_path): + print(f"Error: Image file not found: {image_path}") + sys.exit(1) + + print(f"Loading model from {model_path}...") + model, transforms = load_model(model_path) + # Ensure model is not compiled (torch.compile can cause issues when loading) + if hasattr(model, '_orig_mod'): + model = model._orig_mod + print("Model loaded successfully") + + print(f"Loading image from {image_path}...") + image = Image.open(image_path).convert('RGB') + print(f"Image size: {image.size}") + + print("Running prediction...") + predicted_class, probabilities = predict(model, transforms, image) + + # probabilities is a list of [neg_prob, pos_prob] for each cell + # Sum up probabilities across all cells + neg_sum = sum([p[0] for p in probabilities]) + pos_sum = sum([p[1] for p in probabilities]) + total = neg_sum + pos_sum + neg_prob = neg_sum / total if total > 0 else 0 + pos_prob = pos_sum / total if total > 0 else 0 + + print("\n" + "="*50) + print("Prediction Results:") + print("="*50) + print(f"Predicted class: {'POSITIVE' if predicted_class == 1 else 'NEGATIVE'}") + print(f"Negative probability: {neg_prob:.2%}") + print(f"Positive probability: {pos_prob:.2%}") + print(f"Number of cells evaluated: {len(probabilities)}") + print("="*50) + +if __name__ == '__main__': + main() +