首页 > 解决方案 > 在 tensorflowjs 中分类

问题描述

按照本教程,我想在 tensorflowjs 中加载和使用模型,然后使用分类方法对输入进行分类。

我像这样加载并执行模型:

const model = await window.tf.loadGraphModel(MODEL_URL);

const threshold = 0.9;
const labelsToInclude = ["test1"];

model.load(threshold, labelsToInclude).then(model2 => {
    model2.classify(["test sentence"])
      .then(predictions => {
    console.log('prediction: ' + predictions);
    return true;
  })
});

但我收到错误:

TypeError:model2.classify 不是 App.js:23 的函数

如何正确使用 tensorflowjs 中的分类方法?

标签: tensorflowtensorflow.js

解决方案


本教程使用特定模型(毒性)。它loadclassify函数不是 Tensorflow.js 模型本身的特性,而是由这个特定模型实现的。

查看API以查看模型支持的一般功能。如果您加载 GraphModel,您希望使用model.predict(or execute) 函数来执行模型。

因此,您的代码应如下所示:

const model = await window.tf.loadGraphModel(MODEL_URL);
const input = tf.tensor(/* ... */); // whatever a valid tensor looks like for your model
const predictions = model.predict([input]);
console.log('prediction: ' + predictions);

推荐阅读