tensorflow.js - 带有 earlyStopping 和训练日志的 TensorflowJS 不起作用
问题描述
当我们同时定义提前停止和训练日志功能时,TensorflowJS 似乎不起作用。上面的例子取自 TensorflowJS 文档,我刚刚添加了 onTrainBegin 回调——但它失败了。
const model = tf.sequential();
model.add(tf.layers.dense({
units: 3,
activation: 'softmax',
kernelInitializer: 'ones',
inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
{loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
const onTrainBegin = function onTrainBegin(logs){
console.log("onTrainBegin");
}
// Without the EarlyStopping callback, the val_acc value would be:
// 0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
epochs: 10,
validationData: [xsVal, ysVal],
callbacks: [onTrainBegin, tf.callbacks.earlyStopping({monitor: 'val_acc'})]
});
// Expect to see a length-2 array.
console.log(history.history.val_acc);
此代码产生错误消息:
发生错误 this.getMonitorValue 不是函数
https://js.tensorflow.org/api/latest/#callbacks.earlyStopping
解决方案
你正在混合不同的东西。OntrainBegin
指定何时执行回调函数并且 tf.callbacks.earlyStopping({monitor: 'val_acc'})
是一个函数
(async() => {
const model = tf.sequential();
model.add(tf.layers.dense({
units: 3,
activation: 'softmax',
kernelInitializer: 'ones',
inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
{loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
const onTrainBegin = logs => {
console.log("onTrainBegin");
}
// Without the EarlyStopping callback, the val_acc value would be:
// 0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
epochs: 10,
validationData: [xsVal, ysVal],
callbacks: [
tf.callbacks.earlyStopping({monitor: 'val_acc'}), new tf.CustomCallback({
onEpochEnd: onTrainBegin()}),
]
});
// Expect to see a length-2 array.
console.log(history.history.val_acc);
})()
<html>
<head>
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
</head>
<body>
</body>
</html>
推荐阅读
- c# - 文档中的项目数组项
- python - 编写 MongoDB 查询以根据引用的属性进行过滤?
- solr - 从 Solr 6.3 升级后,Solr 7.5 无法索引 pdf 文件
- javascript - Angular Highcharts - 如何动态克隆图表
- xml - 通过 UnMarshal 和 MarshalIndent 的往返 xml
- sql - 在一对多表中查找重复项
- text-to-speech - Watson TTS 语音选择
- apache - Apache 2.4 Require all denied 不使用 RewriteRule
- javascript - 每个浏览器如何选择wordpress中的图片url?
- javascript - 如何使用 Javascript 将字符串转换为 AST 对象?