machine-learning - 结合具有不同特征但目标相同的神经网络
问题描述
我有一个关于如何创建神经网络(或类似)架构的问题,该架构执行以下操作:
例子
假设我有一个使用特征 1 和特征 2 来预测目标的模型。该模型表现不佳,因为我受限于填充了特征 1 和特征 2 的训练示例的数量。
如果我要拥有另一个具有特征 3 和特征 4 的神经网络,并且我的目标是预测相同的目标,那么我该如何结合两个模型的学习来做出相同的目标预测。
对于具有不同特征但具有共同目标的其他几个类似数据集,这种情况继续存在
解释
我这样做只是因为并非每个训练示例都具有特征 1、2、3 和 4,因此不能将其合并到单个模型中。但唯一的共同点是模型试图预测相同的目标。
问题
哪种机器学习策略(不仅仅是神经网络)最适合此类问题?
解决方案
您描述的模型由 2 个核心子模型构建而成。
许多依赖于特征的编码器,每个特征集一个。特征 1 和 2 可以由模型的一部分组合成一些隐藏的表示。特征 3 和 4 将被转换为相同的隐藏表示,但将具有不同的子模型,并具有不同的参数集来拟合。
隐藏表示之上的单个与特征无关的解码器,用于预测您的目标。
在拟合模型时,每个编码器只能使用所需特征集可用的数据。它正在拟合这些特征的表示,因此它需要看到它们。但是解码器可以用于您的所有数据。这将捕获目标的分布,这很常见,因为您的目标很常见。
当您认为存在有意义的隐藏表示时,这种模型是合适的。也就是说,您认为您的功能集正在以不同的方式测量相似的事物。
这使您可以保持编码器很小,因为它正在从一种测量方式到另一种测量方式进行小的转换。从测量值转换到目标可能仍然很困难,但由于该逻辑进入通用解码器,它可以从所有训练数据中受益。
具体来说,如果您的特征是width
、height
、volume
和,则此类模型的一个很好的示例用例是weight
。假设您的目标是 shipping cost
。
可以合理地说,中间表示可以很好地用 的概念来描述size
。也可以合理地说,从size
to转换cost
本身就是一个有趣的问题,不管你size
最初是如何测量的。
所以模型公式看起来像这样:
# Feature encoders.
size ~ width + height
size ~ volume + weight
# Target decoder.
cost ~ size
现在,上面我已经仔细描述了模型设计,没有对模型类型做出任何承诺。但是您确实将这个问题标记为相关的神经网络,我认为这是一个不错的选择。
对于您的简单示例,使用 PyTorch,模型可能如下所示:
import torch.nn as nn
import torch.nn.functional as F
class MultiEncoderSingleDecoder(torch.nn.Module):
def __init__(self, hid_sz):
super().__init__()
self.using_encoder = 0
self.encoders = torch.nn.ModuleList([
torch.nn.Linear(2, hid_sz),
torch.nn.Linear(2, hid_sz),
])
self.decoder = torch.nn.Linear(hid_sz, 1)
def set_encoder(self, use_encoder):
self.using_encoder = use_encoder
def forward(self, inp):
encoder = self.encoders[self.using_encoder]
return self.decoder(F.relu(encoder(inp)))
然后用法可能如下所示:
model = MultiEncoderSingleDecoder()
model.set_encoder(0)
# Do some training on the first feature set.
model.set_encoder(1)
# Do some more training on the second feature set.
# ...
推荐阅读
- android - Android SwipeRefreshLayout Spinner 在完成前隐藏
- amazon-eks - 使用 boto3 连接 EKS 集群,然后需要排空特定节点
- c# - Powershell中的Windows徽标键+ Alt + PrtScn同时按下多个键?
- php - PHP - nagate 回调结果
- gradle - 当我在 gradle 中添加一个新的源集时,“Kotlin 未配置”。(智能)
- r - 在 r 中替换 XML (KML) 文件中的特定文本
- r - 您如何聚合行并选择特定日期的值?
- java - 带有 OpenAM 的 Spring OIDC
- f# - Deedle、F# 和读取 csv
- reactjs - 收到非布尔属性“showheader”的“false”。如果要将其写入 DOM,请改为传递一个字符串:showheader="false"