首页 > 解决方案 > Tensorflow - 加载预训练的 inception_v3 以与 tf.estimator 一起使用

问题描述

我正在使用 tf.estimator.Estimator 来训练我的模型。我现在想尝试一个预训练的 inception_v3 来代替我当前的模型。

我偶然发现了这个存储库,它似乎拥有我可能需要的所有模型:https ://github.com/tensorflow/models/tree/master/research/slim

加载图表和训练似乎工作得很好。现在它还想使用我从这里下载的预训练权重:http: //download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz

我尝试了在网上找到的多种变体,但到目前为止似乎没有任何效果。大多数时候我收到以下错误:

Tensor name "model/InceptionV3/AuxLogits/Conv2d_1b_1x1/BatchNorm/beta" not found in checkpoint files

这就是我现在构建图表的方式:

sys.path.insert(0, "./models/research/slim")
from nets import nets_factory

# ...

params.network_name = "inception_v3",
params.pre_logits   = "PreLogits",         
params.checkpoint   = "inception_v3.ckpt"

# ...

network_fn = nets_factory.get_network_fn(params.network_name,
                                         num_classes=(1001),
                                         weight_decay=0.00004,
                                         is_training=is_training)

logits, end_points = network_fn(images)        

pre_logits = end_points[params.pre_logits]
pre_logits = tf.layers.flatten(inputs=pre_logits)

# ...

loss = ...

train_op = optimizer.minimize(loss, global_step=global_step)
model_fn = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

estimator = tf.estimator.Estimator(model_fn, 
                                   params=params, 
                                   config=config)

# ...

我尝试了 tensorflow 版本 v.1.6.0 和 v.1.8.0。两者似乎都有相同的问题。

有任何想法吗?

非常感谢!

标签: pythontensorflowtf-slim

解决方案


推荐阅读