首页 > 解决方案 > 如何按组选择随机索引?

问题描述

我正在使用 MNIST 数据库,其中我们有图像像素数组 (x_train) 和相应的图像标签 (y_train)。如何为每个数字标签选择一个随机像素阵列?

到目前为止,我能够为 x_train 或 y_train 选择随机值。然而问题是,选择不是考虑每个组一次,而是随机的。

import tensorflow as tf
import numpy as np
import random
from random import randint
import numpy_indexed as npi

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

labels = npi.group_by(y_train).split(y_train)
print(labels)

加载数据库后,我们可以对标签进行分组。我们看到我们有以下标签:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]。

Output: 
[array([0, 0, 0, ..., 0, 0, 0], dtype=uint8), array([1, 1, 1, ..., 1, 1, 1], dtype=uint8), array([2, 2, 2, ..., 2, 2, 2], dtype=uint8), array([3, 3, 3, ..., 3, 3, 3], dtype=uint8), array([4, 4, 4, ..., 4, 4, 4], dtype=uint8), array([5, 5, 5, ..., 5, 5, 5], dtype=uint8), array([6, 6, 6, ..., 6, 6, 6], dtype=uint8), array([7, 7, 7, ..., 7, 7, 7], dtype=uint8), array([8, 8, 8, ..., 8, 8, 8], dtype=uint8), array([9, 9, 9, ..., 9, 9, 9], dtype=uint8)]

我的目标是从 10 个组中选择 10 个随机索引并选择相应的标签和像素数组。

Desired Output:
Set of 10 Images: [(array([40707]), array([[[  0,   0,  ...  0,   0]]], dtype=uint8), array([6], dtype=uint8)), ...

在这种情况下,我们将有索引:[40707]、[像素阵列]、标签:[6]。

到目前为止,我无法限制每个标签选择 10 个随机索引。

# Return a list of 10 random indices as listindex
def digit_indices_randselect():
    listi = []
    for i in range(10):
        i = np.random.choice(np.arange(0, len(y_train)), size = (1,))
        listi.append(i)
    return listi
listindex = digit_indices_randselect()
print('Random list of indices:', listindex)

# For every index in listindex return the corresponding index, pixel array and label

def array_and_label_for_digit_indices_randselect():
    listi = []
    digit_data = []
    labels = []
    for i in listindex:
        digit_array = x_train[i] #digit data (image array) is the data from index i
        label = y_train[i] #corresponding label
        listi.append(i)
        digit_data.append(digit_array)
        labels.append(label)
    list3 = list(zip(listi, digit_data, labels))
    return list3
array_and_label_for_digit_indices_randselect()

如何限制每组的索引选择?或者如何以某种方式拆分数组,以便我可以选择 10 个组并保留原始索引?

标签: pythonlistsplitgroup-bynumpy-ndarray

解决方案


以下代码将为您提供每组的 10 个随机索引

def get_random_indxs(group_value, count=10):
    train_indxs = np.arange(len(y_train), dtype=np.int32)
    group_indxs = train_indxs[y_train == group_value]
    return np.random.choice(group_indxs,count)

for group_val in np.unique(y_train):
    print(get_random_indxs(group_val))


推荐阅读