python - Few Shot Learning / Siamese Network - 3 通道输入图像
问题描述
我试图在准备好的数据集上进行少量学习,其中包含不同的少数类和 40 个训练样本(40 次学习)。要加载我的数据,我使用了以下代码:
def list_files(startpath):
X = []
images = []
full_path = []
for root, dirs, files in os.walk(startpath):
level = root.replace(startpath, '').count(os.sep)
indent = ' ' * 4 * (level)
print('{}{}/'.format(indent, os.path.basename(root)))
subindent = ' ' * 4 * (level + 1)
l_alpha.append(dirs)
for f in files:
# print('{}{}'.format(subindent, f))
# l_char.append(f)
full_path.append(str(dirs)+'\\'+str(root)+'\\'+str(f))
for pixel in f:
#img_data = cv2.imread(str(root)+'\\'+str(f))
# # store loaded image
# loaded_images.append([img_data])
##X.append(np.stack(loaded_images))
#X = np.concatenate( loaded_images]]
#full_path = [full_path[i][2:] for i in range(len(full_path))]
#print(full_path[i][3:])
images = [np.array(Image.open(v[3:])) for v in full_path]
images = [images]
return root, dirs, files, full_path, images
这很完美,因为我正在为 shape 的输出图像获得一个形状(3267, 100, 100, 3)
。我的问题在于每次收集一批数据的下一部分代码:
def get_batch(batch_size,s="train"):
n_examples= 40
"""Create batch of n pairs, half same class, half different class"""
if s == 'train':
X = Xtrain
X= X.reshape(-1,100,100,3)
#X= X.reshape(-1,20,105,105)
categories = train_classes
else:
X = Xval
X= X.reshape(-1,100,100,3)
categories = val_classes
#n_classes, n_examples, w, h, chan = X.shape
tot_examples, w, h, chan = X.shape
n_classes = 51#tot_examples / len(full_path) *100
# randomly sample several classes to use in the batch
#categories = rng.choice(n_classes,size=(batch_size,),replace=False)
categories = rng.choice(int(n_classes),size=(batch_size,),replace=True)
# initialize 2 empty arrays for the input image batch
#pairs=[np.zeros((batch_size, h, w,1)) for i in range(2)]
pairs=[np.zeros((batch_size, h, w, chan)) for i in range(2)]
# initialize vector for the targets
targets=np.zeros((batch_size,))
# make one half of it '1's, so 2nd half of batch has same class
targets[batch_size//2:] = 1
for i in range(batch_size):
category = categories[i]
print(category)
idx_1 = rng.randint(0, n_examples)#
pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
idx_2 = rng.randint(0, n_examples)
# pick images of same class for 1st half, different for 2nd
if i >= batch_size // 2:
category_2 = category
else:
# add a random number to the category modulo n classes to ensure 2nd image has a different category
category_2 = (category + rng.randint(1,n_classes)) % n_classes
pairs[1][i,:,:,:] = X[category_2,idx_2].reshape(w, h,1)
return pairs, targets
我得到的回溯是:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-129-a37505d52b27> in <module>
----> 1 (inputs,targets) = get_batch(batch_size)
<ipython-input-128-e55ce72910a7> in get_batch(batch_size, s)
33 print(category)
34 idx_1 = rng.randint(0, n_examples)#
---> 35 pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
36 idx_2 = rng.randint(0, n_examples)
37
ValueError: cannot reshape array of size 300 into shape (100,100,3)
我理解错误,但我没有看到问题。由于示例的数量对应于指定的 40 个。有人可以帮我解决我的代码问题所在吗?我做了一些评论,这可能有助于更好地理解代码。先感谢您。
解决方案
由于我可以跟踪 的形状X
,您之前已将其重塑为(-1, 100, 100, 3)
。
在出现错误的行中,pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
将X
在第一和第二维度上进行切片。因此,它的形状X[category, idx_1]
将(100, 3)
使其大小为 300。因此无法将其重塑为(100, 100, 3)
.
推荐阅读
- kubernetes - 如何仅在工作节点上安排 gitlab-runner pod?
- angular - 由于“Angular Dependency Injection”错误,测试平台声明中不接受指令
- manim - Manim 示例场景不运行
- java - 如何记录用户操作并在他返回时重播?
- python - Python:Anaconda 和 SciPy 之间的差异。我可以同时安装吗?我可以在 Anaconda 中包含 SciPy 吗?
- firebase - 颤振:检索firebase集合中存在的多个文档ID
- django - django 管理面板中的 TextField 不支持格式
- linux - Linux:如果删除快捷方式/链接,如何删除原始文件
- javascript - 计算 C3js 水平条形图的填充
- string - 如何在不计算多余空格的情况下计算 Typescript 中字符串中的单词数?