python - 在 GPU 上使用 tensorflow 训练模型,使用 Adadelta 优化器不起作用。但是当我用 Adam 替换 Adadelta 时,它似乎没有问题。
问题描述
我正在尝试在 GPU 上使用 adadelta 优化器在 tensorflow(python2 上的 v1.9.0)上训练模型。它显示以下错误。
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'embedding_matrix_de/read': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'
Colocation Debug Info:
Colocation group had the following types and devices:
UnsortedSegmentSum: GPU CPU
Unique: GPU CPU
Shape: GPU CPU
Cast: GPU CPU
StridedSlice: GPU CPU
GatherV2: GPU CPU
SparseApplyAdadelta: CPU
Const: GPU CPU
Identity: CPU
VariableV2: GPU CPU
Colocation members and user-requested devices:
embedding_matrix_de (VariableV2)
embedding_matrix_de/read (Identity)
embedding_lookup/axis (Const)
embedding_lookup (GatherV2)
gradients/embedding_lookup_grad/Shape (Const)
gradients/embedding_lookup_grad/ToInt32 (Cast)
embedding_matrix_de/Adadelta (VariableV2)
embedding_matrix_de/Adadelta_1 (VariableV2)
Adadelta/update_embedding_matrix_de/Unique (Unique)
Adadelta/update_embedding_matrix_de/Shape (Shape)
Adadelta/update_embedding_matrix_de/strided_slice/stack (Const)
Adadelta/update_embedding_matrix_de/strided_slice/stack_1 (Const)
Adadelta/update_embedding_matrix_de/strided_slice/stack_2 (Const)
Adadelta/update_embedding_matrix_de/strided_slice (StridedSlice)
Adadelta/update_embedding_matrix_de/UnsortedSegmentSum (UnsortedSegmentSum)
Adadelta/update_embedding_matrix_de/SparseApplyAdadelta (SparseApplyAdadelta)
[[Node: embedding_matrix_de/read = Identity[T=DT_FLOAT, _class=["loc:@embedding_matrix_de"]](embedding_matrix_de)]]
当我用亚当替换 adadelta 时,没有问题。下面给出了一些代码。
....
embedding_matrix_decode = tf.get_variable(
name="embedding_matrix_de",
shape=[trainVocabSize, embedding_size],
dtype=tf.float32)
....
optimizer = tf.train.AdadeltaOptimizer()
....
解决方案
我在 Tensorflow 2.1.1 中遇到了同样的问题。Adadelta 优化器似乎不支持 GPU 和 TPU。
推荐阅读
- ubuntu - bash:/generic_send_tcp:在 Ubuntu 18.04 中没有这样的文件或目录
- android - 在 Android 服务中接收两次 onTaskRemoved() 回调
- node.js - 如何通过谷歌云 http 功能上传大于 10mb 的文件。? 任何替代选择?
- python - TensorFlow Probability:如何获得预测的准确性?
- c# - 已经定义了一个使用相同参数类型调用的成员
- javascript - 如何验证 JSON 模式中的自定义类型?
- jquery - Bootstrap 模态中的 YouTube 嵌入缩略图模糊
- java - CXF客户端如何支持多个证书
- r - 按行重塑矩阵
- angular6 - 分页记录很少时向上滚动