tensorflow - 在 4 维(4D 张量)的输入上使用 MaxPool1D?
问题描述
我正在尝试为多实例学习创建一个 NN 架构,因此这些实例实际上是时间序列段的包。我想在特征(最后一个维度)上执行 COnv1D 和 MaxPool1D,我将输入指定为具有 4 个维度,并且对 Conv1D 工作正常,但在 MaxPool1D 中引发错误:
n = 6
sample_size = 300
code_size = 50
learning_rate = 0.001
bag_size = None
# autoencoder: n_bags X bag_size X n_samples (timesteps) X n_measurements
input_window = Input(shape=(bag_size,sample_size, n))
x = Conv1D(filters=40, kernel_size=21, activation='relu', padding='valid')(input_window)
x = MaxPooling1D(pool_size=2)(x)
错误是:
ValueError: Input 0 of layer max_pooling1d_4 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, None, 280, 40]
根据TF 文档 MaxPool1D 仅适用于 3D 张量。有解决办法吗?
解决方案
虽然尚不清楚您要在哪个轴上进行池化,但您可以使用MaxPooling2D
正确的池化大小,在这种情况下,IIUC 将是(1,2)
from tensorflow.keras import layers, Model
n = 6
sample_size = 300
code_size = 50
learning_rate = 0.001
n_bags= None
# autoencoder: n_bags X n_instances_in_bag X n_samples (timesteps) X n_measurements
input_window = layers.Input(shape=(n_bags,sample_size, n))
x = layers.Conv1D(filters=40, kernel_size=21, activation='relu', padding='valid')(input_window)
x = layers.MaxPooling2D(pool_size=(1,2))(x)
model = Model(input_window, x)
model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, None, 300, 6)] 0
_________________________________________________________________
conv1d_3 (Conv1D) (None, None, 280, 40) 5080
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, None, 140, 40) 0
=================================================================
Total params: 5,080
Trainable params: 5,080
Non-trainable params: 0
_________________________________________________________________
推荐阅读
- javascript - 无法呈现反应引导表
- json - 在 Jinja2 模板中循环 Ansible 变量数组以创建嵌套字典
- jsdoc - 如何使用 JSDoc 注释返回此副本的方法?
- python - Python 解码 | 字节到 json
- node.js - ScrollConsoleScreenBuffer 影响裁剪矩形外的数据
- sql - 是否可以在 SQL Server 中使用通配符作为 OPENJSON 的参数?
- swift - 如何在后台强制 macOS 应用程序更新?
- android - NoSuchFieldException:类 Landroid/widget/ImageView 中没有字段 mMaxWidth
- git - 我在 git 中一次提交了 10 个文件,但现在我怎样才能只恢复其中的 2 个
- reactjs - 在useState问题之前调用react useCallback