javascript - ValueError:logits 和标签必须具有相同的形状,但形状为 [2] 和 [2,1]
问题描述
请帮助我理解我在 TensorFlow.js 代码中的错误。试图击败二元分类和 fitDataset。
简化示例https://jsfiddle.net/9w8hx21o/4/。
在示例中,我有 4 个 4 x 7 的观察值,并且有四个标签。在训练开始时,我收到错误“logits and labels must have the same shape, but got shapes [2] and [2,1]”。
const xs = [
[
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
],
[
[2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 2],
],
[
[3, 3, 3, 3, 3, 3, 3],
[3, 3, 3, 3, 3, 3, 3],
[3, 3, 3, 3, 3, 3, 3],
[3, 3, 3, 3, 3, 3, 3],
],
[
[4, 4, 4, 4, 4, 4, 4],
[4, 4, 4, 4, 4, 4, 4],
[4, 4, 4, 4, 4, 4, 4],
[4, 4, 4, 4, 4, 4, 4],
]
]
const ys = [0, 1, 0, 1]
const model = tf.sequential()
model.add(tf.layers.inputLayer({
inputShape: [4, 7]
}))
model.add(tf.layers.conv1d({
filters: 16,
kernelSize: 2,
activation: 'relu',
}))
model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}))
model.summary()
model.compile({
optimizer: 'adam',
loss: 'binaryCrossentropy',
metrics: ['accuracy']
})
const xDataset = tf.data.array(xs);
const yDataset = tf.data.array(ys);
const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset}).batch(2).shuffle(2)
const print_xyDataset = async () => {
await xyDataset.forEachAsync(e => {
console.log('\n');
for (let key in e) {
console.log(key + ':');
console.log('Shape ' + e[key].shape)
e[key].print();
}
})
}
print_xyDataset()
const train = async () => {
await model.fitDataset(xyDataset, {
epochs: 4,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(`EPOCH (${epoch + 1}): Train Accuracy: ${(logs.acc * 100).toFixed(2)}\n`);
},
}
})
}
train().catch(e => console.log(e))
解决方案
您可能正在运行新版本的 TF。如果 true 和 pred 缺少额外的暗淡,旧的 TF 将创建数学上等效但内部意外的行为。做这个
const ys = [[0], [1], [0], [1]]
看看是否可以解决它。
推荐阅读
- django - Django缓存简单的身份验证令牌
- elasticsearch - 如何通过聚合在elasticsearch中的另一个字段上对数组字段执行联合操作
- java - Java 流:收集到映射为每个流元素创建两个键
- azure - 有没有办法以编程方式更新 Azure 函数槽中的应用程序设置?
- flutter - Flutter 应用程序不是从 apk 或 appbundle 安装的,在模拟器上运行良好
- c++ - 在 C++ 中,我可以在不触及 main() 的情况下将 cout 显示为双打吗?
- sql - SQL - 字符串直到第一个空格
- javascript - React Native Context vs Redux vs AsyncStorage
- c# - ASP.NET Core MVC 中的业务逻辑
- c# - 尽管 try/catch 块 C# 崩溃