tensorflow - 使用 Estimator 接口通过预训练的 tensorflow 对象检测模型进行推理
问题描述
我正在尝试从Tensorflow 对象检测repo加载一个预先训练好的 tensorflow 对象检测模型,tf.estimator.Estimator
并使用它来进行预测。
我可以使用 加载模型并运行推理Estimator.predict()
,但是输出是垃圾。加载模型的其他方法,例如作为Predictor
, 和运行推理工作正常。
任何帮助正确加载模型作为Estimator
调用predict()
将不胜感激。我当前的代码:
加载和准备图像
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(list(image.getdata())).reshape((im_height, im_width, 3)).astype(np.uint8)
image_url = 'https://i.imgur.com/rRHusZq.jpg'
# Load image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Format original image size
im_size_orig = np.array(list(image.size) + [1])
im_size_orig = np.expand_dims(im_size_orig, axis=0)
im_size_orig = np.int32(im_size_orig)
# Resize image
image = image.resize((np.array(image.size) / 4).astype(int))
# Format image
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_np_expanded = np.float32(image_np_expanded)
# Stick into feature dict
x = {'image': image_np_expanded, 'true_image_shape': im_size_orig}
# Stick into input function
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
x=x,
y=None,
shuffle=False,
batch_size=128,
queue_capacity=1000,
num_epochs=1,
num_threads=1,
)
边注:
train_and_eval_dict
似乎也包含一个input_fn
预测
train_and_eval_dict['predict_input_fn']
然而,这实际上返回 a tf.estimator.export.ServingInputReceiver
,我不知道该怎么做。这可能是我的问题的根源,因为在模型实际看到图像之前涉及到相当多的预处理。
加载模型为Estimator
模型从 TF 模型动物园下载在这里,加载模型的代码从这里改编。
model_dir = './pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28/'
pipeline_config_path = os.path.join(model_dir, 'pipeline.config')
config = tf.estimator.RunConfig(model_dir=model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
hparams=model_hparams.create_hparams(None),
pipeline_config_path=pipeline_config_path,
train_steps=None,
sample_1_of_n_eval_examples=1,
sample_1_of_n_eval_on_train_examples=(5))
estimator = train_and_eval_dict['estimator']
运行推理
output_dict1 = estimator.predict(predict_input_fn)
这会打印出一些日志消息,其中之一是:
INFO:tensorflow:Restoring parameters from ./pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28/model.ckpt
所以看起来预训练的权重正在加载。然而结果看起来像:
加载与 a 相同的模型Predictor
from tensorflow.contrib import predictor
model_dir = './pretrained_models/tensorflow/ssd_mobilenet_v1_coco_2018_01_28'
saved_model_dir = os.path.join(model_dir, 'saved_model')
predict_fn = predictor.from_saved_model(saved_model_dir)
运行推理
output_dict2 = predict_fn({'inputs': image_np_expanded})
结果看起来不错:
解决方案
当您将模型作为估计器并从检查点文件加载时,这里是与ssd
模型关联的恢复函数。从ssd_meta_arch.py
def restore_map(self,
fine_tune_checkpoint_type='detection',
load_all_detection_checkpoint_vars=False):
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
load_all_detection_checkpoint_vars: whether to load all variables (when
`fine_tune_checkpoint_type='detection'`). If False, only variables
within the appropriate scopes are included. Default False.
Returns:
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
Raises:
ValueError: if fine_tune_checkpoint_type is neither `classification`
nor `detection`.
"""
if fine_tune_checkpoint_type not in ['detection', 'classification']:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
if fine_tune_checkpoint_type == 'classification':
return self._feature_extractor.restore_from_classification_checkpoint_fn(
self._extract_features_scope)
if fine_tune_checkpoint_type == 'detection':
variables_to_restore = {}
for variable in tf.global_variables():
var_name = variable.op.name
if load_all_detection_checkpoint_vars:
variables_to_restore[var_name] = variable
else:
if var_name.startswith(self._extract_features_scope):
variables_to_restore[var_name] = variable
return variables_to_restore
如您所见,即使配置文件设置from_detection_checkpoint: True
了 ,也只会恢复特征提取器范围内的变量。要恢复所有变量,您必须设置
load_all_detection_checkpoint_vars: True
在配置文件中。
所以,上面的情况就很清楚了。当将模型加载为 时Estimator
,只会恢复来自特征提取器范围的变量,而不会恢复预测器的范围权重,估计器显然会给出随机预测。
当加载模型作为预测器时,所有权重都被加载,因此预测是合理的。
推荐阅读
- python - Python, delete attribute from instance but not from source
- python - 无法将数据磁盘附加到 Azure VM
- java - I cant connect a form with method post with a servlet
- android - Cache HTTP response or store in database in mobile apps?
- javascript - Twitter Api returns "code":215,"message":"Bad Authentication data." by axios
- python - How to play the alarm after 5 second when the eyes had closed in python?
- google-cloud-platform - 在 Google Cloud Platform 中,每个 IOT Core 设备是否需要有单独的发布/订阅主题?
- javascript - Snowflake JavaScript procedure how to update a field from an object that is not in a stage?
- python - How to create or mount dash component after clicking on other component like button or checkbox
- android - Drawbacks of having large Android App Size on device