首页 > 解决方案 > 简单最近质心分类器中的 numpy

问题描述

问题:如何在 (1) 效率、(2) 内存和 (3) 样式/简单性方面改进以下代码?

train_means = np.mean(train_data.reshape(784, 10, 1000), axis=2)
train_classes, test_classes = (np.argmin(np.sum(np.square(data[:,:,np.newaxis] - train_means[:,np.newaxis,:]), axis=0), axis=1) for data in (train_data, test_data))
train_acc, test_acc = (np.mean(np.equal(classes, np.repeat(np.arange(10), classes.size // 10))) for classes in (train_classes, test_classes))

以下是变量和问题的说明:

# train_data: shape of (784, 10000)
#  test_data: shape of (784, 1000)
#
# In both cases, each column is a vectorization of a 28 x 28 image of a handwritten digit
# (0-9) where the correct classification is 0 for the first tenth of the columns, 1 for
# the second tenth, and so on. The goal is to write a nearest centroid classifier and
# output its accuracy on each data set.

例如,我担心的一个问题是在减法中创建data[:,:,np.newaxis] - train_means[:,np.newaxis,:]了一个(784, 10000, 10)数组,这比我需要使用的内存要多。我可以避免分配这么多内存而不牺牲任何效率(理想情况下是代码的任何简单性)吗?

另一个问题是我用来将程序应用于训练数据和测试数据的理解。这会被鼓励还是被认为是令人费解的(或者这可能只是无关的个人偏好)?

背景:这是我尝试编写一个最近质心分类器的 numpy 优化版本,以对来自 MNIST 手写数字数据集的一些图像进行分类。我对 numpy 有点陌生,并且对在广播和矢量化操作的帮助下可以如此简洁地编写这段代码感到惊讶,但我想知道我是否仍然错过了一些可能重要的改进。

标签: pythonnumpyvectorization

解决方案


推荐阅读