update train2
This commit is contained in:
parent
01e3543be9
commit
aa2d4041e3
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ build
|
||||
/detection/model
|
||||
/api/db.sqlite3
|
||||
/emblemscanner-release
|
||||
/models
|
||||
|
||||
43
Makefile
43
Makefile
@ -1,4 +1,4 @@
|
||||
.PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick
|
||||
.PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick verify
|
||||
|
||||
DATA_DIR ?= $(HOME)/emblem
|
||||
|
||||
@ -102,6 +102,47 @@ train: FORCE
|
||||
train-quick: FORCE
|
||||
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --quick --epochs 2
|
||||
|
||||
browser: FORCE
|
||||
cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR)
|
||||
|
||||
verify: FORCE
|
||||
@if [ -z "$(SBS_IMG)" ]; then \
|
||||
echo "Error: SBS_IMG is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@if [ -z "$(MODEL)" ]; then \
|
||||
echo "Error: MODEL is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@MODEL_PATH="$(MODEL)"; \
|
||||
if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \
|
||||
MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \
|
||||
fi; \
|
||||
if [ ! -f "$$MODEL_PATH" ]; then \
|
||||
echo "Error: Model file not found: $$MODEL_PATH"; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/verify_image.py "$$MODEL_PATH" "$(SBS_IMG)"
|
||||
|
||||
check-accuracy: FORCE
|
||||
@if [ -z "$(MODEL)" ]; then \
|
||||
echo "Error: MODEL is required. Usage: make check-accuracy MODEL=/path/to/model.pt [DATA_DIR=/path/to/data]"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@MODEL_PATH="$(MODEL)"; \
|
||||
if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \
|
||||
MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \
|
||||
fi; \
|
||||
if [ ! -f "$$MODEL_PATH" ]; then \
|
||||
echo "Error: Model file not found: $$MODEL_PATH"; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
DATA_DIR="$(DATA_DIR)"; \
|
||||
if [ -z "$$DATA_DIR" ]; then \
|
||||
DATA_DIR="$(HOME)/emblem"; \
|
||||
fi; \
|
||||
cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/check_accuracy.py "$$MODEL_PATH" "$$DATA_DIR"
|
||||
|
||||
OPENCV_TAG := 4.9.0
|
||||
opencv/src/LICENSE:
|
||||
rm -rf opencv/src opencv/contrib
|
||||
|
||||
@ -1,38 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import streamlit as st
|
||||
import random
|
||||
import json
|
||||
import re
|
||||
from PIL import Image
|
||||
|
||||
def get_mispredicted_scans():
|
||||
with open('data/verify2.log', 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
fields = line.split()
|
||||
if len(fields) != 6:
|
||||
def get_scan_ids(data_dir):
|
||||
scans_dir = os.path.join(data_dir, 'scans')
|
||||
if not os.path.exists(scans_dir):
|
||||
return []
|
||||
scan_ids = []
|
||||
for item in os.listdir(scans_dir):
|
||||
scan_path = os.path.join(scans_dir, item)
|
||||
if os.path.isdir(scan_path):
|
||||
try:
|
||||
scan_ids.append(int(item))
|
||||
except ValueError:
|
||||
continue
|
||||
return sorted(scan_ids, reverse=True)
|
||||
|
||||
def parse_filter(filter_str):
|
||||
"""Parse filter string like '357193-358023 or 358024-358808' into list of ranges."""
|
||||
if not filter_str or not filter_str.strip():
|
||||
return []
|
||||
|
||||
ranges = []
|
||||
# Split by 'or' (case insensitive)
|
||||
parts = re.split(r'\s+or\s+', filter_str.strip(), flags=re.IGNORECASE)
|
||||
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if fields[1] != fields[2]:
|
||||
yield fields[0]
|
||||
|
||||
# Try to parse as range (start-end)
|
||||
match = re.match(r'(\d+)\s*-\s*(\d+)', part)
|
||||
if match:
|
||||
start = int(match.group(1))
|
||||
end = int(match.group(2))
|
||||
ranges.append((start, end))
|
||||
else:
|
||||
# Try to parse as single number
|
||||
try:
|
||||
num = int(part)
|
||||
ranges.append((num, num))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return ranges
|
||||
|
||||
def filter_scans(scan_ids, ranges):
|
||||
"""Filter scan IDs based on ranges."""
|
||||
if not ranges:
|
||||
return scan_ids
|
||||
|
||||
filtered = set()
|
||||
for start, end in ranges:
|
||||
for sid in scan_ids:
|
||||
if start <= sid <= end:
|
||||
filtered.add(sid)
|
||||
|
||||
return sorted(filtered, reverse=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/emblem'), help='Data directory')
|
||||
args, unknown = parser.parse_known_args()
|
||||
data_dir = args.data_dir
|
||||
|
||||
st.title('Browser')
|
||||
|
||||
# scan_ids = os.listdir('data/scans')
|
||||
# to_show = sorted(scan_ids, key=lambda x: int(x), reverse=True)[:100]
|
||||
to_show = list(get_mispredicted_scans())
|
||||
st.write(f'to show: {len(to_show)}')
|
||||
for sid in to_show:
|
||||
show_scan(sid)
|
||||
scan_ids = get_scan_ids(data_dir)
|
||||
total_scans = len(scan_ids)
|
||||
|
||||
if total_scans == 0:
|
||||
st.write('No scans found')
|
||||
return
|
||||
|
||||
def show_scan(scan_id):
|
||||
scan_dir = f'data/scans/{scan_id}'
|
||||
mdfile = f'{scan_dir}/metadata.json'
|
||||
md = json.load(open(mdfile))
|
||||
st.write(f'Total scans: {total_scans}')
|
||||
|
||||
# Smart filter input
|
||||
filter_str = st.text_input('Filter (e.g., "357193-358023 or 358024-358808")', value='')
|
||||
|
||||
if filter_str:
|
||||
ranges = parse_filter(filter_str)
|
||||
if ranges:
|
||||
to_show = filter_scans(scan_ids, ranges)
|
||||
st.write(f'Filter matched: {len(to_show)} scans')
|
||||
else:
|
||||
st.warning('No valid ranges found in filter')
|
||||
return
|
||||
else:
|
||||
# Default: show last 500 scans
|
||||
default_count = min(500, total_scans)
|
||||
to_show = scan_ids[:default_count]
|
||||
st.write(f'Showing last {len(to_show)} scans (default)')
|
||||
|
||||
for sid in to_show:
|
||||
show_scan(str(sid), data_dir)
|
||||
|
||||
def show_scan(scan_id, data_dir):
|
||||
scan_dir = os.path.join(data_dir, 'scans', scan_id)
|
||||
mdfile = os.path.join(scan_dir, 'metadata.json')
|
||||
if not os.path.exists(mdfile):
|
||||
return
|
||||
sbs = Image.open(f'{scan_dir}/sbs.jpg')
|
||||
st.write(f'{scan_id}: {md["labels"]}')
|
||||
md = json.load(open(mdfile))
|
||||
sbs_path = os.path.join(scan_dir, 'sbs.jpg')
|
||||
if not os.path.exists(sbs_path):
|
||||
return
|
||||
sbs = Image.open(sbs_path)
|
||||
st.write(f'{scan_id}: {md.get("labels", "N/A")}')
|
||||
st.write(f'SBS: {sbs_path}')
|
||||
st.image(sbs.resize((512, 256)))
|
||||
st.divider()
|
||||
|
||||
|
||||
139
emblem5/ai/check_accuracy.py
Executable file
139
emblem5/ai/check_accuracy.py
Executable file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Temporary script to check model accuracy on specific scan ranges
|
||||
Positive samples: 357193-358023
|
||||
Negative samples: 358024-358808
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Add current directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from common import load_model, predict
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python3 check_accuracy.py <model_path> <data_dir>")
|
||||
print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem")
|
||||
sys.exit(1)
|
||||
|
||||
model_path = sys.argv[1]
|
||||
data_dir = sys.argv[2]
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Define scan ranges
|
||||
pos_start, pos_end = 357193, 358023
|
||||
neg_start, neg_end = 358024, 358808
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
model, transforms = load_model(model_path)
|
||||
# Ensure model is not compiled
|
||||
if hasattr(model, '_orig_mod'):
|
||||
model = model._orig_mod
|
||||
print("Model loaded successfully\n")
|
||||
|
||||
scans_dir = os.path.join(data_dir, 'scans')
|
||||
if not os.path.exists(scans_dir):
|
||||
print(f"Error: Scans directory not found: {scans_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Collect scan IDs in ranges
|
||||
pos_scans = []
|
||||
neg_scans = []
|
||||
|
||||
print("Collecting scan IDs...")
|
||||
for scan_id in range(pos_start, pos_end + 1):
|
||||
scan_dir = os.path.join(scans_dir, str(scan_id))
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
if os.path.exists(sbs_file):
|
||||
pos_scans.append(str(scan_id))
|
||||
|
||||
for scan_id in range(neg_start, neg_end + 1):
|
||||
scan_dir = os.path.join(scans_dir, str(scan_id))
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
if os.path.exists(sbs_file):
|
||||
neg_scans.append(str(scan_id))
|
||||
|
||||
print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})")
|
||||
print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n")
|
||||
|
||||
# Run predictions
|
||||
pos_correct = 0
|
||||
pos_total = 0
|
||||
neg_correct = 0
|
||||
neg_total = 0
|
||||
|
||||
print("Evaluating positive samples...")
|
||||
for scan_id in tqdm(pos_scans, desc="Positive"):
|
||||
scan_dir = os.path.join(scans_dir, scan_id)
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
|
||||
try:
|
||||
image = Image.open(sbs_file).convert('RGB')
|
||||
predicted_class, probabilities = predict(model, transforms, image)
|
||||
|
||||
pos_total += 1
|
||||
if predicted_class == 1:
|
||||
pos_correct += 1
|
||||
except Exception as e:
|
||||
print(f"\nError processing {scan_id}: {e}")
|
||||
continue
|
||||
|
||||
print("\nEvaluating negative samples...")
|
||||
neg_errors = []
|
||||
for scan_id in tqdm(neg_scans, desc="Negative"):
|
||||
scan_dir = os.path.join(scans_dir, scan_id)
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
|
||||
try:
|
||||
image = Image.open(sbs_file).convert('RGB')
|
||||
predicted_class, probabilities = predict(model, transforms, image)
|
||||
|
||||
neg_sum = sum([p[0] for p in probabilities])
|
||||
pos_sum = sum([p[1] for p in probabilities])
|
||||
total = neg_sum + pos_sum
|
||||
neg_prob = neg_sum / total if total > 0 else 0
|
||||
pos_prob = pos_sum / total if total > 0 else 0
|
||||
|
||||
neg_total += 1
|
||||
if predicted_class == 0:
|
||||
neg_correct += 1
|
||||
else:
|
||||
# Store first few errors for debugging
|
||||
if len(neg_errors) < 5:
|
||||
neg_errors.append((scan_id, neg_prob, pos_prob))
|
||||
except Exception as e:
|
||||
print(f"\nError processing {scan_id}: {e}")
|
||||
continue
|
||||
|
||||
# Calculate accuracies
|
||||
pos_acc = pos_correct / pos_total if pos_total > 0 else 0
|
||||
neg_acc = neg_correct / neg_total if neg_total > 0 else 0
|
||||
total_correct = pos_correct + neg_correct
|
||||
total_samples = pos_total + neg_total
|
||||
overall_acc = total_correct / total_samples if total_samples > 0 else 0
|
||||
|
||||
# Print results
|
||||
print("\n" + "="*60)
|
||||
print("Accuracy Results:")
|
||||
print("="*60)
|
||||
print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})")
|
||||
print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})")
|
||||
print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})")
|
||||
print("="*60)
|
||||
|
||||
if neg_errors:
|
||||
print("\nSample negative misclassifications:")
|
||||
for scan_id, neg_prob, pos_prob in neg_errors:
|
||||
print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@ -94,7 +94,10 @@ def make_side_by_side_img_with_margins(frame_img, std_img):
|
||||
|
||||
# Use a reasonable size - use the larger dimension or at least 512
|
||||
edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512)
|
||||
margin_ratio = find_min_margin_ratio(std_img, std_corners)
|
||||
std_margin_ratio = find_min_margin_ratio(std_img, std_corners)
|
||||
frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners)
|
||||
# Use the minimum margin ratio to ensure both QR areas are captured correctly
|
||||
margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05)
|
||||
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
|
||||
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)
|
||||
|
||||
|
||||
@ -5,14 +5,18 @@ import sys
|
||||
import json
|
||||
import random
|
||||
import argparse
|
||||
from functools import partial
|
||||
from common import *
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-dir', required=True)
|
||||
parser.add_argument('--scan-id', help='Process only this specific scan ID')
|
||||
parser.add_argument('--scan-id-range', help='Process scan IDs in range (e.g., "358626-358630" or "358626-358630,359000-359010")')
|
||||
parser.add_argument('--force', action='store_true', help='Overwrite existing sbs.jpg files')
|
||||
return parser.parse_args()
|
||||
|
||||
def process_scan(scan_dir):
|
||||
def process_scan(scan_dir, force=False):
|
||||
if not os.path.isdir(scan_dir):
|
||||
return "scan_dir not found"
|
||||
frame_file = os.path.join(scan_dir, 'frame.jpg')
|
||||
@ -20,10 +24,8 @@ def process_scan(scan_dir):
|
||||
if not os.path.exists(frame_file) or not os.path.exists(std_file):
|
||||
return "frame.jpg or std.jpg not found"
|
||||
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
|
||||
frame_qr_file = os.path.join(scan_dir, 'frame-qr.jpg')
|
||||
std_qr_file = os.path.join(scan_dir, 'std-qr.jpg')
|
||||
try:
|
||||
if not os.path.exists(sbs_file):
|
||||
if force or not os.path.exists(sbs_file):
|
||||
frame_img = Image.open(frame_file)
|
||||
std_img = Image.open(std_file)
|
||||
sbs_img = make_side_by_side_img_with_margins(frame_img, std_img)
|
||||
@ -39,10 +41,33 @@ def main():
|
||||
args = parse_args()
|
||||
data_dir = args.data_dir
|
||||
scans_dir = os.path.join(data_dir, 'scans')
|
||||
pool = Pool(cpu_count())
|
||||
scan_ids = os.listdir(scans_dir)
|
||||
if args.scan_id:
|
||||
if args.scan_id not in scan_ids:
|
||||
print(f"Error: scan-id '{args.scan_id}' not found")
|
||||
sys.exit(1)
|
||||
scan_ids = [args.scan_id]
|
||||
elif args.scan_id_range:
|
||||
ranges = parse_ranges(args.scan_id_range)
|
||||
filtered_scan_ids = []
|
||||
for scan_id in scan_ids:
|
||||
try:
|
||||
scan_id_int = int(scan_id)
|
||||
for val_range in ranges:
|
||||
if in_range(scan_id_int, val_range):
|
||||
filtered_scan_ids.append(scan_id)
|
||||
break
|
||||
except ValueError:
|
||||
# Skip non-numeric scan IDs
|
||||
continue
|
||||
scan_ids = filtered_scan_ids
|
||||
if not scan_ids:
|
||||
print(f"Error: No scan IDs found in range '{args.scan_id_range}'")
|
||||
sys.exit(1)
|
||||
pool = Pool(cpu_count())
|
||||
counts = defaultdict(int)
|
||||
for result in tqdm(pool.imap(process_scan, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)):
|
||||
process_scan_with_force = partial(process_scan, force=args.force)
|
||||
for result in tqdm(pool.imap(process_scan_with_force, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)):
|
||||
counts[result] += 1
|
||||
for k, v in counts.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
@ -233,45 +233,22 @@ def load_scan_data(data_dir):
|
||||
def split_train_val(scan_data):
|
||||
"""Split data into train and validation sets"""
|
||||
scan_ids = list(scan_data.keys())
|
||||
random.shuffle(scan_ids)
|
||||
|
||||
print("Splitting data into train and validation sets...")
|
||||
|
||||
# Select 10% random scan_ids for initial validation set
|
||||
val_size = int(len(scan_ids) * 0.1)
|
||||
initial_val_scan_ids = random.sample(scan_ids, val_size)
|
||||
# Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation
|
||||
train_size = int(len(scan_ids) * 0.2)
|
||||
val_size = int(len(scan_ids) * 0.4)
|
||||
|
||||
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation")
|
||||
train_scan_ids = scan_ids[:train_size]
|
||||
val_scan_ids = scan_ids[train_size:train_size + val_size]
|
||||
|
||||
# Get all (code, labels) combinations that appear in the initial validation set
|
||||
val_code_labels = set()
|
||||
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
|
||||
if scan_id in scan_data:
|
||||
code = scan_data[scan_id].get('code')
|
||||
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||
if code:
|
||||
val_code_labels.add((code, labels))
|
||||
|
||||
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
|
||||
|
||||
# Find all scans in train that have matching (code, labels) combinations and move them to validation
|
||||
additional_val_scan_ids = set()
|
||||
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
|
||||
|
||||
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
|
||||
if scan_id in scan_data:
|
||||
code = scan_data[scan_id].get('code')
|
||||
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
|
||||
# If (code, labels) combination matches, move to validation
|
||||
if code and (code, labels) in val_code_labels:
|
||||
additional_val_scan_ids.add(scan_id)
|
||||
|
||||
# Combine validation sets
|
||||
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
|
||||
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
|
||||
print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)")
|
||||
|
||||
# Create train and validation data dictionaries
|
||||
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids}
|
||||
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids}
|
||||
train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids}
|
||||
val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids}
|
||||
|
||||
print(f"Final split:")
|
||||
print(f" Total scans: {len(scan_data)}")
|
||||
@ -418,7 +395,8 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler
|
||||
best_val_acc = val_acc
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
mode_suffix = "_quick" if args.quick else ""
|
||||
best_model_path = f'models/best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt'
|
||||
models_dir = os.path.expanduser('~/emblem/models')
|
||||
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
||||
save_model(model, transform, best_model_path, train_metadata)
|
||||
print(f' New best validation accuracy: {val_acc:.2f}%')
|
||||
print(f' Best model saved: {best_model_path}')
|
||||
@ -445,7 +423,8 @@ def main():
|
||||
print(f'Using hue jitter: {args.hue_jitter}')
|
||||
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs('models', exist_ok=True)
|
||||
models_dir = os.path.expanduser('~/emblem/models')
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
# Load scan data
|
||||
print("Loading scan data...")
|
||||
@ -510,7 +489,11 @@ def main():
|
||||
# Create model
|
||||
print("Creating ResNet18 model...")
|
||||
# model = torchvision.models.resnet18(pretrained=True)
|
||||
model, _ = load_model('models/gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt')
|
||||
models_dir = os.path.expanduser('~/emblem/models')
|
||||
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt'))
|
||||
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
|
||||
if hasattr(model, '_orig_mod'):
|
||||
model = model._orig_mod
|
||||
# Modify final layer for binary classification
|
||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||
model = model.to(device)
|
||||
@ -564,7 +547,8 @@ def main():
|
||||
# Save final model
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
mode_suffix = "_quick" if args.quick else ""
|
||||
final_model_path = f'models/final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt'
|
||||
models_dir = os.path.expanduser('~/emblem/models')
|
||||
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
|
||||
save_model(model, transform, final_model_path, train_metadata)
|
||||
print(f"Training completed! Final model saved: {final_model_path}")
|
||||
|
||||
|
||||
64
emblem5/ai/verify_image.py
Executable file
64
emblem5/ai/verify_image.py
Executable file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple script to verify a single image using a trained model
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add current directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from PIL import Image
|
||||
from common import load_model, predict
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python3 verify_image.py <model_path> <image_path>")
|
||||
print("Example: python3 verify_image.py models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem/scans/358626/sbs.jpg")
|
||||
sys.exit(1)
|
||||
|
||||
model_path = sys.argv[1]
|
||||
image_path = sys.argv[2]
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
print(f"Error: Image file not found: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
model, transforms = load_model(model_path)
|
||||
# Ensure model is not compiled (torch.compile can cause issues when loading)
|
||||
if hasattr(model, '_orig_mod'):
|
||||
model = model._orig_mod
|
||||
print("Model loaded successfully")
|
||||
|
||||
print(f"Loading image from {image_path}...")
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
print(f"Image size: {image.size}")
|
||||
|
||||
print("Running prediction...")
|
||||
predicted_class, probabilities = predict(model, transforms, image)
|
||||
|
||||
# probabilities is a list of [neg_prob, pos_prob] for each cell
|
||||
# Sum up probabilities across all cells
|
||||
neg_sum = sum([p[0] for p in probabilities])
|
||||
pos_sum = sum([p[1] for p in probabilities])
|
||||
total = neg_sum + pos_sum
|
||||
neg_prob = neg_sum / total if total > 0 else 0
|
||||
pos_prob = pos_sum / total if total > 0 else 0
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Prediction Results:")
|
||||
print("="*50)
|
||||
print(f"Predicted class: {'POSITIVE' if predicted_class == 1 else 'NEGATIVE'}")
|
||||
print(f"Negative probability: {neg_prob:.2%}")
|
||||
print(f"Positive probability: {pos_prob:.2%}")
|
||||
print(f"Number of cells evaluated: {len(probabilities)}")
|
||||
print("="*50)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user