首页 > 解决方案 > Few Shot Learning / Siamese Network - 3 通道输入图像

问题描述

我试图在准备好的数据集上进行少量学习,其中包含不同的少数类和 40 个训练样本(40 次学习)。要加载我的数据,我使用了以下代码:

def list_files(startpath):
    X = []
    images = []
    full_path = []
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        l_alpha.append(dirs)
        for f in files:
        #    print('{}{}'.format(subindent, f))
        #    l_char.append(f)
            full_path.append(str(dirs)+'\\'+str(root)+'\\'+str(f))
            for pixel in f:
                #img_data = cv2.imread(str(root)+'\\'+str(f))
            #    # store loaded image
            #    loaded_images.append([img_data])
            ##X.append(np.stack(loaded_images))
            #X = np.concatenate( loaded_images]]
        #full_path = [full_path[i][2:] for i in range(len(full_path))]
            #print(full_path[i][3:])
                images = [np.array(Image.open(v[3:])) for v in full_path]
        images = [images]
    return root, dirs, files, full_path, images

这很完美,因为我正在为 shape 的输出图像获得一个形状(3267, 100, 100, 3)。我的问题在于每次收集一批数据的下一部分代码:

def get_batch(batch_size,s="train"):
    n_examples= 40
    """Create batch of n pairs, half same class, half different class"""
    if s == 'train':
        X = Xtrain
        X= X.reshape(-1,100,100,3)
        #X= X.reshape(-1,20,105,105)
        categories = train_classes
    else:
        X = Xval
        X= X.reshape(-1,100,100,3)
        categories = val_classes
    #n_classes, n_examples, w, h, chan = X.shape
    tot_examples, w, h, chan = X.shape
    
    n_classes = 51#tot_examples / len(full_path) *100
    
    # randomly sample several classes to use in the batch
    #categories = rng.choice(n_classes,size=(batch_size,),replace=False)
    categories = rng.choice(int(n_classes),size=(batch_size,),replace=True)
    
    # initialize 2 empty arrays for the input image batch
    #pairs=[np.zeros((batch_size, h, w,1)) for i in range(2)]
    pairs=[np.zeros((batch_size, h, w, chan)) for i in range(2)]
    
    # initialize vector for the targets
    targets=np.zeros((batch_size,))
    
    # make one half of it '1's, so 2nd half of batch has same class
    targets[batch_size//2:] = 1
    for i in range(batch_size):
        category = categories[i]
        print(category)
        idx_1 = rng.randint(0, n_examples)#
        pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
        idx_2 = rng.randint(0, n_examples)
        
        # pick images of same class for 1st half, different for 2nd
        if i >= batch_size // 2:
            category_2 = category  
        else: 
            # add a random number to the category modulo n classes to ensure 2nd image has a different category
            category_2 = (category + rng.randint(1,n_classes)) % n_classes
        
        pairs[1][i,:,:,:] = X[category_2,idx_2].reshape(w, h,1)
    
    return pairs, targets

我得到的回溯是:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-129-a37505d52b27> in <module>
----> 1 (inputs,targets) = get_batch(batch_size)

<ipython-input-128-e55ce72910a7> in get_batch(batch_size, s)
     33         print(category)
     34         idx_1 = rng.randint(0, n_examples)#
---> 35         pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
     36         idx_2 = rng.randint(0, n_examples)
     37 

ValueError: cannot reshape array of size 300 into shape (100,100,3)

我理解错误,但我没有看到问题。由于示例的数量对应于指定的 40 个。有人可以帮我解决我的代码问题所在吗?我做了一些评论,这可能有助于更好地理解代码。先感谢您。

标签: pythontensorflowkeras

解决方案


由于我可以跟踪 的形状X,您之前已将其重塑为(-1, 100, 100, 3)

在出现错误的行中,pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)X在第一和第二维度上进行切片。因此,它的形状X[category, idx_1](100, 3)使其大小为 300。因此无法将其重塑为(100, 100, 3).


推荐阅读