themblem/research/roi_bench.py
2025-03-25 13:29:58 +08:00

57 lines
1.7 KiB
Python
Executable File

#!/usr/bin/env python3
from PIL import Image
import os
import argparse
import random
from roi_lib import *
def parse_args():
parser = argparse.ArgumentParser(description='ROI prediction')
parser.add_argument('--model', type=str, required=True, help='model path')
parser.add_argument('--label', type=str, required=True, help='label file')
return parser.parse_args()
# 主函数
def main():
args = parse_args()
# 加载模型
model = load_model(args.model)
with open(args.label, 'r') as f:
lines = f.readlines()
correct = 0
zero_to_one = 0
one_to_zero = 0
testset = []
for line in lines:
if not line.strip():
continue
fs = line.strip().split()
if len(fs) != 2:
continue
image_name = fs[0]
label = fs[1]
testset.append((image_name, label))
random.shuffle(testset)
done = 0
for image_name, label in testset:
image_path = os.path.abspath(os.path.join(os.path.dirname(args.label), image_name))
image_tensor = preprocess_image(image_path)
predicted_class, probabilities = predict(model, image_tensor)
done += 1
if str(predicted_class) == label:
correct += 1
else:
if label == '0':
zero_to_one += 1
else:
one_to_zero += 1
accu_so_far = correct / done
print(f'{image_path} label={label} predicted={predicted_class} prob={probabilities} accuracy so far: {accu_so_far:.3f}')
print(f'total={done} correct={correct} accuracy={accu_so_far:.3f}')
print(f'zero_to_one={zero_to_one} one_to_zero={one_to_zero}')
return 0
if __name__ == '__main__':
sys.exit(main())