首页 > 解决方案 > IndexError:我的 Keras model.fit_generator 中的元组索引超出范围

问题描述

我对 Keras/Tensorflow 完全陌生。下面是我的 fit_generator

train_dataset=train_fn_inputs(batch_size, None) 
val_data=validation_fn_inputs(batch_size, None) 
total_records = 44712  
val_records = 11178  
steps_per_epoch=int(total_records // batch_size)

hist=model.fit_generator(#aug.flow(X_def, y_def, batch_size=batch_size), 
               #get_batches(X_def, y_def, batch_size), 
               train_dataset,
               steps_per_epoch=steps_per_epoch, #(training_df.shape[0])//batchsize,
               epochs=5,
               verbose = 1,
               #callbacks=[early_stopping],
               #validation_data=val_data, 
               #validation_steps=val_records//batch_size,
               workers=0
          )

函数定义为:

def train_fn_inputs(bs,  aug=None):
    train_files, total_records = get_training_data_old()
    steps_per_epoch = int(total_records / batch_size)    
    raw_dataset = tf.data.TFRecordDataset(train_files)     #.repeat()
    parsed_image_dataset = raw_dataset.map(_parse_image_function).repeat().shuffle(buffer_size=buf_size).batch(batch_size).make_initializable_iterator()  

    image, label = parsed_image_dataset.get_next()
    image = tf.reshape(image, [3, IMG_WIDTH, IMG_HEIGHT, bs])
    #label = tf.reshape(label, [bs, 75, 25])

    while True:
       yield (np.array(image), np.array(label))

但是我收到了这个错误:

文件“...\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_generator.py”,第 184 行,在 model_iteration batch_size = int(nest.flatten(batch_data)[0].shape[0] )

IndexError:元组索引超出范围

标签: tensorflowkeras

解决方案


推荐阅读