首页 > 解决方案 > 不理解类部分并通过读取 h5 数据集文件进行重塑

问题描述

您好,有人可以逐步解释以下代码中发生了什么吗?特别是零件类和重塑?tnx

def load_data():
    train_dataset = h5py.File('datasets/train_catvnoncat.h5', "r")
    train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
    train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels

    test_dataset = h5py.File('datasets/test_catvnoncat.h5', "r")
    test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
    test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels

    classes = np.array(test_dataset["list_classes"][:]) # the list of classes

    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))

    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes

标签: numpyipythonh5py

解决方案


大多数行只是datasetsh5文件中加载。np.array(...)不需要包装器 。test_dataset[name][:]足以加载数组。

test_set_y_orig = test_dataset["test_set_y"][:]

test_dataset是打开的文件。 test_dataset["test_set_y"]dataset那个文件上的一个。将[:]数据集加载到numpy数组中。查找h5py文档以获取有关加载dataset.

我从

train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))

加载的数组是 1d,形状为(n,),而此重塑只是添加一个初始维度,使其成为(1,n)。我会把它编码为

train_set_y_orig = train_set_y_orig[None,:]

但结果是一样的。

数组没有什么特别之处classes(尽管它很可能是一个字符串数组)。


推荐阅读