首页 > 解决方案 > model.fit中的Tensorflow形状不正确

问题描述

我正在尝试使用此处类似的方法来拟合 TF 模型:How to Merge Numerical and Embedding Sequential Models to Treat categories in RNN

我试图通过使用前 5 次发生的测量来预测第 n 个事件。每个 X 变量都是一个不同的分类变量,所以我的模型想要

num_cats = 4 # number of categorical features
n_steps = 5 # number of timesteps in each sample
cat_size = [7954,  3500, 3000, 2000] # number of categories in each categorical feature
cat_embd_dim = [7953, 350, 300, 200] # embedding dimension for each categorical feature



cat_inputs = []
for i in range(num_cats):
    cat_inputs.append(Input(shape=(n_steps,), name='cat' + str(i+1) + '_input'))

cat_embedded = []
for i in range(num_cats):
    embed = Embedding(cat_size[i], cat_embd_dim[i])(cat_inputs[i])
    cat_embedded.append(embed)

cat_merged = concatenate(cat_embedded)

lstm_out = LSTM(1)(cat_merged)
model = Model( cat_inputs, lstm_out)

model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy'] )

model.fit([ X_tr_cat1, X_tr_cat2, X_tr_cat3, X_tr_cat4], y_train, epochs=10)

模型摘要是:

Model: "functional_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
cat1_input (InputLayer)         [(None, 5)]          0                                            
__________________________________________________________________________________________________
cat2_input (InputLayer)         [(None, 5)]          0                                            
__________________________________________________________________________________________________
cat3_input (InputLayer)         [(None, 5)]          0                                            
__________________________________________________________________________________________________
cat4_input (InputLayer)         [(None, 5)]          0                                            
__________________________________________________________________________________________________
embedding_28 (Embedding)        (None, 5, 7953)      63258162    cat1_input[0][0]                 
__________________________________________________________________________________________________
embedding_29 (Embedding)        (None, 5, 350)       1530900     cat2_input[0][0]                 
__________________________________________________________________________________________________
embedding_30 (Embedding)        (None, 5, 300)       1014600     cat3_input[0][0]                 
__________________________________________________________________________________________________
embedding_31 (Embedding)        (None, 5, 200)       491000      cat4_input[0][0]                 
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 5, 8803)      0           embedding_28[0][0]               
                                                                 embedding_29[0][0]               
                                                                 embedding_30[0][0]               
                                                                 embedding_31[0][0]               
__________________________________________________________________________________________________
lstm_6 (LSTM)                   (None, 1)            35220       concatenate_7[0][0]              
==================================================================================================
Total params: 66,329,882
Trainable params: 66,329,882
Non-trainable params: 0
__________________________________________________________________________________

跑步model.fit给了我ValueError: Data cardinality is ambiguous: x sizes: 39770, 39770, 39770, 39770 y sizes: 7954 Please provide data which shares the same first dimension.

每个X变量都有形状 (39770, 5) (即 39770 个序列,每个序列包含 5 个测量值。我的数据中每个人有 5 个序列,所以有 7954 人)并且 y_train 有形状 (7954,5) (即每个人的每个序列的 5 个记录结果)。

有人可以解释如何解决该错误吗?

标签: pythontensorflowkerasneural-networklstm

解决方案


推荐阅读