python - 使用 Tensorflow-Hub 中的 ELMo 时显着增加内存消耗
问题描述
我目前正在尝试比较数百万个文档的相似性。对于 CPU 上的第一次测试,我将它们减少到每个大约 50 个字符,并尝试一次获得 10 个字符的 ELMo 嵌入,如下所示:
ELMO = "https://tfhub.dev/google/elmo/2"
for row in file:
split = row.split(";", 1)
if len(split) > 1:
text = split[1].replace("\n", "")
texts.append(text[:50])
if i == 300:
break
if i % 10 == 0:
elmo = hub.Module(ELMO, trainable=False)
executable = elmo(
texts,
signature="default",
as_dict=True)["elmo"]
vectors = execute(executable)
texts = []
i += 1
然而,即使是这个小例子,在大约 300 个句子(甚至不保存向量)之后,程序也会消耗高达 12GB 的 RAM。这是一个已知问题(我发现的其他问题提出了类似的问题,但不是那么极端)还是我犯了错误?
解决方案
我想这是针对没有 Eager 模式的 TensorFlow 1.x(否则使用 hub.Module 可能会遇到更大的问题)。
在该编程模型中,您需要首先在 TensorFlow 图中表达您的计算,然后为每批数据重复执行该图。
构建模块
hub.Module()
并将其应用于将输入张量映射到输出张量都是图构建的一部分,并且应该只发生一次。输入数据的循环应该只调用 session.run() 来提供输入并从固定图中获取输出数据。
幸运的是,已经有一个实用函数可以为您完成所有这些工作:
import numpy as np
import tensorflow_hub as hub
# For demo use only. Extend to your actual I/O needs as you see fit.
inputs = (x for x in ["hello world", "quick brown fox"])
with hub.eval_function_for_module("https://tfhub.dev/google/elmo/2") as f:
for pystr in inputs:
batch_in = np.array([pystr])
batch_out = f(batch_in)
print(pystr, "--->", batch_out[0])
就原始 TensorFlow 而言,这对您的作用大致如下:
module = Module(ELMO_OR_WHATEVER)
tensor_in = tf.placeholder(tf.string, shape=[None]) # As befits `module`.
tensor_out = module(tensor_in)
# This kind of session handles init ops for you.
with tf.train.SingularMonitoredSession() as sess:
for pystr in inputs:
batch_in = np.array([pystr])
batch_out = sess.run(tensor_out, feed_dict={tensor_in: batch_in}
print(pystr, "--->", batch_out[0])
如果您的需求过于复杂with hub.eval_function_for_module ...
,您可以构建这个更明确的示例。
注意 hub.Module 是如何在循环中既没有构造也没有调用的。
PS:厌倦了担心构建图表和运行会话?那么 TF2 和 Eager Execution 适合您。查看https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_text_classification.ipynb
推荐阅读
- gcc - 找不到 -ll collect2:错误:ld 返回 1 退出状态
- reactjs - UseEffect 或在功能组件中抛出逻辑
- pycharm - 无法在 Windows 10 PC 中打开我的 PyCharm 应用程序?
- c++ - 单个文件中的多个命名空间相互引用 C++
- python-3.x - Modin df iterrows 非常缓慢。有什么办法可以加快速度吗?
- php - 如何将表格的内容发送到我的电子邮件?使用 HTML 和 PHP
- elasticsearch - 知道弹性搜索中的开始日期字段时如何计算持续时间字段?
- postgresql - 我将如何按大多数加权匹配排序?
- php - 我想删除 json 格式
- java - 使用 BigInteger 时如何在范围内选择一个随机数