python - TF 数据集是否有等效的 tf.gather() ?
问题描述
我目前正在尝试使用从 h5py 文件加载的预计算词嵌入。嵌入是为数据集中的每个示例预先计算的,因此我试图通过它们的示例/序列 ID 检索嵌入。但是,嵌入非常大,所以我遇到了一个问题,我不能直接在嵌入上运行 tf.gather() 来获取我想要的嵌入,因为 TF 不会t 生成大于 2GB 的张量。结果,我尝试使用以下代码:
# precompute_ds is just the tensor of word embeddings
precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
precompute_place = tf.placeholder(precompute_ds.dtype,
shape=precompute_ds.shape)
word_emb = tf.gather(precompute_place, sequence_ids)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(word_emb, feed_dict={precompute_place: precompute_ds})
return word_emb
但是,由于precompute_ds
是 h5py 数据集,我不确定如何为其初始化迭代器并得到以下错误:
FailedPreconditionError (see above for traceback): GetNext() failed
because the iterator has not been initialized. Ensure that you have run
the initializer operation for this iterator before getting the next element.
[[{{node IteratorGetNext}} = IteratorGetNext[output_shapes=
[[?], [?], [?]], output_types=[DT_INT64, DT_INT64, DT_INT64],
_device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]
所以,我也尝试使用以下代码,遵循TF 网站上的这个示例:
precompute_ds = h5py.File(kwargs['precompute_path'], 'r')['precomputed']
precompute_place = tf.placeholder(precompute_ds.dtype,
shape=precompute_ds.shape)
ds = tf.data.Dataset.from_tensor_slices(precompute_place)
word_emb = tf.gather(ds, sequence_ids)
it = ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(it.initializer, feed_dict={precompute_place: precompute_ds})
return word_emb
但是,这有两个问题:一方面,我很确定即使tf.gather
确实在 TF 数据集上工作,这也不会正确填充word_emb
. 我现在在想的是我可以ds
使用第二种方法正确填充,但是我不知道如何准确地获得sequence_ids
我想要的这个特定批次。对于这两种方法中的任何一种,是否有任何建议可以使其正常工作?
谢谢!
解决方案
推荐阅读
- php - 如何通过使用 PHP GD 合并两个图像来居中和对齐文本
- vhdl - VHDL 状态机问题 - 重复状态
- go - Save generic struct to redis
- c# - How to filter aspnet core logs?
- php - 在开发工具中获取影响 CSS 样式的随机 html 元素
- angularjs - Passing parameters from input bar
- python - Incorrect post request for nasdaq.com
- python - Issues importing pytorch with conda
- php - Display file info?
- java - Spring Boot + Spring Data + Spring MVC 编码