首页 > 解决方案 > 如何按文本长度过滤Tensorflow TextLineDataset

问题描述

我想过滤为 3 <= length_of_text <=15 但我不能这样做。

import tensorflow as tf

dataset = tf.data.TextLineDataset("data.txt")


def drop_outliers(line):

    return (3<= tf.size(line) <=15).numpy()

dataset = dataset.filter(lambda line: tf.py_function(func = drop_outliers,
                                                inp=[line],
                                                Tout = tf.bool))

iterator = iter(dataset)
print(iterator.get_next())

运行此代码时出现“序列结束”错误。

标签: pythontensorflowpipeline

解决方案


官方文档Iterator.get_next(),您会看到OutOfRangeError到达序列末尾的时间,

Raises tf.errors.OutOfRangeError:如果已经到达迭代器的末尾。

所以,错误不是因为TextLineDatasetor dataset.filter()。你可以使用dataset.as_numpy_iterator()喜欢,

out = list( dataset.as_numpy_iterator() )

或者用块包围dataset.get_next()方法,try except

for i in range( seq_length ):
   try:
       element = iterator.get_next()
   except tf.errors.OutOfRangeError:
       print( "End of sequence reached" )
       break

推荐阅读