java - Java 中的 feed_dict 等价物
问题描述
我正在使用 Java 来提供使用 Python 学习的 Tensorflow 模型。该模型有两个输入。代码如下:
def predict(float32InputShape: (Long, Long),
float32Inputs: Seq[Seq[Float]],
uint8InputShape: (Long, Long),
uint8Inputs: Seq[Seq[Byte]]
): Array[Float] = {
val float32Input = Tensor.create(
Array(float32InputShape._1, float32InputShape._2),
FloatBuffer.wrap(float32Inputs.flatten.toArray)
)
val uint8Input = Tensor.create(
classOf[UInt8],
Array(uint8InputShape._1, uint8InputShape._2),
ByteBuffer.wrap(uint8Inputs.flatten.toArray)
)
val tfResult = session
.runner()
.feed("serving_default_float32_Input", float32Input)
.feed("serving_default_uint8_Input", uint8Input)
.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])
tfResult
}
我想做的是重构该方法,使其更通用,方法是通过 Python 中的 feed_dict 等输入。也就是说,类似:
def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
...
session
.runner()
.feed(inputs)
...
}
地图的关键是inputs
输入层的名称。feed
除非我制作一个宏(我想避免),否则用该方法是不可能的。
有什么方法可以使用 Tensorflow 的 Java API(我使用的是 TF 2.0)来做到这一点?
编辑:我找到了解决方案(感谢@geometrikal 的回答),代码在Scala 中,但在Java 中应该不会太难。
val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
case (sess, (layerName, array)) =>
val tensor = createTensor(array)
sess.feed(layerName, tensor)
}
val output = runnerWithInputLayers
.fetch(outputLayer)
.run()
.get(0)
.expect(Float.getClass)
这是可能的,因为该.feed
方法返回 aSession.Runner
并提供了输入层。
解决方案
你可以循环喂每个。如果对 java 脚本不太熟悉,但伪代码类似于
例如
val tfResult = session.runner()
for(key, value : inputs) {
tfResult = tfResult(key, value)
}
tfResult = tfResult.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])
请记住,您可以在任何时候分解函数链,例如result = foo.bar().baz().qux()
可以编写temp = foo.bar().baz(); result = temp.qux()
推荐阅读
- python - Django - 如果没有匹配的 URL,则重定向到主页
- c# - 将文本写入文件 System.IO.IOException
- three.js - THREE.Geometry 上的视频纹理
- java - 课堂上的@param javadoc
- android - Camera.open() 获取“访问权限
已被限制” - node.js - 使用 SSL 连接到 Oracle
- mysql - MySQL 8 和 PostgreSQL 中的“过程分析”类似物
- objective-c - 突出显示时创建的 TableViewCell 边框消失
- postgresql - 行级安全性不适用于表所有者
- clojure - Project Euler 问题 54 错误答案。可能做错了什么