python - 如何在文本的卷积层中使用 tensorflow hub 层嵌入?
问题描述
我是 TensorFlow 集线器的新手,我正在尝试在我的 Conv1D 网络中使用集线器嵌入层来进行文本分类。
在顺序模型中使用集线器嵌入层没有任何问题:
hub_layer = hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim50/2",
input_shape=[], dtype=tf.string, trainable=False)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(128))
model.add(tf.keras.layers.Activation('relu'))
model.add(tf.keras.layers.Dense(5))
model.add(tf.keras.layers.Activation('softmax'))
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()
但是,我无法在 Conv1D 模型中使用:
第一个模型:
int_sequences_input = Input(shape=(max_length,))
embedded_sequences = hub_layer(int_sequences_input)
x = layers.Conv1D(128, 5, activation="relu")(embedded_sequences)
x = layers.MaxPooling1D(5)(x)
x = layers.Conv1D(128, 5, activation="relu")(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.5)(x)
preds = layers.Dense(len(class_names), activation="softmax")(x)
model = keras.Model(int_sequences_input, preds)
model.summary()
或者:
第二种型号:
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Conv1D(128, 7, activation='relu'))
model.add(tf.keras.layers.GlobalMaxPooling1D())
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
当我收到价值错误时:
ValueError: Input 0 of layer conv1d_11 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 50]
解决方案
生成的嵌入维度是:(num_examples, embedding_dimension)
它与 1D 卷积不兼容,因为它需要 3D 输入。
尝试在中心层之后重塑,如下所示:
model.add(hub_layer)
model.add(tf.keras.layers.Reshape((1,50)))
model.add(tf.keras.layers.Conv1D(16, 3, activation='relu', padding = 'same'))
...
推荐阅读
- postgresql - Postgresql:通过触发器检查表中是否存在值
- c - C 使用结构的动态内存分配
- python - 2019 年 Python 2/3 Asyncio
- docker - Docker specific installation version. Dependency failed
- javascript - How to Javascript to replace HTML text with new text when the HTML may have children elements
- unity3d - Unity 相机抖动/玩家传送
- javascript - Eslint glob (**) 没有递归地考虑所有目录
- multithreading - C++ Visual Studio 2017 表单。如何从单独的线程更新和刷新标签
- ios - Firestore - 使用子集合添加文档的单一操作
- java - java - 无法将多个数据插入数据库