javascript - Tensorflow JS 损失没有改变
问题描述
我是神经网络的相对初学者,对 tensorflow js 完全陌生。我正在尝试训练一个简单的模型来预测房产的价格。(https://www.kaggle.com/shree1992/housedata)。但是,当我训练模型时,损失永远不会改变,并且预测会大幅偏离。我在这里做错了什么?
这些是我用来预测价格的功能:
{
bedrooms: 2,
bathrooms: 2.5,
sqft_living: 1590,
sqft_lot: 2656,
floors: 2,
waterfront: 0,
view: 0,
condition: 3
}
{ price: 305000 }
========
Code
========
const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');
// const csvUrl =
// 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
const csvUrl = "file://./data.csv"
async function run() {
// We want to predict the column "medv", which represents a median value of
// a home (in $1000s), so we mark it as a label.
const csvDataset = tf.data.csv(
csvUrl, {
columnConfigs: {
price: {
isLabel: true
}
}
});
// Number of features is the number of column names minus one for the label
// column.
const numOfFeatures = (await csvDataset.columnNames()).length - 1;
console.log(csvDataset.columnNames());
// Prepare the Dataset for training.
const flattenedDataset =
csvDataset
.map(({xs, ys}) =>
{
// Convert xs(features) and ys(labels) from object form (keyed by
// column name) to array form.
// console.log(xs);
// console.log(ys);
return {xs:Object.values(xs), ys:Object.values(ys)};
})
.batch(10);
// Define the model.
const model = tf.sequential();
model.add(tf.layers.dense({
inputShape: [numOfFeatures],
units: 1,
}));
// model.add(tf.layers.dense({
// units:1,
// activation: 'softmax'
// }));
model.compile({
optimizer: 'sgd',
loss: 'binaryCrossentropy',
lr:1
});
// Fit the model using the prepared Dataset
await model.fitDataset(flattenedDataset, {
epochs: 10,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(epoch + ':' + logs.loss);
}
}
});
// let test = tf.tensor2d([[0.26169, 0, 9.9, 0, 0.544, 6.023, 90.4, 2.834, 4, 304, 18.4, 11.72]])
let test = tf.tensor2d([[3,1.5,1340,7912,1.5,0,0,3],[5,2.5,3650,9050,2,0,4,5]])
model.predict(test).print();
}
run();
解决方案
推荐阅读
- amazon-web-services - 如何启用 AmazonS3 静态网站托管?
- c# - 如何在 Visual Studio 调试器中检查 ac#(不安全)指针变量?
- c - 仅使用一个循环在数组中打印 String 的最后一个单词
- exasolution - 在 Exasol DB 中使用 alter 命令添加多列
- python - 相同的集合 == False 有人知道为什么吗?
- javascript - 如何在不使用服务器端语言的情况下从 HTML 表单存储/发送数据?
- java - Rxjava 可观察到的不兼容类型
- django - 如何在 django URL 上为导入的视图设置身份验证和权限
- r - 如何在R中滞后数据框的特定列
- algorithm - 用给定函数形式逼近未知函数的算法?