java - Passing data to Tensorflow model in Java
问题描述
I'm trying to use a Tensorflow model that I trained in python to score data in Scala (using TF Java API). For the model, I've used thisregression example, with the only change being that I dropped asText=True
from export_savedmodel
.
My snippet of Scala:
val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
val s = b.session()
// output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
val input = "0.5,1,ax01,bx02"
val inputTensor = Tensor.create(input.getBytes("UTF-8"))
val result = s.runner()
.feed("csv_rows", inputTensor)
.fetch("dnn/logits/BiasAdd")
.run()
.get(0)
When I run, I get the following error:
Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
[[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
I figure that there's a problem with how I've prepared my input Tensor, but I'm stuck on how to best debug this.
解决方案
错误消息表明某些操作中输入张量的形状不是预期的。
查看您链接到的 Python 笔记本(特别是第 8a 和 8c 节),似乎输入张量应该是字符串张量的“批次”,而不是单个字符串张量。
您可以通过比较 Scala 和 Python 程序中张量的形状(inputTensor.shape()
在 scala中与 Python notebook中csv_rows
提供的形状)来观察这一点。predict_fn
由此看来,您想要的是inputTensor
一个字符串向量,而不是单个标量字符串。为此,您需要执行以下操作:
val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))
希望有帮助
推荐阅读
- java - 在 arraylist 对象中使用 split() 方法代码不完整
- c++ - boost io服务对象
- batch-file - 如何使它不显示执行代码的 bat 文件的 cmd 窗口的第一行?
- c# - Assert.Throws 方法没有捕捉到预期的异常
- c++ - 类中的运算符(“C++ 需要所有声明的类型说明符”)
- ssl - .NET Core 处理 HTTPS 证书时出现未知错误
- gridview - Image.network 中像素抖动的底部溢出
- javascript - 在 NodeJS 中处理来自 POST 请求的数据
- django - 如何授权使用 django rest 框架和 CreateAPIView 创建对象?
- python - 寻找更好的解决方案来用beautifulsoup 抓取多个网页