javascript - 试图学习 tensorflow.js,但需要一个更简单的例子,比如 Brain.js
问题描述
Brain.js 的示例让我能够很好地理解该软件 - 以及机器学习的各个方面。
现在我正在尝试学习 tensorflow.js,很难复制相同类型的代码。
例如,TensorFlow 等价于下面的大脑代码是什么?
var net = new brain.NeuralNetwork();
net.train([{input: { r: 0.03, g: 0.7, b: 0.5 }, output: { black: 1 }},
{input: { r: 0.16, g: 0.09, b: 0.2 }, output: { white: 1 }},
{input: { r: 0.5, g: 0.5, b: 1.0 }, output: { white: 1 }}]);
var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.99, black: 0.002 }
解决方案
这将是您提供的示例的一种简化版本:
const net = tf.sequential();
net.add(tf.layers.dense({
units: 2,
inputShape: [3],
activation: 'sigmoid'
}));
net.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});
const xs = tf.tensor2d([
[0.03, 0.7, 0.5],
[0.16, 0.09, 0.2],
[0.5, 0.5, 1.0]
]);
const ys = tf.tensor2d([
[1, 0],
[0, 1],
[0, 1]
]);
net.fit(xs, ys).then(() => {
const xPredict = tf.tensor2d([
[1.0, 0.4, 0.0]
]);
const prediction = net.predict(xPredict);
prediction.print();
});
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.6">
</script>
但正如您所看到的,您必须更具体地了解神经网络的实际作用。我只是使用了您可以使用的基本选项。正如您所看到的,网络因此性能非常差,但我希望您了解它是如何工作的。所以让我试着解释一下这个片段的作用:
const
通常用于tf.tensor
s,因为它们位于 GPU 上,因此无论如何都无法更改。
tf.sequential()
创建一个空的前馈网络。(如果您不知道那是什么,请先尝试在没有实现的情况下学习神经网络)
tf.layers.dense()
创建一个全连接层。
units:2
定义层的输出形状。在这种情况下,一个具有两个值的向量。
inputShape: [3]
可以在任何非第一层上忽略,因为它可以由前一层推断并定义输入张量的形状
activation: 'sigmoid'
是应用于层返回值的激活函数,高度依赖于您要解决的问题。
.compile()
使用给定选项编译网络,并且非常可定制
xs
是训练ys
数据集。注意:它们有一个额外的维度来表示多个 xy 对,因此它们可以批量传递给训练函数。
.fit()
是训练方法,训练网络的内部权重。注意:这是一个异步函数,所以你必须等到它完成才能使用模型。
xPredict
是测试数据,也比网络的返回形状高一维。
.predict()
根据给定的输入预测网络的输出。
.print()
在控制台中输出张量。(如果它太大,它会被裁剪)
我强烈建议你在尝试实现它们(通过复制)之前先了解更多关于神经网络的知识,因为它们会变得非常复杂和令人困惑。然后,您可以阅读文档以了解可能的情况。
推荐阅读
- c++ - 如何通过C ++消除系列中的某些数字?
- c++ - 对象向量抽象函数调用的分段错误
- django - Django - models - number of fields not known in advance
- mongodb - Get values as array of elements after $lookup
- csv - SSIS Flat File Connection - How does it determine string column DataType?
- php - 文件存储问题
- django - 在 Django 上构建排名系统
- reactjs - 调度动作调用不正确的减速器
- jupyter-notebook - 如何为 Julia 的早期版本添加 Jupyter Notebook 内核?
- c# - 如何在两个位置之间连续移动游戏对象?