python - 在 GPU 上加速 TF/Keras LSTM 文本生成?
问题描述
用于文本生成的 tensorflow 官方示例 ( https://github.com/tensorflow/docs/blob/master/site/en/tutorials/text/text_generation.ipynb ) 在如下定义的循环中运行。文本生成感觉很慢,根据 NVTOP 的说法,它只使用了可用 GPU 资源的一小部分 (15-20%)。
关于如何加快文本生成的任何建议?快速浏览一下 cprofiler 会发现 90% 的时间都花在了单行上predictions = model(input_eval)
,所以我认为其他地方不会有很多收获。
此外,Tensorflow/Keras 文档https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict建议调用该函数,如下所示...
此方法专为大规模输入的性能而设计。对于一批适合的少量输入,建议直接使用call 以加快执行速度,例如 model(x) 或 model(x, training=False)
关于如何加快文本生成的任何建议?是否可以通过同时生成多条线来更好地使用 GPU?
def generate_text(model, start_string):
# Evaluation step (generating text using the learned model)
# Number of characters to generate
num_generate = 1000
# Converting our start string to numbers (vectorizing)
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
# Empty string to store our results
text_generated = []
# Low temperatures results in more predictable text.
# Higher temperatures results in more surprising text.
# Experiment to find the best setting.
temperature = 1.0
# Here batch size == 1
model.reset_states()
for i in range(num_generate):
predictions = model(input_eval)
# remove the batch dimension
predictions = tf.squeeze(predictions, 0)
# using a categorical distribution to predict the character returned by the model
predictions = predictions / temperature
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
# We pass the predicted character as the next input to the model
# along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], 0)
text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated))
解决方案
为了加快处理速度,我有两个建议,
由于您有 GPU 支持,您可能需要设置
unroll=True
图层GRU
。根据 KerasGRU
文档,设置unroll=True
通过使用一些额外的内存来减少一些计算。由于您的 GPU 消耗非常少,您可能需要使用unroll=True
. 使用此设置,您可能会注意到2x
速度提升(取决于具体情况)。但是,如果输入序列太长,您应该避免使用展开。我注意到您链接
GRU
的文本生成架构在层之前使用层Dense
。给GRU
定一个参数return_sequences=True
。这会导致该GRU
层将不必要的输出值传递给后续Dense
层,并且需要更多的计算。一般return_sequences=True
只有在模型的下一层也是RNN层时才应该设置。因此,请尝试设置参数return_sequences=False
。这也可以提高性能。
最后,model(x, training=False)
确实有效。我相信通过维护这三个问题,您可能会注意到性能的显着提升。
推荐阅读
- java - 未找到模块 org.graalvm.sdk,com.oracle.truffle.regex 需要
- python - 为 DataFrame 中的每个组有效地枚举 bin 中的行
- css - 没有 CSS 框架的输入/标签位置操作,只有 CSS3
- flutter - 在 Text 小部件中使用来自 initState() 的值
- jooq - 使用 ParserCLI 的问题
- amazon-web-services - Jhipster aws 子生成器在全新安装时抛出 Invalid DB engine
- java - 用户值不会存储在数组中
- survey - 我正在尝试列出每个年龄段最常见的职业名称(“职业”)
- json - 使用 curl 命令更新 url(必填字段为空)字段时出错
- javascript - 通过 querySelectorAll 分隔具有相同类名的函数