update train2

This commit is contained in:
Fam Zheng 2025-12-26 13:55:52 +00:00
parent 01e3543be9
commit aa2d4041e3
8 changed files with 402 additions and 65 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ build
/detection/model /detection/model
/api/db.sqlite3 /api/db.sqlite3
/emblemscanner-release /emblemscanner-release
/models

View File

@ -1,4 +1,4 @@
.PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick .PHONY: FORCE emblemscanner-release fetch fetch-quick sbs-quick train-quick verify
DATA_DIR ?= $(HOME)/emblem DATA_DIR ?= $(HOME)/emblem
@ -102,6 +102,47 @@ train: FORCE
train-quick: FORCE train-quick: FORCE
cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --quick --epochs 2 cd emblem5 && uv run --with-requirements ../requirements.txt ./ai/train2.py --data-dir $(DATA_DIR) --quick --epochs 2
browser: FORCE
cd emblem5 && uv run --with-requirements ../requirements.txt streamlit run ./ai/browser.py -- --data-dir $(DATA_DIR)
verify: FORCE
@if [ -z "$(SBS_IMG)" ]; then \
echo "Error: SBS_IMG is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \
exit 1; \
fi
@if [ -z "$(MODEL)" ]; then \
echo "Error: MODEL is required. Usage: make verify SBS_IMG=/path/to/sbs.jpg MODEL=/path/to/model.pt"; \
exit 1; \
fi
@MODEL_PATH="$(MODEL)"; \
if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \
MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \
fi; \
if [ ! -f "$$MODEL_PATH" ]; then \
echo "Error: Model file not found: $$MODEL_PATH"; \
exit 1; \
fi; \
cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/verify_image.py "$$MODEL_PATH" "$(SBS_IMG)"
check-accuracy: FORCE
@if [ -z "$(MODEL)" ]; then \
echo "Error: MODEL is required. Usage: make check-accuracy MODEL=/path/to/model.pt [DATA_DIR=/path/to/data]"; \
exit 1; \
fi
@MODEL_PATH="$(MODEL)"; \
if [ "$${MODEL_PATH#/}" = "$$MODEL_PATH" ]; then \
MODEL_PATH="$(shell pwd)/$$MODEL_PATH"; \
fi; \
if [ ! -f "$$MODEL_PATH" ]; then \
echo "Error: Model file not found: $$MODEL_PATH"; \
exit 1; \
fi; \
DATA_DIR="$(DATA_DIR)"; \
if [ -z "$$DATA_DIR" ]; then \
DATA_DIR="$(HOME)/emblem"; \
fi; \
cd emblem5 && uv run --with-requirements ../requirements.txt python3 ./ai/check_accuracy.py "$$MODEL_PATH" "$$DATA_DIR"
OPENCV_TAG := 4.9.0 OPENCV_TAG := 4.9.0
opencv/src/LICENSE: opencv/src/LICENSE:
rm -rf opencv/src opencv/contrib rm -rf opencv/src opencv/contrib

View File

@ -1,38 +1,118 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import sys
import argparse
import streamlit as st import streamlit as st
import random
import json import json
import re
from PIL import Image from PIL import Image
def get_mispredicted_scans(): def get_scan_ids(data_dir):
with open('data/verify2.log', 'r') as f: scans_dir = os.path.join(data_dir, 'scans')
lines = f.readlines() if not os.path.exists(scans_dir):
for line in lines: return []
fields = line.split() scan_ids = []
if len(fields) != 6: for item in os.listdir(scans_dir):
scan_path = os.path.join(scans_dir, item)
if os.path.isdir(scan_path):
try:
scan_ids.append(int(item))
except ValueError:
continue continue
if fields[1] != fields[2]: return sorted(scan_ids, reverse=True)
yield fields[0]
def parse_filter(filter_str):
"""Parse filter string like '357193-358023 or 358024-358808' into list of ranges."""
if not filter_str or not filter_str.strip():
return []
ranges = []
# Split by 'or' (case insensitive)
parts = re.split(r'\s+or\s+', filter_str.strip(), flags=re.IGNORECASE)
for part in parts:
part = part.strip()
if not part:
continue
# Try to parse as range (start-end)
match = re.match(r'(\d+)\s*-\s*(\d+)', part)
if match:
start = int(match.group(1))
end = int(match.group(2))
ranges.append((start, end))
else:
# Try to parse as single number
try:
num = int(part)
ranges.append((num, num))
except ValueError:
continue
return ranges
def filter_scans(scan_ids, ranges):
"""Filter scan IDs based on ranges."""
if not ranges:
return scan_ids
filtered = set()
for start, end in ranges:
for sid in scan_ids:
if start <= sid <= end:
filtered.add(sid)
return sorted(filtered, reverse=True)
def main(): def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/emblem'), help='Data directory')
args, unknown = parser.parse_known_args()
data_dir = args.data_dir
st.title('Browser') st.title('Browser')
# scan_ids = os.listdir('data/scans') scan_ids = get_scan_ids(data_dir)
# to_show = sorted(scan_ids, key=lambda x: int(x), reverse=True)[:100] total_scans = len(scan_ids)
to_show = list(get_mispredicted_scans())
st.write(f'to show: {len(to_show)}')
for sid in to_show:
show_scan(sid)
def show_scan(scan_id): if total_scans == 0:
scan_dir = f'data/scans/{scan_id}' st.write('No scans found')
mdfile = f'{scan_dir}/metadata.json' return
md = json.load(open(mdfile))
st.write(f'Total scans: {total_scans}')
# Smart filter input
filter_str = st.text_input('Filter (e.g., "357193-358023 or 358024-358808")', value='')
if filter_str:
ranges = parse_filter(filter_str)
if ranges:
to_show = filter_scans(scan_ids, ranges)
st.write(f'Filter matched: {len(to_show)} scans')
else:
st.warning('No valid ranges found in filter')
return
else:
# Default: show last 500 scans
default_count = min(500, total_scans)
to_show = scan_ids[:default_count]
st.write(f'Showing last {len(to_show)} scans (default)')
for sid in to_show:
show_scan(str(sid), data_dir)
def show_scan(scan_id, data_dir):
scan_dir = os.path.join(data_dir, 'scans', scan_id)
mdfile = os.path.join(scan_dir, 'metadata.json')
if not os.path.exists(mdfile): if not os.path.exists(mdfile):
return return
sbs = Image.open(f'{scan_dir}/sbs.jpg') md = json.load(open(mdfile))
st.write(f'{scan_id}: {md["labels"]}') sbs_path = os.path.join(scan_dir, 'sbs.jpg')
if not os.path.exists(sbs_path):
return
sbs = Image.open(sbs_path)
st.write(f'{scan_id}: {md.get("labels", "N/A")}')
st.write(f'SBS: {sbs_path}')
st.image(sbs.resize((512, 256))) st.image(sbs.resize((512, 256)))
st.divider() st.divider()

139
emblem5/ai/check_accuracy.py Executable file
View File

@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Temporary script to check model accuracy on specific scan ranges
Positive samples: 357193-358023
Negative samples: 358024-358808
"""
import sys
import os
import json
from PIL import Image
from tqdm import tqdm
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from common import load_model, predict
def main():
if len(sys.argv) < 3:
print("Usage: python3 check_accuracy.py <model_path> <data_dir>")
print("Example: python3 check_accuracy.py ../models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem")
sys.exit(1)
model_path = sys.argv[1]
data_dir = sys.argv[2]
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
# Define scan ranges
pos_start, pos_end = 357193, 358023
neg_start, neg_end = 358024, 358808
print(f"Loading model from {model_path}...")
model, transforms = load_model(model_path)
# Ensure model is not compiled
if hasattr(model, '_orig_mod'):
model = model._orig_mod
print("Model loaded successfully\n")
scans_dir = os.path.join(data_dir, 'scans')
if not os.path.exists(scans_dir):
print(f"Error: Scans directory not found: {scans_dir}")
sys.exit(1)
# Collect scan IDs in ranges
pos_scans = []
neg_scans = []
print("Collecting scan IDs...")
for scan_id in range(pos_start, pos_end + 1):
scan_dir = os.path.join(scans_dir, str(scan_id))
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
if os.path.exists(sbs_file):
pos_scans.append(str(scan_id))
for scan_id in range(neg_start, neg_end + 1):
scan_dir = os.path.join(scans_dir, str(scan_id))
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
if os.path.exists(sbs_file):
neg_scans.append(str(scan_id))
print(f"Found {len(pos_scans)} positive scans (expected {pos_end - pos_start + 1})")
print(f"Found {len(neg_scans)} negative scans (expected {neg_end - neg_start + 1})\n")
# Run predictions
pos_correct = 0
pos_total = 0
neg_correct = 0
neg_total = 0
print("Evaluating positive samples...")
for scan_id in tqdm(pos_scans, desc="Positive"):
scan_dir = os.path.join(scans_dir, scan_id)
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
try:
image = Image.open(sbs_file).convert('RGB')
predicted_class, probabilities = predict(model, transforms, image)
pos_total += 1
if predicted_class == 1:
pos_correct += 1
except Exception as e:
print(f"\nError processing {scan_id}: {e}")
continue
print("\nEvaluating negative samples...")
neg_errors = []
for scan_id in tqdm(neg_scans, desc="Negative"):
scan_dir = os.path.join(scans_dir, scan_id)
sbs_file = os.path.join(scan_dir, 'sbs.jpg')
try:
image = Image.open(sbs_file).convert('RGB')
predicted_class, probabilities = predict(model, transforms, image)
neg_sum = sum([p[0] for p in probabilities])
pos_sum = sum([p[1] for p in probabilities])
total = neg_sum + pos_sum
neg_prob = neg_sum / total if total > 0 else 0
pos_prob = pos_sum / total if total > 0 else 0
neg_total += 1
if predicted_class == 0:
neg_correct += 1
else:
# Store first few errors for debugging
if len(neg_errors) < 5:
neg_errors.append((scan_id, neg_prob, pos_prob))
except Exception as e:
print(f"\nError processing {scan_id}: {e}")
continue
# Calculate accuracies
pos_acc = pos_correct / pos_total if pos_total > 0 else 0
neg_acc = neg_correct / neg_total if neg_total > 0 else 0
total_correct = pos_correct + neg_correct
total_samples = pos_total + neg_total
overall_acc = total_correct / total_samples if total_samples > 0 else 0
# Print results
print("\n" + "="*60)
print("Accuracy Results:")
print("="*60)
print(f"Positive samples: {pos_correct}/{pos_total} correct ({pos_acc:.2%})")
print(f"Negative samples: {neg_correct}/{neg_total} correct ({neg_acc:.2%})")
print(f"Overall accuracy: {total_correct}/{total_samples} correct ({overall_acc:.2%})")
print("="*60)
if neg_errors:
print("\nSample negative misclassifications:")
for scan_id, neg_prob, pos_prob in neg_errors:
print(f" Scan {scan_id}: neg={neg_prob:.2%}, pos={pos_prob:.2%}")
if __name__ == '__main__':
main()

View File

@ -94,7 +94,10 @@ def make_side_by_side_img_with_margins(frame_img, std_img):
# Use a reasonable size - use the larger dimension or at least 512 # Use a reasonable size - use the larger dimension or at least 512
edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512) edge_length = max(std_img.width, std_img.height, frame_img.width, frame_img.height, 512)
margin_ratio = find_min_margin_ratio(std_img, std_corners) std_margin_ratio = find_min_margin_ratio(std_img, std_corners)
frame_margin_ratio = find_min_margin_ratio(frame_img, frame_corners)
# Use the minimum margin ratio to ensure both QR areas are captured correctly
margin_ratio = min(std_margin_ratio, frame_margin_ratio, 0.05)
std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio) std_warped = warp_with_margin_ratio(std_img, edge_length, std_corners, margin_ratio)
frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio) frame_warped = warp_with_margin_ratio(frame_img, edge_length, frame_corners, margin_ratio)

View File

@ -5,14 +5,18 @@ import sys
import json import json
import random import random
import argparse import argparse
from functools import partial
from common import * from common import *
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', required=True) parser.add_argument('--data-dir', required=True)
parser.add_argument('--scan-id', help='Process only this specific scan ID')
parser.add_argument('--scan-id-range', help='Process scan IDs in range (e.g., "358626-358630" or "358626-358630,359000-359010")')
parser.add_argument('--force', action='store_true', help='Overwrite existing sbs.jpg files')
return parser.parse_args() return parser.parse_args()
def process_scan(scan_dir): def process_scan(scan_dir, force=False):
if not os.path.isdir(scan_dir): if not os.path.isdir(scan_dir):
return "scan_dir not found" return "scan_dir not found"
frame_file = os.path.join(scan_dir, 'frame.jpg') frame_file = os.path.join(scan_dir, 'frame.jpg')
@ -20,10 +24,8 @@ def process_scan(scan_dir):
if not os.path.exists(frame_file) or not os.path.exists(std_file): if not os.path.exists(frame_file) or not os.path.exists(std_file):
return "frame.jpg or std.jpg not found" return "frame.jpg or std.jpg not found"
sbs_file = os.path.join(scan_dir, 'sbs.jpg') sbs_file = os.path.join(scan_dir, 'sbs.jpg')
frame_qr_file = os.path.join(scan_dir, 'frame-qr.jpg')
std_qr_file = os.path.join(scan_dir, 'std-qr.jpg')
try: try:
if not os.path.exists(sbs_file): if force or not os.path.exists(sbs_file):
frame_img = Image.open(frame_file) frame_img = Image.open(frame_file)
std_img = Image.open(std_file) std_img = Image.open(std_file)
sbs_img = make_side_by_side_img_with_margins(frame_img, std_img) sbs_img = make_side_by_side_img_with_margins(frame_img, std_img)
@ -39,10 +41,33 @@ def main():
args = parse_args() args = parse_args()
data_dir = args.data_dir data_dir = args.data_dir
scans_dir = os.path.join(data_dir, 'scans') scans_dir = os.path.join(data_dir, 'scans')
pool = Pool(cpu_count())
scan_ids = os.listdir(scans_dir) scan_ids = os.listdir(scans_dir)
if args.scan_id:
if args.scan_id not in scan_ids:
print(f"Error: scan-id '{args.scan_id}' not found")
sys.exit(1)
scan_ids = [args.scan_id]
elif args.scan_id_range:
ranges = parse_ranges(args.scan_id_range)
filtered_scan_ids = []
for scan_id in scan_ids:
try:
scan_id_int = int(scan_id)
for val_range in ranges:
if in_range(scan_id_int, val_range):
filtered_scan_ids.append(scan_id)
break
except ValueError:
# Skip non-numeric scan IDs
continue
scan_ids = filtered_scan_ids
if not scan_ids:
print(f"Error: No scan IDs found in range '{args.scan_id_range}'")
sys.exit(1)
pool = Pool(cpu_count())
counts = defaultdict(int) counts = defaultdict(int)
for result in tqdm(pool.imap(process_scan, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)): process_scan_with_force = partial(process_scan, force=args.force)
for result in tqdm(pool.imap(process_scan_with_force, [os.path.join(scans_dir, scan_id) for scan_id in scan_ids]), total=len(scan_ids)):
counts[result] += 1 counts[result] += 1
for k, v in counts.items(): for k, v in counts.items():
print(f"{k}: {v}") print(f"{k}: {v}")

View File

@ -233,45 +233,22 @@ def load_scan_data(data_dir):
def split_train_val(scan_data): def split_train_val(scan_data):
"""Split data into train and validation sets""" """Split data into train and validation sets"""
scan_ids = list(scan_data.keys()) scan_ids = list(scan_data.keys())
random.shuffle(scan_ids)
print("Splitting data into train and validation sets...") print("Splitting data into train and validation sets...")
# Select 10% random scan_ids for initial validation set # Select random 1/5 (20%) for train and random 2/5 (40%) for test/validation
val_size = int(len(scan_ids) * 0.1) train_size = int(len(scan_ids) * 0.2)
initial_val_scan_ids = random.sample(scan_ids, val_size) val_size = int(len(scan_ids) * 0.4)
print(f"Initial random split: {len(scan_ids) - val_size} train, {val_size} validation") train_scan_ids = scan_ids[:train_size]
val_scan_ids = scan_ids[train_size:train_size + val_size]
# Get all (code, labels) combinations that appear in the initial validation set print(f"Random split: {len(train_scan_ids)} train (1/5), {len(val_scan_ids)} validation (2/5)")
val_code_labels = set()
for scan_id in tqdm(initial_val_scan_ids, desc="Collecting validation code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
if code:
val_code_labels.add((code, labels))
print(f"Found {len(val_code_labels)} unique (code, labels) combinations in validation set")
# Find all scans in train that have matching (code, labels) combinations and move them to validation
additional_val_scan_ids = set()
train_scan_ids = set(scan_ids) - set(initial_val_scan_ids)
for scan_id in tqdm(train_scan_ids, desc="Finding scans with matching code-label combinations"):
if scan_id in scan_data:
code = scan_data[scan_id].get('code')
labels = tuple(sorted(scan_data[scan_id].get('labels', [])))
# If (code, labels) combination matches, move to validation
if code and (code, labels) in val_code_labels:
additional_val_scan_ids.add(scan_id)
# Combine validation sets
all_val_scan_ids = set(initial_val_scan_ids) | additional_val_scan_ids
all_train_scan_ids = set(scan_data.keys()) - all_val_scan_ids
# Create train and validation data dictionaries # Create train and validation data dictionaries
train_data = {scan_id: scan_data[scan_id] for scan_id in all_train_scan_ids} train_data = {scan_id: scan_data[scan_id] for scan_id in train_scan_ids}
val_data = {scan_id: scan_data[scan_id] for scan_id in all_val_scan_ids} val_data = {scan_id: scan_data[scan_id] for scan_id in val_scan_ids}
print(f"Final split:") print(f"Final split:")
print(f" Total scans: {len(scan_data)}") print(f" Total scans: {len(scan_data)}")
@ -418,7 +395,8 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler
best_val_acc = val_acc best_val_acc = val_acc
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else "" mode_suffix = "_quick" if args.quick else ""
best_model_path = f'models/best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' models_dir = os.path.expanduser('~/emblem/models')
best_model_path = os.path.join(models_dir, f'best_model_ep{epoch+1}_pos{val_pos_acc:.2f}_neg{val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, best_model_path, train_metadata) save_model(model, transform, best_model_path, train_metadata)
print(f' New best validation accuracy: {val_acc:.2f}%') print(f' New best validation accuracy: {val_acc:.2f}%')
print(f' Best model saved: {best_model_path}') print(f' Best model saved: {best_model_path}')
@ -445,7 +423,8 @@ def main():
print(f'Using hue jitter: {args.hue_jitter}') print(f'Using hue jitter: {args.hue_jitter}')
# Create models directory if it doesn't exist # Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True) models_dir = os.path.expanduser('~/emblem/models')
os.makedirs(models_dir, exist_ok=True)
# Load scan data # Load scan data
print("Loading scan data...") print("Loading scan data...")
@ -510,7 +489,11 @@ def main():
# Create model # Create model
print("Creating ResNet18 model...") print("Creating ResNet18 model...")
# model = torchvision.models.resnet18(pretrained=True) # model = torchvision.models.resnet18(pretrained=True)
model, _ = load_model('models/gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt') models_dir = os.path.expanduser('~/emblem/models')
model, _ = load_model(os.path.join(models_dir, 'gridcrop-resnet-ep24-pos97.29-neg75.10-20250509_051300.pt'))
# Unwrap compiled model if it exists (torch.compile can cause issues when loading)
if hasattr(model, '_orig_mod'):
model = model._orig_mod
# Modify final layer for binary classification # Modify final layer for binary classification
# model.fc = nn.Linear(model.fc.in_features, 2) # model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device) model = model.to(device)
@ -564,7 +547,8 @@ def main():
# Save final model # Save final model
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
mode_suffix = "_quick" if args.quick else "" mode_suffix = "_quick" if args.quick else ""
final_model_path = f'models/final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt' models_dir = os.path.expanduser('~/emblem/models')
final_model_path = os.path.join(models_dir, f'final_model_ep{args.epochs}_pos{final_val_pos_acc:.2f}_neg{final_val_neg_acc:.2f}{mode_suffix}_{timestamp}.pt')
save_model(model, transform, final_model_path, train_metadata) save_model(model, transform, final_model_path, train_metadata)
print(f"Training completed! Final model saved: {final_model_path}") print(f"Training completed! Final model saved: {final_model_path}")

64
emblem5/ai/verify_image.py Executable file
View File

@ -0,0 +1,64 @@
#!/usr/bin/env python3
"""
Simple script to verify a single image using a trained model
"""
import sys
import os
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from PIL import Image
from common import load_model, predict
def main():
if len(sys.argv) < 3:
print("Usage: python3 verify_image.py <model_path> <image_path>")
print("Example: python3 verify_image.py models/best_model_ep30_pos99.35_neg94.75_20251125_182545.pt /home/fam/emblem/scans/358626/sbs.jpg")
sys.exit(1)
model_path = sys.argv[1]
image_path = sys.argv[2]
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
sys.exit(1)
if not os.path.exists(image_path):
print(f"Error: Image file not found: {image_path}")
sys.exit(1)
print(f"Loading model from {model_path}...")
model, transforms = load_model(model_path)
# Ensure model is not compiled (torch.compile can cause issues when loading)
if hasattr(model, '_orig_mod'):
model = model._orig_mod
print("Model loaded successfully")
print(f"Loading image from {image_path}...")
image = Image.open(image_path).convert('RGB')
print(f"Image size: {image.size}")
print("Running prediction...")
predicted_class, probabilities = predict(model, transforms, image)
# probabilities is a list of [neg_prob, pos_prob] for each cell
# Sum up probabilities across all cells
neg_sum = sum([p[0] for p in probabilities])
pos_sum = sum([p[1] for p in probabilities])
total = neg_sum + pos_sum
neg_prob = neg_sum / total if total > 0 else 0
pos_prob = pos_sum / total if total > 0 else 0
print("\n" + "="*50)
print("Prediction Results:")
print("="*50)
print(f"Predicted class: {'POSITIVE' if predicted_class == 1 else 'NEGATIVE'}")
print(f"Negative probability: {neg_prob:.2%}")
print(f"Positive probability: {pos_prob:.2%}")
print(f"Number of cells evaluated: {len(probabilities)}")
print("="*50)
if __name__ == '__main__':
main()