首页 > 解决方案 > TypeError:图像数据的形状无效(3072,)

问题描述

这是我的事:

我不想在 colab 上运行,而是想使用 colab 中的代码读取本地 CIFAR10 数据集来播放CNN。首先,我成功下载了 CIFAR10 数据集。然后我用下面的代码来阅读它:

import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# %matplotlib inline


img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000,3072)
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    # Load the raw CIFAR-10 data
    cifar10_dir = './cifar10/'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    x_train = X_train.astype('float32')
    x_test = X_test.astype('float32')

    x_train /= 255.0
    x_test /= 255.0

    return x_train, y_train, X_val, y_val, x_test, y_test


# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()enter code here

然后,为了显示数据集中的图像,我使用了我提到的链接中的原始代码:

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(classes[y_train[i][0]])
plt.show()

最后,出乎意料的是,它给出了一个错误说:

    runfile('F:/Google Drive/DCM_Image_AI/untitled1.py', wdir='F:/Google Drive/DCM_Image_AI')
Traceback (most recent call last):

  File "F:\Google Drive\DCM_Image_AI\untitled1.py", line 85, in <module>
    plt.imshow(x_train[i], cmap=plt.cm.binary)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\pyplot.py", line 2677, in imshow
    None else {}), **kwargs)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\__init__.py", line 1599, in inner
    return func(ax, *map(sanitize_sequence, args), **kwargs)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\cbook\deprecation.py", line 369, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\cbook\deprecation.py", line 369, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\axes\_axes.py", line 5679, in imshow
    im.set_data(X)

  File "C:\Users\liuji\Anaconda3\envs\Face_ recognition\lib\site-packages\matplotlib\image.py", line 690, in set_data
    .format(self._A.shape))

TypeError: Invalid shape (3072,) for image data

 任何人都可以帮助我解决这个问题。非常感谢。

标签: pythontensorflowmatplotlib

解决方案


首先,我意识到您将像素值除以 255。评论这些行。

x_train /= 255.0
x_test /= 255.0

之后像这样重塑你的形象

np.reshape(image, (32, 32, 3))

这应该有效。


推荐阅读