python - 在 Tensorflow 中实现一个展平层
问题描述
我正在尝试使用 TensorFlow 2.2.0 实现一个扁平化层。我正在按照 Geron 的书(第 2 版)中的说明进行操作。至于扁平层,我首先尝试获取批量输入形状并计算新形状。但是我在张量维度上遇到了这个问题:TypeError: Dimension value must be integer or None or have an __index__ method
import tensorflow as tf
from tensorflow import keras
(X_train, y_train), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
input_shape = X_train.shape[1:]
assert input_shape == (28, 28)
class MyFlatten(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, batch_input_shape):
super().build(batch_input_shape)
def call(self, X):
X_shape = tf.shape(X)
batch_size = X_shape[0]
new_shape = tf.TensorShape([batch_size, X_shape[1]*X_shape[2]])
return tf.reshape(X, new_shape)
def get_config(self):
base_config = super().get_config()
return {**base_config}
## works fine on this example
MyFlatten()(X_train[:10])
## fail when building a model
input_ = keras.layers.Input(shape=[28, 28])
fltten_ = MyFlatten()(input_)
hidden1 = keras.layers.Dense(300, activation="relu")(fltten_)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
output = keras.layers.Dense(10, activation="softmax")(hidden2)
model = keras.models.Model(inputs=[input_], outputs=[output])
model.summary()
解决方案
不要尝试创建一个tf.TensorShape
,它只有在张量的所有维度都已知时才有效,实际上这只会在急切模式下,所以它的模型编译会失败。像这样简单地重塑:
def call(self, X):
X_shape = tf.shape(X)
batch_size = X_shape[0]
new_shape = [batch_size, X_shape[1] * X_shape[2]]
return tf.reshape(X, new_shape)
或者,更一般地说,您可以这样做:
def call(self, X):
X_shape = tf.shape(X)
batch_size = X_shape[0]
new_shape = [batch_size, tf.math.reduce_prod(X_shape[1:])]
return tf.reshape(X, new_shape)
tf.reshape
也会接受类似的东西new_shape = [batch_size, -1]
,但我认为这可能会使展平尺寸的大小未知,具体取决于具体情况。另一方面,相反的事情 ,new_shape = [-1, tf.math.reduce_prod(X_shape[1:])]
也应该可以正常工作。
顺便说一句,我假设您这样做是作为练习并且已经知道这一点,但仅供参考Flatten
,Keras 中已经有一个层(您可以查看它的源代码)。
推荐阅读
- c# - 在 AutoMapper 中,如何从抽象基类项目到接口
- javascript - Angular URL 数据生成
- c# - 将字典值按降序排序
- javascript - 有没有办法确定 chrome 选项卡是否正在通过 api 录制音频/视频
- node.js - POST 请求返回 index.html 正文而不是适当的响应
- html - CSS如何根据高度缩放和纵横比宽度?
- python-3.x - Discord.py 加入/留言
- mysql - 主键设置为自动增量,但仍接收字段没有默认值
- python - 使用带日期的多个条件进行切片
- angular - 如何通过向 resolveComponentFactory 添加动态名称来动态创建 Angular 组件