javascript - 标量输出 TensorFlow JS
问题描述
使用顺序模型,如何获取二维输入数组(三维输入)并让模型对每个二维输入进行预测,以生成标量?输入形状(板):[ 153, 8, 8 ]。输出形状(结果):[153]。
模型:
const model = tf.sequential();
model.add(tf.layers.dense({units: 8, inputShape: [8]}));
model.add(tf.layers.dense({units:1, activation: 'sigmoid'}));
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
const xs = tf.tensor(boards);
const ys = tf.tensor(results);
model.fit(xs, ys, {batchSize: 8, epoch: 10000}).then(() => {
});
model.predict(tf.tensor(brokenFen)).print();
console.log(JSON.stringify(model.outputs[0].shape));
输出:
Tensor
[[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1]]
[null,1]
期望的输出:
Tensor
[1]
如果您还有任何问题,请联系lmk。
解决方案
有点理论...
预测的数量是赋予该predict
方法的特征的批量大小。为了更好地理解什么是特征和标签,让我们考虑以下内容。
| feature 1 | feature 2 | ... | feature n | Label |
|----------:|----------:|----:|-----------|------:|
| x11 | x12 | ... | x1n | y1 |
| x21 | x22 | ... | x2n | y2 |
| ... | ... | .. | ... | ... |
在上图中,数据有 n 个特征,对应于 n 个维度,而标签只有一个维度——为了简单起见并适合问题。模型的输入(第一层)应该与特征的维度相匹配,输出(最后一层)应该与标签的维度相匹配。在训练和预测时,我们给模型提供了一堆不同的样本 n1、n2。给出的样本数量对应于批量大小。该模型将返回与标签尺寸相同数量的形状元素。
该模型具有以下inputShape: [8]
,这表明我们有 8 个特征。最后一层units:1
暗示标签的大小为 1。当我们预测值时会发生什么?
const model = tf.sequential();
// first layer
model.add(tf.layers.dense({units: 8, inputShape: [8], activation: 'sigmoid' }));
// second layer
model.add(tf.layers.dense({units: 1}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
console.log(model.predict(tf.ones([3, 8])).shape) // [3, 1] 3 elements predicted
console.log(model.predict(tf.ones([1, 8])).shape) // [1, 1] single element predicted
<html>
<head>
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.15.1"> </script>
</head>
<body>
</body>
</html>
正如问题所建议的那样,标签是从三维值预测的。在这种情况下, inputShape 将是 3 而不是 8。如果tf.tensor(brokenFen)
有 shape [b, ...inputShape]
,就会有b
结果值。如果您想要单个值,请考虑通过扩展尺寸tf.expandims
或使用tf.reshape
单个形状元素将 b 设置为 1 inputShape
- 在本例中为 3。
推荐阅读
- python - 在 SQLAlchemy 核心表中创建计算列
- python - 将加权平均函数应用于熊猫 groupby 对象中的列,但权重总和为零
- django - Tinymce 在我的应用程序中运行良好,但在管理员中运行良好 - Django
- java - 使用来自 Kafka 主题的消息时的反序列化问题
- swift - 使用 Swift 和 Apple 的 Scripting Bridge 将击键发送到邮件应用程序
- python - 如何在 Python3 中访问 c_void_p_Array_8 对象的各个字节?
- jpa - 使用 Eclipse 链接从存储过程中获取输出参数
- java - JPA 查询 - 使用参数的实体名称作为相关 JPA 查询的参数
- php - 我的问题是关于 magento 2.3 中的 Cron 错误
- symfony - 渲染模板异常