python tensorflow训练好的模型怎么在c++用
展开全部
// 导入之前已经保存好的模型
// 本程序来自tensorflow/c/c_api_test.cc
// 如果不明白,就看这个测试脚本就行了
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};
TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_LoadSessionFromSavedModel(
opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
TF_DeleteBuffer(run_options);
TF_DeleteSessionOptions(opt);
tensorflow::MetaGraphDef metagraph_def;
metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
TF_DeleteBuffer(metagraph);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
CSession csession(session);
// Retrieve the regression signature from meta graph def.
const auto signature_def_map = metagraph_def.signature_def();
const auto signature_def = signature_def_map.at("regress_x_to_y");
const string input_name =
signature_def.inputs().at(tensorflow::kRegressInputs).name();
const string output_name =
signature_def.outputs().at(tensorflow::kRegressOutputs).name();
// Write {0, 1, 2, 3} as tensorflow::Example inputs.
Tensor input(tensorflow::DT_STRING, TensorShape({4}));
for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
tensorflow::Example example;
auto* feature_map = example.mutable_features()->mutable_feature();
(*feature_map)["x"].mutable_float_list()->add_value(i);
input.flat<string>()(i) = example.SerializeAsString();
}
const tensorflow::string input_op_name =
tensorflow::ParseTensorName(input_name).first.ToString();
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}});
const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
csession.SetOutputs({output_op});
csession.Run(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
EXPECT_EQ(2, TF_NumDims(out));
EXPECT_EQ(4, TF_Dim(out, 0));
EXPECT_EQ(1, TF_Dim(out, 1));
float* values = static_cast<float*>(TF_TensorData(out));
// These values are defined to be (input / 2) + 2.
EXPECT_EQ(2, values[0]);
EXPECT_EQ(2.5, values[1]);
EXPECT_EQ(3, values[2]);
EXPECT_EQ(3.5, values[3]);
csession.CloseAndDelete(s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
推荐律师服务:
若未解决您的问题,请您详细描述您的问题,通过百度律临进行免费专业咨询