#include <MNN/Interpreter.hpp>
#include <MNN/ImageProcess.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Executor.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/Expr.hpp>
#include <opencv2/opencv.hpp>
class SemanticSegmentation {
public:
SemanticSegmentation(const std::string& modelPath) {
// 创建MNN网络
m_interpreter = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(modelPath.c_str()));
// 创建MNN会话
MNN::ScheduleConfig config;
m_session = m_interpreter->createSession(config);
// 获取输入tensor和输出tensor的信息
m_inputInfo = m_interpreter->getInputInfo(0);
m_outputInfo = m_interpreter->getOutputInfo(0);
}
~SemanticSegmentation() {
// 释放资源
m_interpreter->releaseModel();
m_interpreter->releaseSession(m_session);
}
std::vector<float> inference(const std::string& imagePath) {
cv::Mat image = cv::imread(imagePath);
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
// 创建一个临时的MNN::CV::ImageProcess对象
MNN::CV::ImageProcess process;
process.setResizeDims({image.cols, image.rows});
process.setNormalizationParam(1.0f / 255, {0.485f, 0.456f, 0.406f}, {0.229f, 0.224f, 0.225f});
process.process(image.data, image.cols, image.rows);
// 创建输入Tensor,并将数据拷贝到输入Tensor
MNN::Tensor* inputTensor = m_interpreter->getSessionInput(m_session, nullptr);
process.convertToTensor(inputTensor);
// 运行推理
m_interpreter->runSession(m_session);
// 获取输出Tensor的数据指针和shape信息
const MNN::Tensor* outputTensor = m_interpreter->getSessionOutput(m_session, nullptr);
const std::vector<int> outputShape = outputTensor->shape();
const int outputSize = outputShape[1] * outputShape[2];
std::vector<float> outputData(outputSize, 0.0f);
// 将数据拷贝到输出vector
memcpy(outputData.data(), outputTensor->host<float>(), outputSize * sizeof(float));
return outputData;
}
private:
std::shared_ptr<MNN::Interpreter> m_interpreter;
MNN::Session* m_session;
const MNN::Tensor* m_inputTensor;
const MNN::Tensor* m_outputTensor;
};
int main() {
const std::string modelPath = "path_to_model.mnn";
const std::string imagePath = "path_to_input_image.jpg";
SemanticSegmentation segmentation(modelPath);
std::vector<float> output = segmentation.inference(imagePath);
// 进行后处理,将输出转换为0-1掩模图
// 对output进行处理,根据具体模型输出的数据格式进行解析和操作
return 0;
}
版权声明:本文为u012483097原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。