python - 我有具有二进制特征的机器学习数据。如何强制自动编码器返回二进制数据?
问题描述
我有以下形式的数据集:对 N 维数据的一系列 M 观察。为了从这些数据中获取潜在因素,我希望制作一个在这些数据上训练的隐藏层自动编码器。单个观察的每个维度都是 0 或 1。但 keras 模型返回浮点数。有没有办法添加一个层来强制输出 0 或 1?
我尝试使用一个简单的 keras 模型来解决这个问题。它声称数据的准确性很高,但在查看原始数据时,它正确地预测了 0,并且通常完全忽略了 1。
n_nodes = 50
input_1 = tf.keras.layers.Input(shape=(x_train.shape[1],))
x = tf.keras.layers.Dense(n_nodes, activation='relu')(input_1)
output_1 = tf.keras.layers.Dense(x_train.shape[1], activation='sigmoid')(x)
model = tf.keras.models.Model(input_1, output_1)
my_optimizer = tf.keras.optimizers.RMSprop()
my_optimizer.lr = 0.002
model.compile(optimizer=my_optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10000)
predictions = model.predict(x_test)
然后,我通过查看所有实验并查看是否为 1 的元素返回了大 (>0.1) 值来验证这些观察结果。1 的性能非常差。
我已经看到损失在 10000 个 epoch 左右收敛。但是,自动编码器无法正确预测数据集中几乎所有的 1。即使将隐藏层的宽度设置为与数据的维数相同n_nodes = x_train.shape[1]
(
解决方案
[0, 1]
输出通常应该四舍五入,以便在输出最终预测时 >=0.5 舍入为 1,<0.5 舍入为 0。但是,您的标签应该是{0.0, 1.0}
损失函数的浮点值(我希望它们已经是)。您可以通过对输出进行四舍五入并与二进制标签进行比较来计算准确度,以计算 的错误{0, 1}
,但它们必须采用连续形式[0.0, 1.0]
才能使损失和梯度计算起作用。
如果您正在执行所有这些操作(并且您的代码中的设置确实正确),性能不佳可能有多种原因:
你密集的“收缩”层应该比你的输入要小得多。在使其更小时,您迫使自动编码器学习可用于产生输出的输入的代表性形式。这种代表形式很可能很好地概括。如果你增加隐藏层的大小,网络将有更多的容量来记忆输入。
您的值可能
0
比值多得多1
,如果是这种情况,那么在没有实际学习的情况下,网络可能会卡住,只是将 0 预测为“最佳猜测”,因为这“通常是正确的”。这是一个更难解决的问题。您可以考虑将损失乘以 的向量labels * eta + 1
,这将有效地提高标签的学习率。示例:您的标签是[0, 1, 0]
,eta 是一个 >1 的超参数值,假设 eta=2.0。它通过仅增加'slabels * eta = [1.0, 3.0, 1.0]
的损失来为 1 值放大梯度信号。1
这不是增加 's 类重要性的防弹方法1
,但它很容易尝试。如果它有任何改进,那么更详细地跟进这条推理。您有 1 个隐藏层,这意味着您仅限于线性关系,您可以尝试 3 个隐藏层来添加一点非线性。您的中心层应该相当小,尝试 5 或 10 个神经元,它应该需要将数据压缩到一个相当紧凑的收缩点以提取通用表示。
推荐阅读
- flutter - 当我运行(F5)Flutter Web 应用程序(我使用 VScode)时,chrome 上出现空白屏幕
- typescript - 使用没有 DOM 类型的 TypeScript(对于节点)
- c++ - 如何为 pcl::BoxClipper3D 设置输入云
- python - 如果表有多个类,如何单击,这是 gmail 搜索部分,我已附上截图,请检查
- javascript - HttpMessageNotWritableException:无法写入 JSON:EL1025E:集合包含“0”个元素,索引“0”无效
- android - PagerAdapter.getItem() 崩溃并出现 IllegalStateException:已添加片段
- angular - 使用 Twilio Media Streams 和 Google Speech-to-Text 编写实时转录电话时出错
- python - 是否可以在另一个views.py方法中访问通过HTML表单提交的、在预览页面中呈现的值
- angular - 如何在 Jasmine 中模拟 Angular 9 中的大量依赖项的依赖模块?
- unity3d - 如何让宇宙飞船掉头?