tensorflow - Kaggle TPU:无法连接到所有地址
问题描述
在尝试在 kaggle 上使用 TPU 拟合我的模型时,我遇到了一些问题。
Tpu 已经初始化:
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print(f'Running on TPU {tpu.master()}')
except ValueError:
tpu = None
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
strategy = tf.distribute.get_strategy()
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
但是当我尝试拟合我的模型时,会出现此错误:
{{function_node __inference_train_function_64094}} failed to connect to all addresses
GRPC error information:{"created":"@1609444822.190891136","description":"Failed to pick
subchannel","file":"third_party/grpc/src/core/ext/filters/client_channel/client_channel.cc",
file_line":3959,"referenced_errors": [{"created":"@1609444822.190889693"
,"description":"failed to connect to all addresses", […]
[[{{node MultiDeviceIteratorGetNextFromShard}}]] [[RemoteCall][[IteratorGetNextAsOptional]]
解决方案
您必须在策略范围内创建模型和优化器:
with strategy.scope():
model = create_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
推荐阅读
- r - ggplot2:无效输入:使用 scale_x_datetime 时的 time_trans
- reactjs - ReactJS - 显示列名
- mule-studio - 在dataweave中转换日期时间格式
- rust - Cargo 的 build 和 rustc 命令有什么区别?
- r - 即使使用 DPLYR 包中的 SELECT 列存在,也无法对列进行子集化
- mysql - 加入和分组 concat mysql 查询未按预期工作
- graphql - 如何转换从@apollo/react-hooks 接收到的数据
- java - 超过时间限制,我正在尝试使用链表使用 java 堆栈删除字符串中的重复项
- node.js - 断言对象的每个键都是一个数组
- ruby-on-rails - 确定日期/时间字符串中是否存在时间值