tensorflow - 如何为 csv 数据集创建模型并使用 TensorFlow JS 计算预测结果
问题描述
我是 TensorFlow JS 的新手。我按照 TensorFlow JS 文档创建模型并训练它从创建的模型计算预测结果。
但我不知道如何为 CSV 文件训练创建的模型并计算 CSV 文件中两列或多列的预测结果。
有人可以指导我使用 CSV 文件创建、训练模型并计算预测结果的样本吗?
const csvUrl = 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
function save(model) {
return model.save('downloads://boston_model');
}
function load() {
return tf.loadModel('indexeddb://boston_model');
}
async function run() {
// We want to predict the column "medv", which represents a median value of a
// home (in $1000s), so we mark it as a label.
const csvDataset = tf.data.csv(
csvUrl, {
columnConfigs: {
medv: {
isLabel: true
}
}
});
// Number of features is the number of column names minus one for the label
// column.
const numOfFeatures = (await csvDataset.columnNames()).length - 1;
// Prepare the Dataset for training.
const flattenedDataset =
csvDataset
.map(([rawFeatures, rawLabel]) =>
// Convert rows from object form (keyed by column name) to array form.
[Object.values(rawFeatures), Object.values(rawLabel)])
.batch(10);
// Define the model.
const model = tf.sequential();
model.add(tf.layers.dense({
inputShape: [numOfFeatures],
units: 1
}));
model.compile({
optimizer: tf.train.sgd(0.000001),
loss: 'meanSquaredError'
});
// Fit the model using the prepared Dataset
model.fitDataset(flattenedDataset, {
epochs: 10,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(epoch, logs.loss);
}
}
});
const savedModel=save(model);
}
run().then(() => console.log('Done'));
解决方案
使用tf.data.csv,您可以使用 csv 文件训练模型。
但是浏览器不能直接读取文件。因此,您必须在本地服务器上提供 csv 文件
更新
您的模型仅使用一个感知器。使用多个感知器可以帮助提高模型的准确性,即添加多个层。你可以在这里看看它是如何完成的。
推荐阅读
- python - 在 Python tkinter 中上传 excel 文件并打印为数据框
- java - 如何解析这个类似于 JSON 的字符串?
- asp.net - 如何在基于令牌的身份验证 OWIN 中使令牌无效
- visual-studio - 为什么 Visual Studios 不构建新代码?
- azure-data-factory - 在 Azure 数据工厂中解压缩 gzip 文件
- vue.js - vuex:状态字段“foo”被“foo”处的同名模块覆盖
- c# - 将对象插入容器中,其中构造对象的类在类型参数的数量上有所不同
- android - 单击后退按钮时无法完成活动
- android - kotlin DSL 从其他文件中检索密钥
- angular - 离开页面前的警告 - 但如果路线匹配则不会