javascript - matMul 中的错误:形状为 684,1 和 2,1 且 transposeA=false 和 transposeB=false 的张量的内部形状 (1) 和 (2) 必须匹配
问题描述
我是 AI 和 tensorflow.js 的完整初学者。目前正在学习 Stephen Grider 的机器学习课程。我应该在下面的代码之后得到一个输出,但我得到了错误。请帮忙:
代码:线性回归.js:
const tf = require('@tensorflow/tfjs');
class LinearRegression {
constructor(features, labels, options) {
this.features = tf.tensor(features);
this.labels = tf.tensor(labels);
this.features = tf.ones([this.features.shape[0], 1]).concat(this.features) //generates the column of one for the horse power
this.options = Object.assign(
{ learningRate: 0.1, iterations: 1000 },
options
); //default value is 0.1, if the learning rate is provided, the value is overrided... iteration no. of times gradient decent runs
this.weights = tf.zeros([2, 1]); // intial tensor of both m and b are zeros
}
gradientDescent() {
const currentGuesses = this.features.matMul(this.weights); //matMul is matrix multiplication which is features * weights
const differences = currentGuesses.sub(this.labels); //(features * weights) - labels
const slopes = this.features
.transpose()
.matMul(differences)
.div(features.shape[0]); // slope of MSE with respect to both m and b. features * ((features * weights) - labels) / total no. of features.
this.weights = this.weights.sub(slopes.mul(this.options.learningRate));
}
train() {
for (let i=0; i < this.options.iterations; i++) {
this.gradientDescent();
}
/*test(testFeatures, testLabels) {
testFeatures = tf.tensor(testFeatures);
testLabels = tf.tensor(testLabels);
} */
}
}
module.exports = LinearRegression;
index.js:
require('@tensorflow/tfjs-node');
const tf = require('@tensorflow/tfjs');
const loadCSV = require('./load-csv');
const LinearRegression = require('./linear-regression');
let { features, labels, testFeatures, testLabels } =loadCSV('./cars.csv', {
shuffle: true,
splitTest: 50,
dataColumns: ['horsepower'],
labelColumns: ['mpg']
});
const regression = new LinearRegression(features, labels, {
learningRate: 0.002,
iterations: 100
});
regression.train();
console.log(
'Updated M is:',
regression.weights.get(1, 0),
'Updated B is:',
regression.weights.get(0, 0)
);
错误:
D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\operation.js:32
throw ex;
^
Error: Error in matMul: inner shapes (1) and (2) of Tensors with shapes 684,1 and 2,1 and transposeA=false and transposeB=false must match.
at Object.assert (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\util.js:36:15)
at matMul_ (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\matmul.js:25:10)
at Object.matMul (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\ops\operation.js:23:29)
at Tensor.matMul (D:\Application Development\MLKits-master\MLKits-master\regressions\node_modules\@tensorflow\tfjs-core\dist\tensor.js:315:26)
at LinearRegression.gradientDescent (D:\Application Development\MLKits-master\MLKits-master\regressions\linear-regression.js:19:46)
at LinearRegression.train (D:\Application Development\MLKits-master\MLKits-master\regressions\linear-regression.js:34:18)
at Object.<anonymous> (D:\Application Development\MLKits-master\MLKits-master\regressions\index.js:18:12)
at Module._compile (internal/modules/cjs/loader.js:1063:30)
at Object.Module._extensions..js (internal/modules/cjs/loader.js:1092:10)
at Module.load (internal/modules/cjs/loader.js:928:32)
解决方案
错误是由
this.features.matMul(this.weights)
this.features
形状[684, 1]
和this.weights
形状之间存在矩阵乘法[2, 1]
。为了能够将矩阵 A (shape [a, b]
) 与 B (shape [c, d]
) 相乘,b
并且c
应该匹配这里不是这种情况。
要解决这里的问题,this.weights
应该换位
this.features.matMul(this.weights, false, true)
推荐阅读
- java - 无法实例化视图模型,因为表为空
- netlogo - 最小预期输入是一个列表,但得到了数字......而不是
- javascript - Nextjs:如何注册 quill-blot-formatter 以仅在客户端渲染上动态导入 react-quill?
- r - 如何在 ggplot2 中制作堆叠密度图?
- assembly - 为什么在 Rust 的泛型中接受 `fn(..)` 而不是 `Fn(...)`?
- sql - SQL按日期查找条目总和,包括前一个日期
- jwt - Grafana:如何使用 JWT 身份验证?
- git - 我们可以从 Azure Devops 管道提出和合并拉取请求吗?
- c# - 在 ABP C# 中捕获连接错误异常
- c# - 在两个实例化对象之间均匀实例化块