首页 > 解决方案 > 模型 keras 训练期间 fit_generator 中的问题

问题描述

我是 Python 和 CNN 的初学者。我确实为 2 个类之间的训练模型编写了一个简单的代码我有一个文件夹有 2 个用于训练的文件夹和 2 个用于验证的文件夹

import keras,os
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator

Classifier = Sequential()
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2))) 
classifier.add(Flatten()) 
classifier.add(Dense(units = 128, activation = 'relu')) 
classifier.add(Dense(units = 1, activation = 'sigmoid')) 
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy' , metrics = ['raccuracy'])

train_datagen = ImageDataGenerator(rescale = 1./255, shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory(r'C:\Users\user1\Documents\code\Test', target_size = (244, 244), color_mode="rgb", batch_size = 32, class_mode = 'binary', shuffle=True,)
test_set = test_datagen.flow_from_directory(r'C:\Users\user1\Documents\code\Valid', target_size = (244, 244), color_mode="rgb", batch_size = 32, class_mode = 'binary', shuffle=True,)
STEP_SIZE_TRAIN=training_set.n    #train_generator.batch_size
STEP_SIZE_VALID=test_set.n    #valid_generator.batch_size
classifier.fit_generator(generator = training_set, steps_per_epoch = STEP_SIZE_TRAIN,epochs = 5,validation_data = test_set,validation_steps = STEP_SIZE_VALID)

一切顺利,但我有一个错误。我不知道是什么问题。它在第一个 epoch 立即开始,如下所示

Found 518 images belonging to 2 classes.
Found 40 images belonging to 2 classes.
Epoch 1/5
TypeError                                 Traceback (most recent call last) <ipython-input-21-c93c80bb7785> in <module>
     20 STEP_SIZE_VALID=test_set.n    #valid_generator.batch_size
     21 
---> 22 classifier.fit_generator(generator = training_set, 
     23                          steps_per_epoch = STEP_SIZE_TRAIN,
     24                          epochs = 5,

~\anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',


~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)    1813     """    1814    
_keras_api_gauge.get_cell('fit_generator').set(True)
-> 1815     return self.fit(    1816         generator,    1817         steps_per_epoch=steps_per_epoch,
     ~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.  

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)    1096                 batch_size=batch_size):   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)    1099               if data_handler.should_sync:    1100                 context.async_wait()

~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    805       # In this case we have created variables on the first call, so we run the
    806       # defunned version which is guaranteed to never create variables.
--> 807       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    808     elif self._stateful_fn is not None:
    809       # Release the lock early so that multiple threads can perform the call TypeError: 'NoneType' object is not callable.

Can any one please tell me what is the problem ?

标签: pythontensorflowkerasconv-neural-network

解决方案


推荐阅读