diff --git a/alg/qrtool.cpp b/alg/qrtool.cpp index 3e51845..093dfd5 100644 --- a/alg/qrtool.cpp +++ b/alg/qrtool.cpp @@ -14,6 +14,8 @@ #include #include #include "opencv2/objdetect.hpp" +#include + #include "mq_worker.h" #if ENABLE_GRPC #include "fileprocess.h" @@ -607,6 +609,31 @@ void usage(const char *name, vector &cmds) } } +static +int roi_verify_cmd(char **argv, int argc) +{ + char *model_path = argv[0]; + char *input_file = argv[1]; + Mat input_img = imread(input_file); + cv::dnn::Net net = cv::dnn::readNetFromONNX(model_path); + if (net.empty()) { + std::cerr << "Failed to load ONNX model!" << std::endl; + return -1; + } + cv::Mat blob; + cv::resize(input_img, input_img, cv::Size(128, 64)); // 调整图像大小 + cv::dnn::blobFromImage(input_img, blob, 1.0 / 255.0, cv::Size(64, 128), cv::Scalar(0.485, 0.456, 0.406), true, false); + blob = (blob - cv::Scalar(0.485, 0.456, 0.406)) / cv::Scalar(0.229, 0.224, 0.225); // 归一化 + net.setInput(blob); + cv::Mat output = net.forward(); + float* data = (float*)output.data; + int class_id = std::max_element(data, data + output.total()) - data; // 找到最大概率的类别 + float confidence = data[class_id]; + std::cout << "Predicted class: " << class_id << ", Confidence: " << confidence << std::endl; + + return 0; +} + #ifdef QRTOOL_MAIN int main(int argc, char *argv[]) { @@ -640,6 +667,7 @@ int main(int argc, char *argv[]) #endif add_cmd(http_server, 1); add_cmd(verify, 2); + add_cmd(roi_verify, 2); usage(argv[0], cmds); return 1; }