python - 使用 MirroredStrategy 时出现 AssertionError:isinstance(x, dataset_ops.DatasetV2)
问题描述
我正在尝试使用MirroredStrategy来使用两个 Titan Xp GPU 来拟合我的顺序模型。我tensorflow 2.0
在 ubuntu 16.04 上使用 alpha。
我成功运行了 tensorflow 文档中的代码片段:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mse', optimizer='sgd')
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
model.evaluate(dataset)
但是,当我尝试对我的数据进行训练时,这是一个稀疏的形状矩阵(使用亚当优化器和二元交叉熵):
Shape X_train: (91422, 65545)
Shape y_train: (91422, 1)
我在 _distribution_standardize_user_data 中收到断言错误
assert isinstance(x, dataset_ops.DatasetV2)
在代码中, training.pyTensorFlow
中的第 2166 行似乎导致了这个断言错误。
有人可以向我解释我的数据可能存在什么问题吗?
解决方案
使用dataset= strategy.experimental_distribute_dataset(train_dataset)
with时出现类似错误model.fit(dataset)
。
我删除strategy.experimental_distribute_dataset
. 它工作正常。它类似于TF 文档,他们说keras.Model.fit()
自动处理所有事情,只有当我们想要使用tf.GradientTape()
.
您可以通过MNIST 的官方教程了解更多信息
推荐阅读
- ios - 在 iOS 浏览器中流式传输 AWS S3 HLS 视频
- python - 如何在 PYTHON 中的列表字典中迭代值
- flutter - 颤振动画没有持续时间
- javascript - Js:将箭头函数名称分配给变量
- ios - 共享扩展 - 自 iOS 14 以来首次尝试未显示在共享菜单中的应用
- r - 检查文件是否有 x 列,参数长度为零
- postgresql - Postgres 服务器是否只在一台机器上存储数据?
- c++ - FFmpeg - Libavcodec - 无法编译,未找到文件错误,但文件存在
- python - Heroku 部署因缺少 SQL 表而失败
- javascript - Rails:为 link_to 显示红色的工具提示