python-3.x - Keras - 将不同数据点的不同参数传递到 Lambda 层
问题描述
我正在研究 Keras/TF 背景下的 CNN 模型。在最终卷积层结束时,我需要汇集过滤器的输出图。GlobalAveragePooling
我不得不根据输出地图宽度上存在的时间范围来进行池化,而不是使用或任何其他类型的池化。
因此,假设一个过滤器的样本输出n x m
是n
时间框架和m
特征输出。在这里,我只需要从<= n的帧n1 to n2
中汇集输出。所以我的输出切片是,我将在其上应用池化。我遇到了 keras 来做到这一点。但是我被困在一个点上,每个点都会有所不同。所以我的问题是如何将每个数据点的自定义参数传递到一个?还是我以错误的方式处理这个问题?n1
n2
(n2-n1)*m
Lambda Layer
n1
n2
Lambda Layer
示例片段:
# for slicing a tensor
def time_based_slicing(x, crop_at):
dim = x.get_shape()
len_ = crop_at[1] - crop_at[0]
return tf.slice(x, [0, crop_at[0], 0, 0], [1, len_, dim[2], dim[3]])
# for output shape
def return_out_shape(input_shape):
return tuple([input_shape[0], None, input_shape[2], input_shape[3]])
# lambda layer addition
model.add(Lambda(time_based_slicing, output_shape=return_out_shape, arguments={'crop_at': (2, 5)}))
crop_at
拟合循环时,需要为每个数据点自定义上述参数。对此的任何指示/线索都会有所帮助。
解决方案
从 Sequential API 切换 - 当您需要使用多个输入时,它开始崩溃:使用功能 API https://keras.io/models/model/
假设您的 lambda 函数是正确的:
def time_based_slicing(inputs_list):
x, crop_at = inputs_list
... (will probably need to do some work to subset crop_at since it will be a tensor now instead of constants
inp = Input(your_shape)
inp_additional = Inp((2,)
x=YOUR_CNN_LOGIC(inp)
out = Lambda(time_based_slicing)([x,inp_additional])
推荐阅读
- python - write 函数只返回字符串长度而不实际将字符串写入目标文件
- git - 如何使用 Git 计算在给定日期之前写入的行数
- c# - 如何断言已抛出异常
- java - JPA 查询中的 NullPointerException
- c# - 自定义排序字典
使用 LINQ 向下遍历 - search - 为什么 indexed_search 仅在 TYPO3 8 / 9 的子页面上不起作用?
- node.js - 是否可以使用 js-clien 运行 Inifinispan 连续查询
- excel - 基于正则表达式的 Excel“智能”自动完成(新功能?)
- php - 如何在 PHP 中将 GROUP BY + ORDER BY + WHERE 与 GET 一起使用?
- java - 如何在 try catch throw 中抛出错误对话框?