首页 > 解决方案 > 使用 Keras callback_tensorboard 时出错

问题描述

我正在使用 R 3.2.3、Keras 2.1.6 和 TensorFlow 1.10 来解决文本分类问题。我正在尝试通过 TensorBoard 设置仪器,但我遇到了这个错误:

Error in py_call_impl(callable, dots$args, dots$keywords) : ValueError: To visualize embeddings, embeddings_data must be provided.

这是我的最小模型/训练设置:

# x, y, and tokens loaded from tab files
num_samples <- 30000L
train_sample <- sample(1:dim(x)[1], num_samples)

tb_log <- "tb_log"
tensorboard(tb_log)

model <- keras_model_sequential() %>% 
  layer_embedding(input_dim = dim(tokens)[1], output_dim = 128, input_length = 1000) %>% 
  layer_conv_1d(filters = 32, kernel_size = 7, activation = "relu") %>% 
  layer_max_pooling_1d(pool_size = 5) %>% 
  layer_conv_1d(filters = 32, kernel_size = 7, activation = "relu") %>% 
  layer_global_max_pooling_1d() %>%
  layer_dense(units = 1)

summary(model)

model %>% compile(
  optimizer = "rmsprop",
  loss = "binary_crossentropy",
  metrics = c("acc")
)

history <- model %>% fit(
  x[train_sample,], y[train_sample],
  epochs = 3,
  batch_size = 128,
  validation_split = 0.5,
  callbacks = c(callback_tensorboard(
    log_dir = tb_log,
    embeddings_freq = 1,
    histogram_freq = 1
  ))
)

模型在第一个 epoch 进行训练,然后该过程因上述错误而终止。如果我callbacks从调用中删除该选项fit,模型将按预期进行训练和工作。我可以看到回调没有embeddings_data参数。我已尝试按照此处embeddings_metadata所述进行传递,但仍然出现相同的错误。如果我只是从回调中删除该选项,我会收到此错误:embedding_freq

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  InvalidArgumentError: You must feed a value for placeholder tensor 'embedding_10_input' with dtype float and shape [?,1000]

我错过了一些明显的东西吗?

更新

第二个错误(InvalidArgumentError)显然是由于尝试使用带有embeddings_freqset 的回调后环境的某些损坏引起的。如果我删除该选项,删除日志文件夹,然后从头开始重新启动我的 R 会话,我可以让它训练并生成直方图等,但在可视化实际嵌入方面仍然没有骰子。

标签: rtensorflowkeras

解决方案


这似乎归结为 TensorFlow、python Keras 模块和 R Keras 模块之间的版本不匹配。对于将来尝试解决此问题的人,您可以检查所有三个版本,如下所示:

python -c "import tensorflow; print(tensorflow.__version__)"
python -c "import keras; print(keras.__version__)"
Rscript -e "library(keras); sessionInfo()"

或者python3 -c ...视情况而定。您的 R 环境使用正确的 Python 环境也很重要,您可以检查:

Rscript -e "reticulate::py_config()"

除此之外,还有一点点试错;我还没有找到任何能够始终如一地记录哪些 Keras 版本支持哪些 TensorFlow 版本等。对于我的情况,神奇的关系最终是1.10为 Python2.7.2和 Keras构建的 TensorFlow 2.2


推荐阅读