首页 > 解决方案 > tf.keras 手动设备放置

问题描述

迁移到 TF2.0 我正在尝试使用该tf.keras方法来解决问题。在标准 TF 中,我可以with tf.device(...)用来控制操作的位置。

例如,我可能有一个类似的模型


model = tf.keras.Sequential([tf.keras.layers.Input(..),
                             tf.keras.layers.Embedding(...),
                             tf.keras.layers.LSTM(...),
                             ...])

假设我想让网络直到Embedding(包括)在 CPU 上以及从那里开始在 GPU 上,我将如何去做?(这只是一个例子,这些层可能与嵌入无关)

如果解决方案涉及子类化tf.keras.Model也可以,我不介意不使用Sequential

标签: pythontensorflowkerastensorflow2.0

解决方案


您可以使用 Keras 功能 API:

inputs = tf.keras.layers.Input(..)
with tf.device("/GPU:0"):
    model = tf.keras.layers.Embedding(...)(inputs)
outputs = tf.keras.layers.LSTM(...)(model)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

推荐阅读