首页 > 解决方案 > ValueError: Error when checking input: expected input_1 to have shape (50,) but got array with shape (1,) with ELMo embeddings and LSTM

问题描述

I'm trying to reproduce the example at this link:
https://www.depends-on-the-definition.com/named-entity-recognition-with-residual-lstm-and-elmo/
In few words, I'm trying to use the ELMo embeddings for the Sequence tagging task. I'm following this tutorial but when I try to fit the model

ValueError: Error when checking input: expected input_1 to have shape (50,) but got array with shape (1,)

The code that gives me the error is this:

from keras.layers.merge import add
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional, Lambda
input_text = Input(shape=(max_len,), dtype=tf.string)
embedding = Lambda(ElmoEmbedding, output_shape=(max_len, 1024))(input_text)
x = Bidirectional(LSTM(units=512, return_sequences=True,
                       recurrent_dropout=0.2, dropout=0.2))(embedding)
x_rnn = Bidirectional(LSTM(units=512, return_sequences=True,
                           recurrent_dropout=0.2, dropout=0.2))(x)
x = add([x, x_rnn])  # residual connection to the first biLSTM
out = TimeDistributed(Dense(n_tags, activation="softmax"))(x)
model = Model(input_text, out)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["categorical_accuracy"])
X_tr, X_val = X_tr[:1213*batch_size], X_tr[-135*batch_size:]
y_tr, y_val = y_tr[:1213*batch_size], y_tr[-135*batch_size:]
y_tr = y_tr.reshape(y_tr.shape[0], y_tr.shape[1], 1)
y_val = y_val.reshape(y_val.shape[0], y_val.shape[1], 1)
history = model.fit(np.array(X_tr), y_tr, validation_data=(np.array(X_val), y_val),batch_size=batch_size, epochs=3, verbose=1)

The error is related to the last line of this code, when I try to fit the model. Can someone help me in understand how to solve this problem?

标签: pythontensorflowkeraslstmelmo

解决方案


您的输入形状指定为 (50, ),但您的 np.array(X_tr) 的当前输出是单行数组 (1,)。鉴于有关您的数据的信息有限,我会检查数组的长度(X_tr),如果它是 50,只需使用 .T 转置它

X_tr_arr = np.array(X_tr)
X_tr_t = X_tr_arr.T

https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.ndarray.T.html


推荐阅读