首页 > 解决方案 > 如何为神经网络制作一个最小且可重复的示例?

问题描述

我想知道如何为 Stack Overflow 制作一个最小且可重现的深度学习示例。我想确保人们有足够的信息来查明我的代码的确切问题。仅提供追溯就足够了吗?

    c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training_utils.py 
                         in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
        135                             ': expected ' + names[i] + ' to have shape ' +
        136                             str(shape) + ' but got array with shape ' +
    --> 137                             str(data_shape))
        138     return data
        139 

还是我应该简单地发布错误消息?

值错误:检查输入时出错:预期的 dense_1_input 具有形状 (4,) 但得到的数组具有形状 (1,)

标签: pythontensorflowkerasneural-networkpytorch

解决方案


以下是制作可重现的最小深度学习示例的一些技巧。无论是对于Keras,Pytorch还是, 这都是很好的建议Tensorflow

  • 我们无法使用您的数据,但在大多数情况下,这并不重要。我们需要的只是正确的形状。
    • 使用随机生成的正确形状的数字。
      • 例如,np.random.randint(0, 256, (100, 30, 30, 3)对于 100 张30x30大小的彩色图片
      • 例如,np.random.choice(np.arange(10), 100)对于 10 个类别的 100 个样本
  • 我们不需要查看您的整个管道。
    • 只提供运行代码的最低要求。
  • 充分利用Keras其调试能力。
    • 包括回溯。它很可能会指出确切的问题。
  • 神经网络都是关于拟合正确的形状。
    • 至少,始终提供输入形状。
  • 使其易于测试和重现。
    • 发布您的整个神经网络架构。
    • 包括您的库导入。定义所有变量。

这是一个完美的最小且可重现的示例:


“我有一个错误。当我运行这段代码时,它给了我这个错误:”

ValueError:检查目标时出错:预期dense_2具有形状(10,)但得到了具有形状的数组

“这是我的架构,带有生成的数据:”

import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D

xtrain, xtest = np.random.rand(2, 1000, 30, 30, 3)
ytrain, ytest = np.random.choice(np.arange(10), 2000).reshape(2, 1000) 

model = Sequential([
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=xtrain.shape[1:]),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')])

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adam(),
              metrics=['accuracy'])

model.fit(xtrain, ytrain,
          batch_size=16,
          epochs=10,
          validation_data=(xtest, ytest))

推荐阅读