首页 > 解决方案 > tensorflow 数据集在过滤后无法计算出 epoch 中的训练步数

问题描述

我创建一个tensorflow dataset然后过滤它以将其拆分为训练和测试集,如下所示:

test_index, train_index = split_data_for_train_and_test(text_data, label_data, test_ratio)
y = to_categorical(label_data)
dataset = tf.data.Dataset.from_tensor_slices((sequences_matrix, Y.astype(np.int8)))
       
dataset = dataset.cache()
dataset = dataset.enumerate()
    
@tf.function
def filter_train_function(i, _):
  return tf.py_function(lambda i: i in train_index, inp=[i], Tout=tf.bool)
    
@tf.function
def filter_test_function(i, _):
   return tf.py_function(lambda i: i in test_index, inp=[i], Tout=tf.bool)
    
train_dataset = dataset \
            .filter(filter_train_function) \
            .map(lambda i, data: data)
test_dataset = dataset \
            .filter(filter_test_function) \
            .map(lambda i, data: data)
history = deep_model.fit(train_dataset.batch(batch_size), epochs=epoch_num,
                                     verbose=1, initial_epoch=initial_epoch,
                                     validation_data=test_dataset.batch(batch_size),
                                     callbacks=callbacks)

问题是在将它们拟合到深度模型中时过滤数据集后,模型无法计算出步骤和打印的数量1/unknown2/unknown

 29/Unknown - 24s 824ms/step - loss: 2.2388 - accuracy: 0.2600

标签: pythontensorflowtensorflow-datasets

解决方案


推荐阅读