python - 如何使用 lambda 和 map_fn 在 4d 张量的每个 2d 切片上应用 keras 层?
问题描述
假设我们有一个形状为(64,100,5,32)的张量x ,它对应于 (batchSize,Length,Height,Channels)。现在,我想在第 32 个通道的每个大小为(100,5)的 2D 矩阵上应用 2D 转换层。所以我需要提取 32 个切片并使用相同的 2D 卷积层(参数)对其进行处理。我不知道如何从lambda und map_fn开始(请不要使用时间分布层)。最后,我想要一个大小为(64,100,5,32)的张量。
感谢您截断了如何执行此操作的简短代码。
解决方案
您可以简单地使用带有索引切片的 for 循环(没有 Lambda 层)。这是一个虚拟示例:
n_sample = 3
H,W,C = 100,5,32
X = np.random.uniform(0,1, (n_sample,H,W,C))
inp = Input((H,W,C))
convs = []
conv = Conv2D(1, 3, padding='same') # this is always the same for all the slices
for c in range(inp.shape[-1]):
_x = tf.expand_dims(inp[:,:,:,c], -1)
convs.append(conv(_x))
convs = Concatenate()(convs)
model = Model(inp, convs)
model.compile('adam', 'mse')
model.fit(X,X, epochs=2)
推荐阅读
- swift - 使用pickerView将无法识别的选择器发送到实例
- javascript - 是否可以将 csv 数据加载到 Highstock 图表?
- crystal-reports - 如何将 XML 数据源字段的所有值提取到数组中
- php - PHP - foreach 增量超过 1
- java - 为所有用户 Firebase 更新 UI
- scala - 在 Scala 中使用 Twitter Futures 进行异步日志记录
- sql-server - 来自 SQL Server 存储过程的 HTML 报告
- c# - 如何在 xamarin 中启动更新的活动
- mysql - mysql UPDATE 语句 where pricelist.import_date = ( SELECT max(pricelist.import_date) )
- java - Java10中的HttpRequest.BodyProcessor在哪里