tensorflow - TensorflowJS 模型无法正确预测多类数据
问题描述
作为一个初学者,我尝试在 tensorflowJS 中构建一个非常简单的多类分类器,它可以预测我的视线方向。
第 1 步:我在浏览器中创建了数据集来训练我的模型,我将网络摄像头渲染的眼睛图像存储在 HTML5 画布上。我使用箭头键将图像标记为 0=left、1=normal 和 2=right。为了训练模型,我在传递给方法之前使用 tf.onHot() 转换这些标签。
// data collection
let imageArray = [];
let labelArray = [];
let collectData = (label) => {
const img = tf.tidy(() => {
const captureImg = getImage();
//console.log(captureImg.shape)
return captureImg;
})
imageArray.push(img)
labelArray.push(label) //--- labels are 0,1,2
}
// label conversion
let labelSet = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);
第 2 步:我没有加载任何经过训练的模型,而是使用了我使用 tensorflowJS 构建的自定义模型。
let createModel = () => {
const model = tf.sequential();
let config_one = {
kernelSize: 3,
filters: 40,
strides: 1,
activation: 'relu',
inputShape: [imageHeight, imageWidth, imageChannels]
}
model.add(tf.layers.conv2d(config_one));
let config_two = {
poolSize: [2, 2],
strides: [2, 2],
}
model.add(tf.layers.maxPooling2d(config_two));
model.add(tf.layers.flatten());
model.add(tf.layers.dropout(0.2));
// Two output values x and y
let congfig_output = {
units: 3,
activation: 'tanh',
}
model.add(tf.layers.dense(congfig_output));
// Use ADAM optimizer with learning rate of 0.0005 and MSE loss
let config_compile = {
optimizer: tf.train.adam(0.00005),
loss: 'categoricalCrossentropy',
}
model.compile(config_compile);
tf.memory()
return model;
}
问题:我现在面临几个问题。
当我使用 meanSquared 作为损失函数和 adam 学习率 0.000005 时,我的模型开始预测,但它只预测眼睛的正常状态和左/右中的两个因此进行多类分类,我将损失函数更改为 categoricalCrossentropy 但结果仍然是相同或有时最差。
我尝试了其他超参数组合,但没有运气。我遇到的最糟糕的情况是我的损失函数只重复显示三个常数值。
在某些情况下,我的浏览器会崩溃——如果——我传递了太多数据或在编译配置中使用了其他类型的优化器,例如 sgd 或其他任何东西。当我在谷歌上进行快速搜索时,我发现我可以使用 tf.memory() 检查任何可能导致浏览器崩溃的内存泄漏,但该行没有在控制台中记录任何内容。
我正在调整代码中的各种值和参数并训练模型,使其有时、部分地工作,甚至大部分时间都不起作用。这一切都受到了打击和考验。最终我了解了用于编译方法中的损失函数和 con2d 输入层中的激活函数的参数,但其他内容仍然令人困惑,例如 - 时期数、批量大小、adam 中的学习率等。
我理解或我认为我理解这些 - 内核大小、过滤器、步幅、输入形状,但仍然不知道如何确定各种超参数等的层数。
编辑- 这是我根据建议更新代码后得到的。我仍然没有正确分类。我正在使用至少 1000 多张图像进行训练。
A. 我仍然得到固定 valeus 反复出现的损失
B. 精度也在重复 1、0.5 和 0
function getImage() {
return tf.tidy(function () {
const image = tf.browser.fromPixels($('#eyes')[0]);
const batchedImage = image.expandDims(0);
const norm = batchedImage.toFloat().div(tf.scalar(255)).sub(tf.scalar(1));
return norm;
});
}
这是控制台输出
示例图像 -
解决方案
对我来说最明显的错误是你的输出层的激活函数,你应该使用tanh
你应该使用的地方softmax
。接下来,你的学习率是很低的尝试设置它0.001
是一个很好的默认值。
您也可能不需要 dropout,因为您没有得到任何结果来证明模型过度拟合。你也可以添加更多的卷积层,试试下面的例子。
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2],
}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2],
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: 3,
activation: 'softmax',
}));
const LEARNING_RATE = 0.001;
const optimizer = tf.train.adam(LEARNING_RATE);
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
推荐阅读
- mysql - 这个 bash 脚本试图用 mysql 做什么?
- matlab - 如何使用 LaTeX 解释器在 MATLAB 图形中使颜色条的刻度变为粗体?
- angularjs - 不允许超过 md-maxlength="200"
- c++ - 为什么 std::function 不能从 lambda 移动构造?
- c# - Npgsql.NpgsqlException: '没有提供密码
- java - 一段时间后从后台恢复时,Android应用布局组件为空
- java - MySQLNonTransientConnectionException:连接关闭后不允许任何操作。
- python - Python中的XOR RGB图像解密
- html - 将鼠标悬停在父 div 上时 CSS 淡化子元素,然后将鼠标悬停在子元素上时再次更改
- vue.js - 从另一个模块访问 vuex 模块状态