python tensorflow训练好的模型怎么在c++用

 我来答
enochwills
2017-08-17 · TA获得超过4793个赞
知道大有可为答主
回答量:2031
采纳率:96%
帮助的人:1643万
展开全部
  // 导入之前已经保存好的模型
  // 本程序来自tensorflow/c/c_api_test.cc
  // 如果不明白,就看这个测试脚本就行了
  const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
  const string saved_model_dir = tensorflow::io::JoinPath(
      tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
  TF_SessionOptions* opt = TF_NewSessionOptions();
  TF_Buffer* run_options = TF_NewBufferFromString("", 0);
  TF_Buffer* metagraph = TF_NewBuffer();
  TF_Status* s = TF_NewStatus();
  const char* tags[] = {tensorflow::kSavedModelTagServe};
  TF_Graph* graph = TF_NewGraph();
  TF_Session* session = TF_LoadSessionFromSavedModel(
      opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
  TF_DeleteBuffer(run_options);
  TF_DeleteSessionOptions(opt);
  tensorflow::MetaGraphDef metagraph_def;
  metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
  TF_DeleteBuffer(metagraph);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  CSession csession(session);
  // Retrieve the regression signature from meta graph def.
  const auto signature_def_map = metagraph_def.signature_def();
  const auto signature_def = signature_def_map.at("regress_x_to_y");
  const string input_name =
      signature_def.inputs().at(tensorflow::kRegressInputs).name();
  const string output_name =
      signature_def.outputs().at(tensorflow::kRegressOutputs).name();
  // Write {0, 1, 2, 3} as tensorflow::Example inputs.
  Tensor input(tensorflow::DT_STRING, TensorShape({4}));
  for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
    tensorflow::Example example;
    auto* feature_map = example.mutable_features()->mutable_feature();
    (*feature_map)["x"].mutable_float_list()->add_value(i);
    input.flat<string>()(i) = example.SerializeAsString();
  }
  const tensorflow::string input_op_name =
      tensorflow::ParseTensorName(input_name).first.ToString();
  TF_Operation* input_op =
      TF_GraphOperationByName(graph, input_op_name.c_str());
  ASSERT_TRUE(input_op != nullptr);
  csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}});
  const tensorflow::string output_op_name =
      tensorflow::ParseTensorName(output_name).first.ToString();
  TF_Operation* output_op =
      TF_GraphOperationByName(graph, output_op_name.c_str());
  ASSERT_TRUE(output_op != nullptr);
  csession.SetOutputs({output_op});
  csession.Run(s);
  ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_Tensor* out = csession.output_tensor(0);
  ASSERT_TRUE(out != nullptr);
  EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
  EXPECT_EQ(2, TF_NumDims(out));
  EXPECT_EQ(4, TF_Dim(out, 0));
  EXPECT_EQ(1, TF_Dim(out, 1));
  float* values = static_cast<float*>(TF_TensorData(out));
  // These values are defined to be (input / 2) + 2.
  EXPECT_EQ(2, values[0]);
  EXPECT_EQ(2.5, values[1]);
  EXPECT_EQ(3, values[2]);
  EXPECT_EQ(3.5, values[3]);
  csession.CloseAndDelete(s);
  EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
  TF_DeleteGraph(graph);
  TF_DeleteStatus(s);
mafangsan
2017-08-16 · TA获得超过1.2万个赞
知道大有可为答主
回答量:1万
采纳率:71%
帮助的人:2596万
展开全部
省心的方法是用进程调用。
已赞过 已踩过<
你对这个回答的评价是?
评论 收起
古朴且成功的饼子03
2021-01-11 · 贡献了超过235个回答
知道答主
回答量:235
采纳率:0%
帮助的人:17.8万
展开全部

Python使用Tensorflow读取CSV数据训练DNN深度学习模型

已赞过 已踩过<
你对这个回答的评价是?
评论 收起
收起 更多回答(1)
推荐律师服务: 若未解决您的问题,请您详细描述您的问题,通过百度律临进行免费专业咨询

为你推荐:

下载百度知道APP,抢鲜体验
使用百度知道APP,立即抢鲜体验。你的手机镜头里或许有别人想知道的答案。
扫描二维码下载
×

类别

我们会通过消息、邮箱等方式尽快将举报结果通知您。

说明

0/200

提交
取消

辅 助

模 式