首页 > 解决方案 > 如何从 TensorFlow Java 调用模型?

问题描述

以下 python 代码传递["hello", "world"]到通用句子编码器并返回一个浮点数组,表示它们的编码表示。

import tensorflow as tf
import tensorflow_hub as hub

module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))

此代码有效,但我现在想使用 Java API 做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。这是我到目前为止所得到的:

import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            Graph graph = module.graph();
            try (Session session = new Session(graph, ConfigProto.newBuilder().
                setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                setAllowSoftPlacement(true).
                build().toByteArray()))
            {
                Tensor<String> input = Tensors.create(new byte[][]
                    {
                        "hello".getBytes(StandardCharsets.UTF_8),
                        "world".getBytes(StandardCharsets.UTF_8)
                    });
                List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
                    addTarget("???").run();
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }
}

我使用https://stackoverflow.com/a/51952478/14731扫描模型以查找可能的输入/输出节点。我相信输入节点是“serving_default_inputs”,但我不知道输出节点。更重要的是,通过 Keras 在 python 中调用代码时,我不必指定任何这些值,那么有没有办法使用 Java API 来做同样的事情?

更新:感谢roywei,我现在可以确认输入节点serving_default_input和输出节点是StatefulPartitionedCall_1,但是当我将这些名称插入上述代码时,我得到:

2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
     [[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall]]
    at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
    at libtensorflow@1.15.0/org.tensorflow.Session.access$100(Session.java:48)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)

意思是,我仍然无法调用模型。我错过了什么?

标签: javapythonkerastensorflow2.0

解决方案


在roywei 指出我正确的方向后,我想通了。

  • 我需要使用SavedModuleBundle.session()而不是构建自己的实例。这是因为加载器初始化了图变量。
  • 我没有将 a 传递ConfigProtoSession构造函数,而是将其传递给SavedModelBundle加载器。
  • 我需要使用fetch()而不是addTarget()检索输出张量。

这是工作代码:

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            try (Tensor<String> input = Tensors.create(new byte[][]
                {
                    "hello".getBytes(StandardCharsets.UTF_8),
                    "world".getBytes(StandardCharsets.UTF_8)
                }))
            {
                MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
                Map<String, Shape> nameToInput = getInputToShape(metadata);
                String firstInput = nameToInput.keySet().iterator().next();

                Map<String, Shape> nameToOutput = getOutputToShape(metadata);
                String firstOutput = nameToOutput.keySet().iterator().next();

                System.out.println("input: " + firstInput);
                System.out.println("output: " + firstOutput);
                System.out.println();

                List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
                    fetch(firstOutput).run();
                for (Tensor<?> tensor : result)
                {
                    {
                        float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
                            tensor.numDimensions()];
                        tensor.copyTo(array);
                        System.out.println(Arrays.deepToString(array));
                    }
                }
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Loads a graph from a file.
     *
     * @param source the directory containing  to load from
     * @param tags   the model variant(s) to load
     * @return the graph
     * @throws NullPointerException if any of the arguments are null
     * @throws IOException          if an error occurs while reading the file
     */
    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        // https://stackoverflow.com/a/43526228/14731
        try
        {
            return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
                withTags(tags).
                withConfigProto(ConfigProto.newBuilder().
                    setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                    setAllowSoftPlacement(true).
                    build().toByteArray()).
                load();
        }
        catch (TensorFlowException e)
        {
            throw new IOException(e);
        }
    }

    /**
     * @param metadata the graph metadata
     * @return the first signature, or null
     */
    private SignatureDef getFirstSignature(MetaGraphDef metadata)
    {
        Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
        if (nameToSignature.isEmpty())
            return null;
        return nameToSignature.get(nameToSignature.keySet().iterator().next());
    }

    /**
     * @param metadata the graph metadata
     * @return the output signature
     */
    private SignatureDef getServingSignature(MetaGraphDef metadata)
    {
        return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an output name to its shape
     */
    protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an input name to its shape
     */
    protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }
}

推荐阅读