首页 > 解决方案 > Passing data to Tensorflow model in Java

问题描述

I'm trying to use a Tensorflow model that I trained in python to score data in Scala (using TF Java API). For the model, I've used thisregression example, with the only change being that I dropped asText=True from export_savedmodel.

My snippet of Scala:

  val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
  val s = b.session()

  // output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
  val input = "0.5,1,ax01,bx02"

  val inputTensor = Tensor.create(input.getBytes("UTF-8"))

  val result = s.runner()
    .feed("csv_rows", inputTensor)
    .fetch("dnn/logits/BiasAdd")
    .run()
    .get(0)

When I run, I get the following error:

Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
 [[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)

I figure that there's a problem with how I've prepared my input Tensor, but I'm stuck on how to best debug this.

标签: javatensorflow

解决方案


错误消息表明某些操作中输入张量的形状不是预期的。

查看您链接到的 Python 笔记本(特别是第 8a 和 8c 节),似乎输入张量应该是字符串张量的“批次”,而不是单个字符串张量。

您可以通过比较 Scala 和 Python 程序中张量的形状(inputTensor.shape()在 scala中与 Python notebook中csv_rows提供的形状)来观察这一点。predict_fn

由此看来,您想要的是inputTensor一个字符串向量,而不是单个标量字符串。为此,您需要执行以下操作:

val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))

希望有帮助


推荐阅读