首页 > 解决方案 > 用于训练考虑 keras 中最后一层的网络

问题描述

这是我的模型代码:

Model=Sequential()
input_img = Input(shape=(180,180,3))  # adapt this if using channels_first` image data format


x = Conv2D(64, (3, 3), padding='valid')(input_img)

x = Conv2D(64, (3, 3), padding='valid',strides=2)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

y = Conv2D(64, (3, 3), padding='valid')(x)
model=Model(input_img,y)

生成器部分如下

train_datagen = ImageDataGenerator(rescale=1./255)


test_datagen = ImageDataGenerator(rescale=1./255)


train_generator=train_datagen.flow_from_directory(
'\Dipti\medical_image_comp',
    target_size=(180,180),
    batch_size=128,
    class_mode=None)

validation_generator = test_datagen.flow_from_directory(
    'D:\Dipti\medical_image_comp\scale0',
    target_size=(180,180),
    batch_size=128,
    class_mode=None)

通过以下方式拟合这个简单的网络:

history=model.fit_generator(
    train_generator,
    epochs=100,
    steps_per_epoch=training_samples/batch_size,
    validation_data=validation_generator,
    validation_steps=testing_samples/batch_size)

    The following error occurs:

纪元 1/100

   ValueError                                Traceback (most recent 
 call last)
 <ipython-input-41-bf2c0dd3bbcf> in <module>()
  4          epochs=100,
  5         validation_data=validation_generator,

- ---> 6 个验证步骤=测试样本/批次大小)

  ~\Anaconda3\lib\site-packages\keras\legacy\interfaces.py in 
  wrapper(*args, **kwargs)
    89                 warnings.warn('Update your `' + object_name +
    90                               '` call to the Keras 2 API: ' +      signature, stacklevel=2)

- --> 91 return func(*args, **kwargs) 92 wrapper._original_function = func 93 return wrapper

~\Anaconda3\lib\site-packages\keras\models.py in fit_generator(self,generator,steps_per_epoch,epochs,verbose,callbacks,validation_data,validation_steps,class_weight,max_queue_size,workers,use_multiprocessing,shuffle,initial_epoch)1254 use_multiprocessing=use_multiprocessing ,1255 shuffle=shuffle,-> 1256 initial_epoch=initial_epoch)1257 1258 @interfaces.legacy_generator_methods_support

~\Anaconda3\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs) 89 warnings.warn('更新您' + object_name + 90 '对 Keras 2 API 的调用:' + 签名,stacklevel=2) ---> 91 return func(*args, **kwargs) 92 wrapper._original_function = func 93 return wrapper

  ~\Anaconda3\lib\site-packages\keras\engine\training.py in 
  fit_generator(self, generator, steps_per_epoch, epochs, verbose, 
    callbacks, validation_data, validation_steps, class_weight, 
       max_queue_size, workers, use_multiprocessing, shuffle, 
         initial_epoch)
        2160                                          'a tuple `(x, 
          y, sample_weight)` '
           2161                                          'or `(x, 
          y)`. Found: ' +
        -> 2162                                          
        str(generator_output))
          2163                     # build batch logs
           2164                     batch_logs = {}

        ValueError: Output of generator should be a tuple `(x, y, 
       sample_weight)` or `(x, y)`. Found: [[[[1.         
    0.91372555 1.        ]
  [0.8980393  0.78823537 0.87843144]
  [0.8705883  0.7607844  0.85098046]
   ...
  [0.8313726  0.7411765  0.8117648 ]
      [0.85098046 0.7607844  0.8313726 ]
         [0.83921576 0.7490196  0.8196079 ]]

             [[0.9333334  0.8352942  0.9215687 ]
          [0.8980393  0.8000001  0.8862746 ]
           [0.9294118  0.8313726  0.9176471 ]
              ...
          [0.7803922  0.6901961  0.7607844 ]
          [0.8196079  0.7294118  0.8000001 ]
        [0.8588236  0.7686275  0.83921576]]

        [[0.9176471  0.8235295  0.909804  ]
        [0.854902   0.7607844  0.8470589 ]
        [0.8745099  0.7803922  0.86666673]
        ...
         [0.7686275  0.6784314  0.7490196 ]
        [0.79215693 0.7019608  0.7725491 ]
         [0.83921576 0.7490196  0.8196079 ]]

       ...

          [[0.81568635 0.6784314  0.7725491 ]
          [0.80392164 0.6666667  0.7607844 ]
            [0.8196079  0.68235296 0.77647066]
                ...
          [0.8470589  0.6784314  0.78823537]
          [0.8352942  0.6666667  0.77647066]
            [0.8745099  0.7058824  0.81568635]]

         [[0.7686275  0.6313726  0.7254902 ]
      [0.7607844  0.62352943 0.7176471 ]
           [0.79215693 0.654902   0.7490196 ]
           ...
            [0.8431373  0.6745098  0.7843138 ]
              [0.83921576 0.67058825 0.7803922 ]
            [0.882353   0.7137255  0.8235295 ]]

           [[0.8235295  0.6862745  0.7725491 ]
         [0.7725491  0.63529414 0.72156864]
              [0.78823537 0.6509804  0.74509805]
          ...
      [0.8588236  0.6901961  0.8000001 ]
         [0.86666673 0.69803923 0.8078432 ]
          [0.8862746  0.7176471  0.82745105]]]


        [[[0.8705883  0.8705883  0.8705883 ]
         [0.8705883  0.8705883  0.8705883 ]
        [0.8705883  0.8705883  0.8705883 ]
                   ...
          [0.8705883  0.8705883  0.8705883 ]
              [0.8705883  0.8705883  0.8705883 ]
           [0.8705883  0.8705883  0.8705883 ]]

       [[0.8705883  0.8705883  0.8705883 ]
         [0.8705883  0.8705883  0.8705883 ]
           [0.8705883  0.8705883  0.8705883 ]
         ...
        [0.8705883  0.8705883  0.8705883 ]
      [0.8705883  0.8705883  0.8705883 ]
       [0.8705883  0.8705883  0.8705883 ]]

       [[0.8705883  0.8705883  0.8705883 ]
           [0.8705883  0.8705883  0.8705883 ]
          [0.8705883  0.8705883  0.8705883 ]
            ...
         [0.8705883  0.8705883  0.8705883 ]
         [0.8705883  0.8705883  0.8705883 ]
          [0.8705883  0.8705883  0.8705883 ]]

       ...

           [[0.8705883  0.8705883  0.8705883 ]
           [0.8705883  0.8705883  0.8705883 ]
     [0.8705883  0.8705883  0.8705883 ]
       ...
           [0.8705883  0.8705883  0.8705883 ]
          [0.8705883  0.8705883  0.8705883 ]
            [0.8705883  0.8705883  0.8705883 ]]

          [[0.8705883  0.8705883  0.8705883 ]
        [0.8705883  0.8705883  0.8705883 ]
            [0.8705883  0.8705883  0.8705883 ]
           ...
          [0.8705883  0.8705883  0.8705883 ]
              [0.8705883  0.8705883  0.8705883 ]
          [0.8705883  0.8705883  0.8705883 ]]

            [[0.8705883  0.8705883  0.8705883 ]
           [0.8705883  0.8705883  0.8705883 ]
              [0.8705883  0.8705883  0.8705883 ]
            ...
                [0.8705883  0.8705883  0.8705883 ]
             [0.8705883  0.8705883  0.8705883 ]
             [0.8705883  0.8705883  0.8705883 ]]]


            [[[0.92549026 0.82745105 0.90196085]
                  [0.89019614 0.7843138  0.8588236 ]
             [0.9176471  0.8078432  0.8941177 ]
           ...
          [0.7960785  0.47450984 0.6627451 ]
        [0.76470596 0.43529415 0.627451  ]
       [0.77647066 0.44705886 0.6392157 ]]

          [[0.9058824  0.8000001  0.8745099 ]
          [0.8941177  0.7803922  0.8588236 ]
           [0.86666673 0.7411765  0.8313726 ]
             ...
                [0.80392164 0.48235297 0.67058825]
         [0.79215693 0.47058827 0.65882355]
           [0.8588236  0.5294118  0.72156864]]

         [[0.83921576 0.7254902  0.80392164]
           [0.87843144 0.75294125 0.8352942 ]
       [0.8235295  0.6901961  0.7843138 ]
     ...
        [0.8078432  0.48627454 0.6745098 ]
          [0.80392164 0.48235297 0.67058825]
        [0.8862746  0.5647059  0.75294125]]

         ...

我无法获得这样一个简单的网络。我已经建立了很多具有相同概念的模型,但是在这里这个网络无法训练。请建议我用 dsidirectory 的流程训练这样一个简单的网络的方法使用 Adam 优化器和 MSE 作为损失函数的概念。我希望你明白我的意思

先生,通过这个小网络,我只是想减小图像的大小,在训练完这个网络之后,我必须将该网络的输出应用到一个图像编解码器,并且还必须做相反的过程来生成重建的图像。然后出于测试目的,我必须比较原始图像和比较图像。因为这基本上是一个压缩任务,减少图像的大小,所以特别是我的工作不需要像分类和回归那样的标签。我想复制题为“使用卷积神经网络的端到端压缩框架”的论文的结果,这个小网络基本上是我想用它们的参数训练的第一个模块。你也可以检查一下我希望你现在了解整个问题的论文

标签: pythonmachine-learningkerasgeneratorconv-neural-network

解决方案


可能在您的生成器中,您将标签作为元组的第一个元素返回,输入图像作为第二个元素返回。交换这两个,问题就解决了。


推荐阅读