python - 线程 Thread-2 中的异常:Python 机器学习错误:Tensorflow 列表超出范围
问题描述
我正在尝试为我的数据制作分类器。但作为机器学习和 Python 方面的新手,我不断收到一个我无法弄清楚的奇怪错误。我的代码是
代码
from sklearn.preprocessing import OneHotEncoder
import tensorflow as tf
import numpy as np
import scipy.io as cio
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpg
from random import shuffle
import tflearn
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
import cv2
a = cio.loadmat("D:/compCarsThesisData/data/misc/make_model_name.mat")
images = "D:/compCarsThesisData/data/image/"
IMG_SIZE = 64
MODEL_NAME = 'Classification'
LR = 1e-3
b = a['make_names']
d = []
for i in range(b.size):
d.append(b[i][0][0])
print(d)
labels_dic = {v: k for v, k in enumerate(d)}
print(labels_dic)
indices = np.arange(163)
depth = 163
y = tf.one_hot(indices,depth)
sess = tf.Session()
result = sess.run(y)#,feed_dict=None,options=None, run_metadata=labels_dic)
print(result)
labels = []
labels.append((result,labels_dic))
print(labels)
for root, _, files in os.walk(images):
cdp = os.path.abspath(root)
for f in files:
name,ext = os.path.splitext(f)
if ext == ".jpg":
cip = os.path.join(cdp,f)
ci = mpg.imread(cip)
image = cv2.cv2.resize(ci,(IMG_SIZE,IMG_SIZE))
image = np.array(image)
print(image)
training_data = []
training_data.append((image,labels))
shuffle(training_data)
np.save('training_data_make_model', training_data)
testing_data = []
testing_data.append((image,labels))
print("TestingDATA",testing_data)
shuffle(testing_data)
# if the data already created
# training_data = np.load('training_data_make_model.npy')
# testing_data = np.load('training_data_make_model.npy')
train = training_data[:-50000]
test = testing_data[-50000:]
X_train = np.array([i[0] for i in train]).reshape(-1, IMG_SIZE, IMG_SIZE, 3)
y_train = [i[1] for i in train]
X_test = np.array([i[0] for i in test]).reshape(-1, IMG_SIZE, IMG_SIZE, 3)
y_test = [i[1] for i in test]
print("YTEST",y_test)
tf.reset_default_graph()
convnet = input_data(shape=[None,IMG_SIZE,IMG_SIZE,3],name='input')
convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 128, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)
convnet = fully_connected(convnet, 2, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet, tensorboard_dir='log', tensorboard_verbose=0)
model.fit({'input': X_train}, {'targets': y_train}, n_epoch=10,
validation_set=0.1,
snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
.mat 文件包含单元格数组,因为现在我使用 ['make_names'] 将汽车名称作为标签并将它们转换为热编码并将它们与训练数据一起附加,如上所示。
.mat 文件
带有标签和 .mat 文件的附加数据如下所示
带有一个热编码的标签
[[1. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
...
[0. 0. 0. ... 1. 0. 0.]
[0. 0. 0. ... 0. 1. 0.]
[0. 0. 0. ... 0. 0. 1.]]
[(array([[1., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.]], dtype=float32), {0: 'ABT', 1: 'BAC', 2: 'Conquest', 3: 'DS', 4: 'Dacia', 5: 'Fisker', 6: 'GMC', 7: 'Gumpert', 8: 'Hennessey', 9: 'Icona', 10: 'Jeep', 11: 'KTM', 12: 'MELKUS', 13: 'MG', 14: 'MINI', 15: 'Mazzanti', 16: 'Noble', 17: 'PGO', 18: 'SPIRRA', 19: 'SSC', 20: 'Scion', 21: 'TESLA', 22: 'TVR', 23: 'Tramontana', 24: 'Zenvo', 25: 'smart', 26: 'Yiqi', 27: 'Mitsubishi', 28: 'Shangqidatong', 29: 'Spyker N.V.', 30: 'Dongnan', 31: 'Dongfeng', 32: 'Dongfengxiaokang', 33: 'Dongfengfengdu', 34: 'Dongfengfengshen', 35: 'Dongfengfengxing', 36: 'Zxauto', 37: 'Zhonghua', 38: 'Toyota', 39: 'Zinoro', 40: 'Jiulong', 41: 'Isuzu', 42: 'Wuling', 43: 'AC Chnitzer', 44: 'Zoyte', 45: 'Iveco', 46: 'Bufori', 47: 'Porsche', 48: 'Mitsuoka', 49: 'Chrysler', 50: 'Lamorghini ', 51: 'Kombat', 52: 'Cadillac', 53: 'Buck', 54: 'Lifan', 55: 'Lorinser', 56: 'Rolls-Royce', 57: 'BAW', 58: 'Baihc', 59: 'Beiqiweiwang', 60: 'Beiqihuansu', 61: 'Beiqi New Energy', 62: 'Huapu', 63: 'Huatai', 64: 'Huaqi', 65: 'Carlsson', 66: 'Shuanghuan', 67: 'Shuanglong', 68: 'Geely', 69: 'Venucia', 70: 'Haval', 71: 'Hafei', 72: 'Volkswagen', 73: 'Daihatsu', 74: 'Chrey', 75: 'Besturn', 76: 'Benz', 77: 'Audi', 78: 'Wisemann', 79: 'Wealeak', 80: 'BWM', 81: 'Baojun', 82: 'Bentley', 83: 'Brabus', 84: 'Bugatti', 85: 'Pagani', 86: 'Guangqichuanqi', 87: 'GAC', 88: 'Karry', 89: 'Ciimo', 90: 'CHTC', 91: 'Jaguar', 92: 'Morgan', 93: 'Subaru', 94: 'Skoda', 95: 'Xinkai', 96: 'Nissan', 97: 'Changhe', 98: 'RANZ', 99: 'Honda', 100: 'Lincoln', 101: 'Peugeot', 102: 'Opel', 103: 'Oley', 104: 'BYD', 105: 'Jonway', 106: 'Huizhong', 107: 'Jianghuai', 108: 'Jiangling', 109: 'Vauxhall', 110: 'Volvo', 111: 'Ferrari', 112: 'Haige', 113: 'Haima', 114: 'Haima(Zhengzhou)', 115: 'Cheetah', 116: 'Maserati', 117: 'Hyundai ', 118: 'Everus', 119:
'Ruiqi', 120: 'Fuqiqiteng', 121: 'Ford', 122: 'Futian', 123: 'Fudi', 124: 'Koenigsegg', 125: 'HongQi', 126: 'Luxgen', 127: 'SAAB', 128: 'Denza', 129: 'Yingzhi', 130: 'Infiniti', 131: 'Roewe', 132: 'Lotus', 133: 'FIAT', 134: 'Saab', 135: 'Lancia', 136: 'Seat', 137: 'Qoros', 138: 'Acura', 139: 'KIA', 140: 'Lotus', 141: 'LAND-ROVER', 142: 'McLaren', 143: 'Maybach', 144: 'Dodge', 145: 'Mustang', 146: 'Jinlv', 147: 'Jinbei', 148: 'Suzuki', 149: 'GreatWall', 150: 'Changan Business', 151: 'Changan', 152: 'Alfa Romeo', 153: 'Aston Martin', 154: 'Lufeng', 155: 'Shanqitongjia', 156: 'Chevy', 157: 'Citroen', 158: 'Lexus', 159: 'Renault', 160: 'Shouwang', 161: 'MAZDA', 162: 'Huanghai'})]
但是每次我尝试运行它时,它都会给我一个错误。
错误
Run id: Classification
Log directory: log/
Exception in thread Thread-2:
Traceback (most recent call last):
File "C:\Users\zeele\Miniconda3\lib\threading.py", line 916, in _bootstrap_inner
self.run()
File "C:\Users\zeele\Miniconda3\lib\threading.py", line 864, in run
self._target(*self._args, **self._kwargs) 1, in fill_batch_ids_queue
File "C:\Users\zeele\Miniconda3\lib\site-packages\tflearn\data_flow.py", line 201, in fill_batch_ids_queue 5, in next_batch_ids
ids = self.next_batch_ids()
File "C:\Users\zeele\Miniconda3\lib\site-packages\tflearn\data_flow.py", line 215, in next_batch_ids
batch_start, batch_end = self.batches[self.batch_index]
IndexError: list index out of range
解决方案
推荐阅读
- javascript - 使用 Javascript 中的新查询解析服务器更新实时查询订阅
- c++ - 不同/较旧的处理器是否以不同的方式运行 c++ 代码?
- linux - 为什么Linux内核的arm端口在上下文切换时没有备份“cpsr”寄存器?
- javascript - Npm ES 模块包不适用于 Express / Babel
- visual-studio-code - 如何在 Visual Studio Code 的母版页中列出和编辑链接的 js/css 文件?
- java - 使用 JSOUP 节点的属性获取元素内容
- wordpress - 如果流量来源来自 Google 搜索,如何在 Wordpress URL 末尾添加#anchor
- mysql - 尝试使用参数将 datagridview 行集合插入数据库时出错
- algorithm - 一半的假设是如何有 i < j 的?
- laravel - Laravel 单元测试断言函数总是失败