python - 如何在 Keras 中使用 float16 微调 resnet50?
问题描述
我试图在半精度模式下微调 resnet50,但没有成功。似乎模型的某些部分与float16
. 这是我的代码:
dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)
model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
我得到这个错误:
Traceback (most recent call last):
File "train_resnet.py", line 40, in <module>
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
return base_fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
return resnet50.ResNet50(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
output = self.call(inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
epsilon=self.epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
epsilon=epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
data_format=tf_data_format)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
param_name=input_name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32
解决方案
这是一个报告的错误和升级Keras==2.2.5
解决了这个问题。
推荐阅读
- spring - 从 Rest Api 调用 Spring Stomp Websocket 不起作用
- django - 三引号中的 Django 原始 sql 插入查询:Django 将 ajax 请求数据中的空值解释为 None 列
- python - 我想在 colab 中使用 imshow(),但它不起作用
- python-3.x - Concox GT800 / Wetrack800 GPS设备数据解码
- java - 如何填充可变参数构造器取消 applicationContext bean?
- json - 如何以编程方式查找 LUKS 标头的开始和结束字节?
- ios - 仍然收到警告“App Store 团队已弃用 API 使用”
- angularjs - 离线持久性在 AngularJS 的 Cloud Firestore 中不起作用(如何调试?)
- java - 为什么这个嵌套的for循环会无限循环(java)?
- matlab - MATLAB中的视频播放器,按住滑块控件