r - 为什么在 R 中使用 keras 训练自动编码器时会收到此错误?
问题描述
我有一个包含不同大小的 8 位 rgb(3 通道)图像的目录。我正在尝试使用它们在linux mint 19机器上使用keras 2.2.5.0和tensorflow 2.0.0在R 3.6.3中训练自动编码器。数据集在这里(压缩):https ://github.com/hrj21/processing-imagestream-images/blob/master/ciliated_cells.zip
图像分为两个标记的类,但我不关心这个类结构。
当我运行该fit_generator()
函数时,我收到错误:
Error in py_call_impl(callable, dots$args, dots$keywords) :
IndexError: list index out of range
我确定这是我做错的事情,但我对 keras 的经验不足,无法理解那是什么。您可以提供的任何帮助将不胜感激。这是代码:
# Load package ------------------------------------------------------------
library(keras)
# Defining the file paths -------------------------------------------------
base_dir <- "ciliated_cells"
train_dir <- file.path(base_dir, "train")
validation_dir <- file.path(base_dir, "validation")
test_dir <- file.path(base_dir, "test")
# Define data generators --------------------------------------------------
# To scale and resize images
datagen <- image_data_generator(rescale = 1/255)
train_generator <- flow_images_from_directory(
train_dir,
datagen,
target_size = c(150, 150),
batch_size = 88,
class_mode = NULL
)
validation_generator <- flow_images_from_directory(
validation_dir,
datagen,
target_size = c(150, 150),
batch_size = 36,
class_mode = NULL
)
test_generator = flow_images_from_directory(
test_dir,
datagen,
target_size = c(150, 150),
batch_size = 30,
class_mode = NULL,
shuffle = FALSE) # keep data in same order as labels
# Defining the model architecture from scratch ----------------------------
input <- layer_input(shape = c(150, 150, 3))
output <- input %>%
layer_flatten(input_shape = c(150, 150, 3)) %>%
layer_flatten() %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 16, name = "code") %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 150 * 150 * 3) %>%
layer_reshape(c(150, 150, 3))
model <- keras_model(input, output)
# Compiling and fitting the model -----------------------------------------
model %>% compile(
loss = "mse",
optimizer = optimizer_rmsprop(lr = 2e-5)
)
history <- model %>% fit_generator(
train_generator,
steps_per_epoch = 1,
epochs = 100,
validation_data = validation_generator,
validation_steps = 1
)
这是输出sessionInfo()
:
R version 3.6.3 (2020-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Linux Mint 19
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
locale:
[1] LC_CTYPE=en_GB.UTF-8 LC_NUMERIC=C LC_TIME=en_GB.UTF-8
[4] LC_COLLATE=en_GB.UTF-8 LC_MONETARY=en_GB.UTF-8 LC_MESSAGES=en_GB.UTF-8
[7] LC_PAPER=en_GB.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] forcats_0.5.0 stringr_1.4.0 dplyr_0.8.5 purrr_0.3.3 readr_1.3.1 tidyr_1.0.2
[7] tibble_3.0.0 ggplot2_3.3.0 tidyverse_1.3.0 keras_2.2.5.0
loaded via a namespace (and not attached):
[1] reticulate_1.15-9000 tidyselect_1.0.0 haven_2.2.0 lattice_0.20-41
[5] colorspace_1.4-1 vctrs_0.2.4 generics_0.0.2 base64enc_0.1-3
[9] rlang_0.4.5 pillar_1.4.3 withr_2.1.2 glue_1.4.0
[13] DBI_1.1.0 rappdirs_0.3.1 dbplyr_1.4.2 modelr_0.1.6
[17] readxl_1.3.1 lifecycle_0.2.0 tensorflow_2.0.0 munsell_0.5.0
[21] gtable_0.3.0 cellranger_1.1.0 rvest_0.3.5 tfruns_1.4
[25] fansi_0.4.1 broom_0.5.5 Rcpp_1.0.4.6 backports_1.1.6
[29] scales_1.1.0 jsonlite_1.6.1 fs_1.4.1 hms_0.5.3
[33] packrat_0.5.0 stringi_1.4.6 grid_3.6.3 cli_2.0.2
[37] tools_3.6.3 magrittr_1.5 crayon_1.3.4 whisker_0.4
[41] pkgconfig_2.0.3 zeallot_0.1.0 ellipsis_0.3.0 Matrix_1.2-18
[45] xml2_1.3.1 reprex_0.3.0 lubridate_1.7.4 assertthat_0.2.1
[49] httr_1.4.1 rstudioapi_0.11 R6_2.4.1 nlme_3.1-145
[53] compiler_3.6.3
解决方案
所以我意识到了我的错误。我的数据生成器正在生成输入图像,而不是输出图像(应该是相同的)供自动编码器学习。因此解决方案是将class_mode
每个flow_images_from_directory()
函数内部的参数更改为“输入”。'fit_generator()` 函数然后运行没有问题。没有这个,自动编码器就不会“知道”它试图在输出层再现输入图像。
推荐阅读
- java - 为什么 Sentry 不能在 Java 中工作,而类似的代码在 python 中工作?
- android - 使用 Android Studio 获取 Android 应用的当前调度程序信息
- javascript - 未捕获的类型错误:无法设置未定义的属性“0” - Javascript
- javascript - 如何在 Discord.js 事件“guildCreate”中发送消息
- html - 背景颜色跨越我的导航栏的宽度,而不是整个背景
- javascript - 在 Promise() 中使用 while 循环
- jersey - 使用 Jersey 构建 REST 服务
- android - 为什么用滑行对图像进行圆角会增加我的 ram 使用量?
- google-photos - Google 照片 API 中的照片 ID
- ios - 无法更改 CNContactViewController forNewContact 中的个人资料图片,当联系人已存在于电话联系人中时