tensorflow - Tries to understand Tensorflow input_shape
问题描述
I have some confusions regarding to Tensorflow input_shape.
Suppose there are 3 documents (each row) in "doc" defined below, and the vocabulary has 4 words (each sublist in each row).
Further suppose that each word is represented by 2 numbers via word embedding.
The program only works when I specify input_shape=(3,4,2) under a Dense layer. But when I use a LSTM layer, the program only works when input_shape=(4,2) but not when input_shape=(3,4,2).
So how to specify the input shape for such inputs? How to make sense of it?
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
doc=[
[[1,0],[0,0],[0,0],[0,0]],
[[0,0],[1,0],[0,0],[0,0]],
[[0,0],[0,0],[1,0],[0,0]]
]
model=Sequential()
model.add(Dense(2,input_shape=(3,4,2))) # model.add(LSTM(2,input_shape=(4,2)))
model.compile(optimizer=Adam(learning_rate=0.0001),loss="sparse_categorical_crossentropy",metrics=("accuracy"))
model.summary()
output=model.predict(doc)
print(model.weights)
print(output)
解决方案
The input_shape
argument in a keras.layers.LTSM
layer expects a 2D array with a shape of [timesteps, features]
. Your doc
has the shape [batch_size, timesteps, features]
and therefore one dimension too much.
You can use the batch_input_shape
argument instead, if you want feed batch_size
, too.
To do so, you have just to replace this line of your code:
model.add(LSTM(2,input_shape=(4,2)))
With this one:
model.add(LSTM(2,batch_input_shape=(3,4,2)))
If you're setting a specific batch_size
in your model and then feed a different size other than 3 (in your case), you will get an error. Using input_shape
instead you have the flexibility to feed any batch size to the network.
推荐阅读
- php - 419 页面在 laravel 7 中已过期
- spring-boot - SpringBoot - *_max *_count *_sum 指标的可观察性
- javascript - Datatable AJAX 数据参数不刷新
- python - python中的矩阵乘法与分类
- java - 如何更改 android studios jdk 版本(特别是 jdk 8)
- css - Bootstrap 4网格未对齐
- sql - 使用 OPTION (MAXDOP 1) 来减少 SQL Server 中的并行性是否安全?
- uml - UML 序列图:如何绘制创建对象数组?
- react-native - 堆栈导航器标题按钮在按下时不会导航到路由
- nestjs - Nestjs with pm2, port 在用 pm2 stop 杀死进程后仍然可以使用?