tensorflow - 运行预训练的 CNN 时不可调用非类型对象
问题描述
我正在尝试使用一些用于二进制分类的数据来运行预训练网络。不幸的是,我收到了标题中提到的错误。顺便说一句,我正在使用我自己的数据。
这是我的代码:
conv_base <- application_vgg16(
weights = "imagenet",
include_top = FALSE,
input_shape = c(150, 150, 3)
)
model <- keras_model_sequential() %>%
conv_base %>%
layer_flatten() %>%
layer_dense(units = 256, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
train_datagen = image_data_generator(
rescale = 1/255,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = TRUE,
fill_mode = "nearest"
)
# Note that the validation data shouldn't be augmented!
test_datagen <- image_data_generator(rescale = 1/255)
train_generator <- flow_images_from_directory(
train_dir, # Target directory
train_datagen, # Data generator
target_size = c(150, 150), # Resizes all images to 150 × 150
batch_size = 20,
class_mode = "binary" # binary_crossentropy loss for binary labels
)
validation_generator <- flow_images_from_directory(
validation_dir,
test_datagen,
target_size = c(150, 150),
batch_size = 20,
class_mode = "binary"
)
model %>% compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(lr = 2e-5),
metrics = c("accuracy")
)
history <- model %>% fit_generator(
train_generator,
steps_per_epoch = 100,
epochs = 30,
validation_data = validation_generator,
validation_steps = 50
)
上面显示了我加载预训练网络,然后将 2 个额外的密集层加载到模型中
错误定义如下
Detailed traceback:
File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/keras/engine/training.py", line 1296, in fit_generator
steps_name='steps_per_epoch')
File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/keras/engine/training_generator.py", line 265, in model_iteration
batch_outs = batch_function(*batch_data)
File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/keras/engine/training.py", line 1017, in train_on_batch
outputs = self.train_function(ins) # pylint: disable=not-callable
File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/keras/backend.py", line 3476, in __call__
run_metadata=self.run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/client/session.py", line 1472, in __call__
run_metadata_ptr)
Traceback:
1. model %>% fit_generator(train_generator, steps_per_epoch = 100,
. epochs = 30, validation_data = validation_generator, validation_steps = 50)
2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
3. eval(quote(`_fseq`(`_lhs`)), env, env)
4. eval(quote(`_fseq`(`_lhs`)), env, env)
5. `_fseq`(`_lhs`)
6. freduce(value, `_function_list`)
7. withVisible(function_list[[k]](value))
8. function_list[[k]](value)
9. fit_generator(., train_generator, steps_per_epoch = 100, epochs = 30,
. validation_data = validation_generator, validation_steps = 50)
10. call_generator_function(object$fit_generator, list(generator = generator,
. steps_per_epoch = as.integer(steps_per_epoch), epochs = as.integer(epochs),
. verbose = as.integer(verbose), callbacks = normalize_callbacks_with_metrics(view_metrics,
. callbacks), validation_data = validation_data, validation_steps = as_nullable_integer(validation_steps),
. class_weight = as_class_weight(class_weight), max_queue_size = as.integer(max_queue_size),
. workers = as.integer(workers), initial_epoch = as.integer(initial_epoch)))
11. do.call(func, args)
12. (structure(function (...)
. {
. dots <- py_resolve_dots(list(...))
. result <- py_call_impl(callable, dots$args, dots$keywords)
. if (convert) {
. result <- py_to_r(result)
. if (is.null(result))
. invisible(result)
. else result
. }
. else {
. result
. }
. }, class = c("python.builtin.instancemethod", "python.builtin.object"
. ), py_object = <environment>))(generator = <environment>, steps_per_epoch = 100L,
. epochs = 30L, verbose = 1L, callbacks = list(<environment>),
. validation_data = <environment>, validation_steps = 50L,
. class_weight = NULL, max_queue_size = 10L, workers = 1L,
. initial_epoch = 0L, use_multiprocessing = FALSE)
13. py_call_impl(callable, dots$args, dots$keywords)```
解决方案
推荐阅读
- docker - Docker - 自动清理卷上的旧文件
- pandas - 熊猫:在下一行条件下加入数据框
- android - 当我尝试运行颤振项目时将显示此错误
- azure-ad-b2c - ADB2C - 在外部 Azure Active Directory 登录时使用 login_hint
- sql - 删除在不同列中具有相同键的重复行
- javascript - 具有多个值的数据表按行分组
- matlab - 在 Octave 上以图形方式求解方程组
- javascript - 读取空单元格 node.js xlsx
- javascript - 在 Javascript 中没有正确答案的情况下提交后,如何将分值编辑到 textarea 输入测试?
- node.js - 删除 nodejs typescript 接口中不存在的属性