python - 构建自定义 keras 层,我遇到 InvalidArgumentError: Input 'pred' pass float expected bool while building
问题描述
我在构建自定义 keras 层时遇到了一个奇怪的错误。我正在尝试构建一个与 GRU 层非常相似的自定义层,但除了 sampled_z 之外还需要额外的输入,以便在变分自动编码器中进行教师强制。
我成功构建了 VAE 模型,其中 terminal_GRU 表示自定义 GRU 层。
Model: "VAE"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) (None, 80, 69) 0
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 80, 9) 5598 encoder_input[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 80, 9) 738 conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 80, 10) 910 conv1d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 800) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 180) 144180 flatten_1[0][0]
__________________________________________________________________________________________________
z_mean (Dense) (None, 180) 32580 dense_1[0][0]
__________________________________________________________________________________________________
z_log_var (Dense) (None, 180) 32580 dense_1[0][0]
__________________________________________________________________________________________________
z_sampling (Lambda) (None, 180) 0 z_mean[0][0]
z_log_var[0][0]
__________________________________________________________________________________________________
decoder (Model) (None, 80, 69) 1760451 z_sampling[0][0]
encoder_input[0][0]
==================================================================================================
Total params: 1,977,037
Trainable params: 1,977,037
Non-trainable params: 0
__________________________________________________________________________________________________
和解码器模型看起来像
Model: "decoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoded_input (InputLayer) (None, 180) 0
__________________________________________________________________________________________________
reapeat_context (RepeatVector) (None, 80, 180) 0 encoded_input[0][0]
__________________________________________________________________________________________________
decoder_GRU1 (GRU) [(None, 80, 400), (N 697200 reapeat_context[0][0]
__________________________________________________________________________________________________
decoder_GRU2 (GRU) [(None, 80, 400), (N 961200 decoder_GRU1[0][0]
decoder_GRU1[0][1]
__________________________________________________________________________________________________
true_seq_input (InputLayer) (None, 80, 69) 0
__________________________________________________________________________________________________
terminal_GRU (TGRU) [(None, 80, 69), (No 102051 decoder_GRU2[0][0]
true_seq_input[0][0]
==================================================================================================
Total params: 1,760,451
Trainable params: 1,760,451
Non-trainable params: 0
__________________________________________________________________________________________________
但是,当我尝试使用 fit_generator() 方法来训练这个模型时,我遇到了 InvalidArgumentError 如下:
InvalidArgumentError: Input 'pred' passed float expected bool while building NodeDef 'decoder/terminal_GRU/PartitionedCall/cond/switch_pred/_1362' using Op<name=Switch; signature=data:T, pred:bool -> output_false:T, output_true:T; attr=T:type> [Op:__inference_keras_scratch_graph_7791]
有没有人可以告诉我为什么会发生这个错误?令我沮丧的是,我找不到为什么会发生此错误...
解决方案
也许尝试将变量名称更改为输入,因为有时 pred 用于定义模型是处于预测模式还是训练模式。
但是,如果您向我们展示一些代码,那将真的很有帮助。
祝你好运
推荐阅读
- laravel - laravel 8中“路由”资源的问题
- python - 将 Python 字符串填充到指定长度
- firebase - Firebase云功能模拟器:不在线部署如何编译本地运行
- node.js - 如果已经存在则忽略索引并使用 nodejs 在 elasticsearch 中创建/添加新索引
- javascript - Javascript 搜索是否不如使用 Sphinx 或 Elasticsearch 等搜索客户端安全?
- ruby-on-rails - bundle install on ruby on rails 抛出 mingw32 错误
- javascript - Material-UI 中的 DefaultTheme 导致“扩充中的模块名称无效”错误
- oracle - Windows Server 2008R 中的 Pro*C Oracle 11g 数据库连接
- google-chrome - PWA:在隐身模式下是否会触发 beforeinstallprompt 事件?
- javascript - 将名称随机推入两个数组之一