java - 如何从 TensorFlow Java 调用模型?
问题描述
以下 python 代码传递["hello", "world"]
到通用句子编码器并返回一个浮点数组,表示它们的编码表示。
import tensorflow as tf
import tensorflow_hub as hub
module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))
此代码有效,但我现在想使用 Java API 做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。这是我到目前为止所得到的:
import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
Graph graph = module.graph();
try (Session session = new Session(graph, ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()))
{
Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
});
List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
addTarget("???").run();
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
我使用https://stackoverflow.com/a/51952478/14731扫描模型以查找可能的输入/输出节点。我相信输入节点是“serving_default_inputs”,但我不知道输出节点。更重要的是,通过 Keras 在 python 中调用代码时,我不必指定任何这些值,那么有没有办法使用 Java API 来做同样的事情?
更新:感谢roywei,我现在可以确认输入节点serving_default_input
和输出节点是StatefulPartitionedCall_1
,但是当我将这些名称插入上述代码时,我得到:
2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
[[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
[[StatefulPartitionedCall_1/StatefulPartitionedCall]]
at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
at libtensorflow@1.15.0/org.tensorflow.Session.access$100(Session.java:48)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)
意思是,我仍然无法调用模型。我错过了什么?
解决方案
在roywei 指出我正确的方向后,我想通了。
- 我需要使用
SavedModuleBundle.session()
而不是构建自己的实例。这是因为加载器初始化了图变量。 - 我没有将 a 传递
ConfigProto
给Session
构造函数,而是将其传递给SavedModelBundle
加载器。 - 我需要使用
fetch()
而不是addTarget()
检索输出张量。
这是工作代码:
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
try (Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
}))
{
MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
Map<String, Shape> nameToInput = getInputToShape(metadata);
String firstInput = nameToInput.keySet().iterator().next();
Map<String, Shape> nameToOutput = getOutputToShape(metadata);
String firstOutput = nameToOutput.keySet().iterator().next();
System.out.println("input: " + firstInput);
System.out.println("output: " + firstOutput);
System.out.println();
List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
fetch(firstOutput).run();
for (Tensor<?> tensor : result)
{
{
float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
tensor.numDimensions()];
tensor.copyTo(array);
System.out.println(Arrays.deepToString(array));
}
}
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
/**
* Loads a graph from a file.
*
* @param source the directory containing to load from
* @param tags the model variant(s) to load
* @return the graph
* @throws NullPointerException if any of the arguments are null
* @throws IOException if an error occurs while reading the file
*/
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
// https://stackoverflow.com/a/43526228/14731
try
{
return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
withTags(tags).
withConfigProto(ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()).
load();
}
catch (TensorFlowException e)
{
throw new IOException(e);
}
}
/**
* @param metadata the graph metadata
* @return the first signature, or null
*/
private SignatureDef getFirstSignature(MetaGraphDef metadata)
{
Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
if (nameToSignature.isEmpty())
return null;
return nameToSignature.get(nameToSignature.keySet().iterator().next());
}
/**
* @param metadata the graph metadata
* @return the output signature
*/
private SignatureDef getServingSignature(MetaGraphDef metadata)
{
return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
}
/**
* @param metadata the graph metadata
* @return a map from an output name to its shape
*/
protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
/**
* @param metadata the graph metadata
* @return a map from an input name to its shape
*/
protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
{
Map<String, Shape> result = new HashMap<>();
SignatureDef servingDefault = getServingSignature(metadata);
for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
{
TensorShapeProto shapeProto = entry.getValue().getTensorShape();
List<Dim> dimensions = shapeProto.getDimList();
long firstDimension = dimensions.get(0).getSize();
long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
Shape shape = Shape.make(firstDimension, remainingDimensions);
result.put(entry.getValue().getName(), shape);
}
return result;
}
}
推荐阅读
- html - Angular 动态反应形式的问题
- javascript - “For”循环索引未正确递增
- android - 8 月 1 日之后在 Google Play 管理中心上传 32 和 64 个 APK
- java - 将我的 JSONObject 转换为字符串并对其进行加密会导致错误
- java - 如何通过 JDBC (mariaDB) 检索选定行的主键值
- javascript - VBA:如何在图像上没有 ID 或名称的情况下触发 onclick 事件?
- javascript - Javascript 变量的 HTML 输入表单
- python - 在 Python 中访问类变量与实例变量
- java - 我如何用正则表达式替换这个函数
- javascript - select2 在用户选择之前将输入的文本应用为标签