首页 > 解决方案 > 带有 earlyStopping 和训练日志的 TensorflowJS 不起作用

问题描述

当我们同时定义提前停止和训练日志功能时,TensorflowJS 似乎不起作用。上面的例子取自 TensorflowJS 文档,我刚刚添加了 onTrainBegin 回调——但它失败了。

const model = tf.sequential();
model.add(tf.layers.dense({
  units: 3,
  activation: 'softmax',
  kernelInitializer: 'ones',
  inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
    {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});

const onTrainBegin = function onTrainBegin(logs){
     console.log("onTrainBegin");
}


// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
  epochs: 10,
  validationData: [xsVal, ysVal],
  callbacks: [onTrainBegin, tf.callbacks.earlyStopping({monitor: 'val_acc'})]
});

// Expect to see a length-2 array.
console.log(history.history.val_acc);

此代码产生错误消息:

发生错误 this.getMonitorValue 不是函数

https://js.tensorflow.org/api/latest/#callbacks.earlyStopping

标签: tensorflow.js

解决方案


你正在混合不同的东西。OntrainBegin指定何时执行回调函数并且 tf.callbacks.earlyStopping({monitor: 'val_acc'})是一个函数

(async() => {
const model = tf.sequential();
model.add(tf.layers.dense({
  units: 3,
  activation: 'softmax',
  kernelInitializer: 'ones',
  inputShape: [2]
}));
const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
model.compile(
    {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});

const  onTrainBegin = logs => {
     console.log("onTrainBegin");
}


// Without the EarlyStopping callback, the val_acc value would be:
//   0.5, 0.5, 0.5, 0.5, ...
// With val_acc being monitored, training should stop after the 2nd epoch.
const history = await model.fit(xs, ys, {
  epochs: 10,
  validationData: [xsVal, ysVal],
  callbacks: [
   tf.callbacks.earlyStopping({monitor: 'val_acc'}), new tf.CustomCallback({
      onEpochEnd: onTrainBegin()}),
 ]
});

// Expect to see a length-2 array.
console.log(history.history.val_acc);
})()
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>


推荐阅读