首页 > 解决方案 > Java 中的 feed_dict 等价物

问题描述

我正在使用 Java 来提供使用 Python 学习的 Tensorflow 模型。该模型有两个输入。代码如下:

  def predict(float32InputShape: (Long, Long),
              float32Inputs: Seq[Seq[Float]],
              uint8InputShape: (Long, Long),
              uint8Inputs: Seq[Seq[Byte]]
             ): Array[Float] = {

    val float32Input = Tensor.create(
      Array(float32InputShape._1, float32InputShape._2),
      FloatBuffer.wrap(float32Inputs.flatten.toArray)
    )
    val uint8Input = Tensor.create(
      classOf[UInt8],
      Array(uint8InputShape._1, uint8InputShape._2),
      ByteBuffer.wrap(uint8Inputs.flatten.toArray)
    )

    val tfResult = session
      .runner()
      .feed("serving_default_float32_Input", float32Input)
      .feed("serving_default_uint8_Input", uint8Input)
      .fetch("PartitionedCall")
      .run()
      .get(0)
      .expect(classOf[java.lang.Float])

    tfResult
  }

我想做的是重构该方法,使其更通用,方法是通过 Python 中的 feed_dict 等输入。也就是说,类似:

    def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
      ...
      session
        .runner()
        .feed(inputs)
      ...
  }

地图的关键是inputs输入层的名称。feed除非我制作一个宏(我想避免),否则用该方法是不可能的。

有什么方法可以使用 Tensorflow 的 Java API(我使用的是 TF 2.0)来做到这一点?

编辑:我找到了解决方案(感谢@geometrikal 的回答),代码在Scala 中,但在Java 中应该不会太难。

    val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
      case (sess, (layerName, array)) =>
        val tensor = createTensor(array)
        sess.feed(layerName, tensor)
    }

    val output = runnerWithInputLayers
      .fetch(outputLayer)
      .run()
      .get(0)
      .expect(Float.getClass)

这是可能的,因为该.feed方法返回 aSession.Runner并提供了输入层。

标签: javatensorflow

解决方案


你可以循环喂每个。如果对 java 脚本不太熟悉,但伪代码类似于

例如

val tfResult = session.runner()
for(key, value : inputs) {
    tfResult = tfResult(key, value)
}
tfResult = tfResult.fetch("PartitionedCall")
  .run()
  .get(0)
  .expect(classOf[java.lang.Float])

请记住,您可以在任何时候分解函数链,例如result = foo.bar().baz().qux()可以编写temp = foo.bar().baz(); result = temp.qux()


推荐阅读