alg: Add roi_verify cmd

This commit is contained in:
Fam Zheng 2025-03-23 09:07:42 -07:00
parent 4e3937b471
commit bcc06bf255

View File

@ -14,6 +14,8 @@
#include <stdlib.h>
#include <unistd.h>
#include "opencv2/objdetect.hpp"
#include <opencv2/dnn.hpp>
#include "mq_worker.h"
#if ENABLE_GRPC
#include "fileprocess.h"
@ -607,6 +609,31 @@ void usage(const char *name, vector<string> &cmds)
}
}
static
int roi_verify_cmd(char **argv, int argc)
{
char *model_path = argv[0];
char *input_file = argv[1];
Mat input_img = imread(input_file);
cv::dnn::Net net = cv::dnn::readNetFromONNX(model_path);
if (net.empty()) {
std::cerr << "Failed to load ONNX model!" << std::endl;
return -1;
}
cv::Mat blob;
cv::resize(input_img, input_img, cv::Size(128, 64)); // 调整图像大小
cv::dnn::blobFromImage(input_img, blob, 1.0 / 255.0, cv::Size(64, 128), cv::Scalar(0.485, 0.456, 0.406), true, false);
blob = (blob - cv::Scalar(0.485, 0.456, 0.406)) / cv::Scalar(0.229, 0.224, 0.225); // 归一化
net.setInput(blob);
cv::Mat output = net.forward();
float* data = (float*)output.data;
int class_id = std::max_element(data, data + output.total()) - data; // 找到最大概率的类别
float confidence = data[class_id];
std::cout << "Predicted class: " << class_id << ", Confidence: " << confidence << std::endl;
return 0;
}
#ifdef QRTOOL_MAIN
int main(int argc, char *argv[])
{
@ -640,6 +667,7 @@ int main(int argc, char *argv[])
#endif
add_cmd(http_server, 1);
add_cmd(verify, 2);
add_cmd(roi_verify, 2);
usage(argv[0], cmds);
return 1;
}