tensorflow - tf.nn.rnn_cell.GRUCell 建立在 CPU 设备上
问题描述
我现在正在训练一个 2 层 seq2seq 模型并使用 gru_cell。
def create_rnn_cell():
encoDecoCell = tf.contrib.rnn.GRUCell(emb_dim)
encoDecoCell = tf.contrib.rnn.DropoutWrapper(
encoDecoCell,
input_keep_prob=1.0,
output_keep_prob=0.7
)
return encoDecoCell
encoder_mutil = tf.contrib.rnn.MultiRNNCell(
[create_rnn_cell() for _ in range(num_layers)],
)
query_encoder_emb = tf.contrib.rnn.EmbeddingWrapper(
encoder_mutil,
embedding_classes=vocab_size,
embedding_size=word_embedding
)
Timeline 对象用于获取图中每个节点的执行时间,我发现 GRU_cell(包括 MatMul)内的大多数操作都发生在 CPU 设备上,这使得它非常慢。我安装了tf-1.8的gpu版本。对此有何评论?我在这里错过了什么吗?我猜 tf.variable_scope 有问题,因为我对训练数据使用了不同的存储桶。这就是我在不同bucktes之间重用变量的方式:
for i, bucket in enumerate(buckets):
with tf.variable_scope(name_or_scope="RNN_encoder", reuse=True if i > 0 else None) as var_scope:
query_output, query_state = tf.contrib.rnn.static_rnn(query_encoder_emb,inputs=self.query[:bucket[0]],dtype=tf.float32)
解决方案
我发现了问题。在 EmbeddingWrapper 的源代码中,使用了 CPU。 tf.contrib.rnn.EmbeddingWrapper 我重写了这个函数,现在它可以在 GPU 上运行并且速度更快。所以如果你想使用 tf.contrib.rnn.EmbeddingWrapper 要小心。
推荐阅读
- javascript - 如果 div 包含某些内容,则更改父 div 的 css
- python - 运行 VSCode nodebug 时,当前目录未添加到 sys.path
- forms - Symfony 形式的值对象约束
- mongodb - MongoTemplate 和/或 MongoRepository 是否支持 Mongo 4 事务?
- python - 如何使用 pandas 读取 .txt 数据列并提供给 TF
- spring - 使用 JAXB 编组 XML 元素内的属性
- javascript - 循环结果数组的最后一个索引
- java - EnumSet.spliterator 没有特征 Spliterator.NONNULL
- html - Div 图像背景在打印时被子表隐藏
- c++ - 使用结构作为缓冲架