首页 > 解决方案 > TF 数据集是否有等效的 tf.gather() ?

问题描述

我目前正在尝试使用从 h5py 文件加载的预计算词嵌入。嵌入是为数据集中的每个示例预先计算的,因此我试图通过它们的示例/序列 ID 检索嵌入。但是,嵌入非常大,所以我遇到了一个问题,我不能直接在嵌入上运行 tf.gather() 来获取我想要的嵌入,因为 TF 不会t 生成大于 2GB 的张量。结果,我尝试使用以下代码:

  # precompute_ds is just the tensor of word embeddings
  precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
  precompute_place = tf.placeholder(precompute_ds.dtype, 
                                    shape=precompute_ds.shape)
  word_emb = tf.gather(precompute_place, sequence_ids)
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(word_emb, feed_dict={precompute_place: precompute_ds})

return word_emb

但是,由于precompute_ds是 h5py 数据集,我不确定如何为其初始化迭代器并得到以下错误:

FailedPreconditionError (see above for traceback): GetNext() failed
because the iterator has not been initialized. Ensure that you have run
the initializer operation for this iterator before getting the next element.
     [[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=
[[?], [?], [?]], output_types=[DT_INT64, DT_INT64, DT_INT64], 
_device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]

所以,我也尝试使用以下代码,遵循TF 网站上的这个示例:

  precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
  precompute_place = tf.placeholder(precompute_ds.dtype, 
                                    shape=precompute_ds.shape)
  ds = tf.data.Dataset.from_tensor_slices(precompute_place)
  word_emb = tf.gather(ds, sequence_ids)
  it = ds.make_initializable_iterator()
  with tf.Session() as sess:
    sess.run(it.initializer, feed_dict={precompute_place: precompute_ds})

return word_emb

但是,这有两个问题:一方面,我很确定即使tf.gather确实在 TF 数据集上工作,这也不会正确填充word_emb. 我现在在想的是我可以ds使用第二种方法正确填充,但是我不知道如何准确地获得sequence_ids我想要的这个特定批次。对于这两种方法中的任何一种,是否有任何建议可以使其正常工作?

谢谢!

标签: pythontensorflowmachine-learningh5py

解决方案


推荐阅读