tensorflow - 计算预测的置信度
问题描述
我正在制作一个简单的程序,试图从一组值中预测一个值:
require('@tensorflow/tfjs-node');
const {layers, setBackend, sequential, train} = require('@tensorflow/tfjs');
setBackend('tensorflow');
const {xs, ys} = require('./data.v2');
const [samples] = xs.shape;
const MAX_SAMPLES = samples - 3; // Leave 3 for predictions
const trainXs = xs.slice([0], MAX_SAMPLES);
const trainYs = ys.slice([0], MAX_SAMPLES);
const predict = xs.slice([MAX_SAMPLES]);
const expect = ys.slice([MAX_SAMPLES]);
const LEARNING_RATE = 0.01;
const BATCH_SIZE = 2;
const EPOCHS = 1000;
// Define a model for linear regression.
const model = sequential();
// First layer must have an input shape defined
model.add(layers.dense({units: 16, inputShape: [37,]}));
// Afterwards, TF.js does automatic shape inference
model.add(layers.dense({units: 4}));
model.add(layers.dense({units: 1}));
// https://js.tensorflow.org/api/0.10.0/#Training-Optimizers
const optimizer = train.adam(LEARNING_RATE);
// NOTE: For classification we would use a cross entropy loss fn,
// but for regression, we prefer mean squared error
// https://stackoverflow.com/a/36516373/1092007
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({
optimizer,
loss: 'meanSquaredError'
});
fit(trainXs, trainYs, EPOCHS, BATCH_SIZE)
.then(() => {
console.log('Done training');
// Use the model to do inference on a data point the model hasn't seen before:
const item = predict.slice([0], 1);
console.log(item.dataSync());
model.predict(item, true)
.print();
expect.slice([0], 1)
.print();
});
async function fit(xs, ys, epochs, batchSize) {
// Train the model using the data
const history = await model.fit(xs, ys, {
batchSize,
epochs,
shuffle: true,
validationSplit: 0.3,
callbacks: {
onEpochEnd(...args) {
const [epoch, history] = args;
const {loss} = history;
console.log(`Loss after epoch ${epoch}: ${loss}`);
}
}
});
}
输入数据是一个形状张量[53, 37]
:
[[1, 2, 3, ...], [4, 5, 6, ...], ...]
输出是向量(形状[53]
):
[3, 4, ...]
但我想知道如何计算输出的置信度.predict()
?
注意:我将tensorflow.js与 Node 一起使用,因此 API 可能与 Python API 有点不同。
解决方案
推荐阅读
- reactjs - 反应包,无效的钩子调用
- node.js - 从嵌套在多维数组中的字典中读取特定属性
- java - 子类可以与超类具有相同的列名吗?
- java - 获取 android.database.CursorWindowAllocationException 频繁崩溃
- asp.net - ASP。net core httpresponsemessage json 内容
- angular - 在我的库中使用浏览器动画模块
- c - 仅打印所有 char 作为字符串数组的单词的正确条件
- android - 来自 MediaSessionCompat 的两个通知
- javascript - 如何在html中使用div作为背景
- javascript - 导入模块上当前未启用对实验语法“jsx”的支持错误