python - 如何在 Keras 中设置自定义测试步骤?
问题描述
我已经使用输入(图像)大小 [None, 400,400,3] 训练了我的模型,但我想使用不同的输入大小进行测试,例如 [None,512,512,3]。这是我的自定义培训实施:
my_model = customModel(rgb_mean=self.args.rgbn_mean)
ckpt_manager = tf.train.Checkpoint(optimizer=optimizer,model=my_model)
for epoch in range(self.args.max_epochs):
# training
for step, (x,y) in enumerate(train_data):
with tf.GradientTape() as tape:
pred = my_model(x, training=True)
preds, last_logits, loss = pre_process_binary_cross_entropy(
loss_bc,pred, y,self.args, use_tf_loss=False)
if (step)%100==0 and loss < global_loss:
# tfk.Model.save_weights(my_model,os.path.join(checkpoint_dir,"saved_model.h5"),
# save_format=ckpt_save_mode)
# # tfk.models.save_model(my_model,os.path.join(checkpoint_dir,"1saved_model.h5"),
# # save_format=ckpt_save_mode)
# tfk.models.save_model(my_model,checkpoint_dir)
ckpt_manager.save(checkpoint_dir)
现在这里是我的自定义测试实现:
root = tf.train.Checkpoint(optimizer=optimizer,
model=my_model)
ckpt_manager = tf.train.CheckpointManager(root,checkpoit_dir,max_to_keep=10)
root.restore(ckpt_manager.latest_checkpoint)
for step, x in enumerate(test_data):
preds = my_model(x,training=False)
当我用 400x400 测试模型调整大小时,它工作得很好,但是当我用 512x512 或 720x1280 等其他尺寸进行测试时,它给了我这个日志:
Traceback (most recent call last):
File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 76, in <module>
main(args=arg)
File "C:/Users/xavie/Documents/Codes/GitHub/efge/main.py", line 70, in main
model.test()
File "C:\Users\xavie\Documents\Codes\GitHub\efge\run_model.py", line 198, in test
preds = my_model(x,training=False)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 968, in __call__
outputs = self.call(cast_inputs, *args, **kwargs)
File "C:\Users\xavie\Documents\Codes\GitHub\efge\model.py", line 90, in call
output = self.batchnorm1(output, training=training)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 964, in __call__
self._maybe_build(inputs)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2416, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 401, in build
experimental_autocast=False)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 577, in add_weight
caching_device=caching_device)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 724, in _add_variable_with_custom_getter
name=name, shape=shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 791, in _preload_simple_restoration
checkpoint_position=checkpoint_position, shape=shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py", line 75, in __init__
self.wrapped_value.set_shape(shape)
File "C:\Users\xavie\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1107, in set_shape
(self.shape, shape))
ValueError: Tensor's shape (200,) is not compatible with supplied shape (256,)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm1.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm2.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm3.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm4.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.axis
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_mean
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.batchnorm5.moving_variance
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm1.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm2.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.conv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm3.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm4.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv2.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.gamma
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.batchnorm5.beta
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.dconv3.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv1.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv2.conv1.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).model.predConv3.conv1.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
这里的模型:
class customModel(tfk.Model):
def __init__(self, data_format='channels_first', weight_decay=1e4, rgb_mean=None):
super(customModel, self).__init__()
self.weight_decay = weight_decay
self.rgbn_mean = rgb_mean
axis = -1 if data_format == "channels_last" else 1
# data_format=data_format,
self.conv1 = tfk.layers.Conv2D(filters=16, kernel_size=(7, 7),
padding="same", use_bias=False,
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
strides=(2, 2)) # [8,200,200,16] when the input is 400
self.batchnorm1 = tfk.layers.BatchNormalization(axis=axis)
def call(self, x, training=False):
x = x-self.rgbn_mean[:-1]
output = self.conv1(x, training=training)
output = self.batchnorm1(output, training=training)
output = tf.nn.relu(output)
return output
我做错了什么?我该如何解决?请帮助我,我是 Keras 的新手 :( PS:我尝试过使用不同的 Keras 保存模型,但我无法使用不同的图像尺寸进行测试。
解决方案
推荐阅读
- angular - 必须使用“过滤器”的返回值
- flutter - 如何在颤动中解决“期望'BuildContext'类型的值,但得到'Null'类型之一”
- java - GameOver 没有出现
- c++ - 使经典的 Singleton 实现线程安全时面临的问题
- reactjs - TypeScript 的 React.Component 如何
泛型类型?
- node.js - webpack 版本导致的 npm start 错误
- react-native - 如何在不刷新的情况下在导航堆栈之间切换
- java - 注意:MainActivity 使用或覆盖已弃用的 API。注意:使用 -Xlint:deprecation 重新编译以获取详细信息
- vbscript - 从 classic.asp 以条带创建订阅的正确语法
- angular - Angular 中的 MathJax v3 - 在页面加载后渲染 TeX(在 ngOnInit 之后)