首页 > 解决方案 > 构建自定义 keras 层,我遇到 InvalidArgumentError: Input 'pred' pass float expected bool while building

问题描述

我在构建自定义 keras 层时遇到了一个奇怪的错误。我正在尝试构建一个与 GRU 层非常相似的自定义层,但除了 sampled_z 之外还需要额外的输入,以便在变分自动编码器中进行教师强制。

我成功构建了 VAE 模型,其中 terminal_GRU 表示自定义 GRU 层。

Model: "VAE"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoder_input (InputLayer)      (None, 80, 69)       0                                            
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 80, 9)        5598        encoder_input[0][0]              
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 80, 9)        738         conv1d_1[0][0]                   
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 80, 10)       910         conv1d_2[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 800)          0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 180)          144180      flatten_1[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 180)          32580       dense_1[0][0]                    
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 180)          32580       dense_1[0][0]                    
__________________________________________________________________________________________________
z_sampling (Lambda)             (None, 180)          0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
__________________________________________________________________________________________________
decoder (Model)                 (None, 80, 69)       1760451     z_sampling[0][0]                 
                                                                 encoder_input[0][0]              
==================================================================================================
Total params: 1,977,037
Trainable params: 1,977,037
Non-trainable params: 0
__________________________________________________________________________________________________

和解码器模型看起来像

Model: "decoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoded_input (InputLayer)      (None, 180)          0                                            
__________________________________________________________________________________________________
reapeat_context (RepeatVector)  (None, 80, 180)      0           encoded_input[0][0]              
__________________________________________________________________________________________________
decoder_GRU1 (GRU)              [(None, 80, 400), (N 697200      reapeat_context[0][0]            
__________________________________________________________________________________________________
decoder_GRU2 (GRU)              [(None, 80, 400), (N 961200      decoder_GRU1[0][0]               
                                                                 decoder_GRU1[0][1]               
__________________________________________________________________________________________________
true_seq_input (InputLayer)     (None, 80, 69)       0                                            
__________________________________________________________________________________________________
terminal_GRU (TGRU)             [(None, 80, 69), (No 102051      decoder_GRU2[0][0]               
                                                                 true_seq_input[0][0]             
==================================================================================================
Total params: 1,760,451
Trainable params: 1,760,451
Non-trainable params: 0
__________________________________________________________________________________________________

但是,当我尝试使用 fit_generator() 方法来训练这个模型时,我遇到了 InvalidArgumentError 如下:

InvalidArgumentError: Input 'pred' passed float expected bool while building NodeDef 'decoder/terminal_GRU/PartitionedCall/cond/switch_pred/_1362' using Op<name=Switch; signature=data:T, pred:bool -> output_false:T, output_true:T; attr=T:type> [Op:__inference_keras_scratch_graph_7791]

有没有人可以告诉我为什么会发生这个错误?令我沮丧的是,我找不到为什么会发生此错误...

标签: pythontensorflowkeras

解决方案


也许尝试将变量名称更改为输入,因为有时 pred 用于定义模型是处于预测模式还是训练模式。

但是,如果您向我们展示一些代码,那将真的很有帮助。

祝你好运


推荐阅读