首页 > 解决方案 > 如何将 dataset_from_generator 与熊猫数据一起使用(Tensorflow)

问题描述

最初我使用 dataset_from_tensor_slices,但由于数据太大(超过 2gb 限制),我不得不使用 dataset_from_generator。这是我的输入函数,但我收到此错误:features should be a dictionary of Tensors. 给定类型:“class 'tensorflow.python.framework.ops.Tensor'”。从张量切片更改为生成器之前,没有抛出错误(直到 2gb 限制)。特征和标签都是 pandas 数据帧。

def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):



features = {key:np.array(value) for key,value in dict(features).items()}                                             
print(type(features))
print(type(targets))

def gen():
    for feature, target in (features, targets):
        yield feature, target
ds = Dataset.from_generator(gen,(tf.float32,tf.float32), (tf.TensorShape([]), tf.TensorShape([None])))
ds = ds.batch(batch_size).repeat(num_epochs)


if shuffle:
  ds = ds.shuffle(10000)



features, labels = ds.make_one_shot_iterator().get_next()

return features, labels

标签: pythonpandastensorflow

解决方案


推荐阅读