首页 > 解决方案 > 输入张量的尺寸应为 1 x 高 x 宽 x 3。得到 1 x 3 x 224 x 224

问题描述

我想将 Pytorch 训练的模型转换为 tensorflow 模型并在移动设备上使用该模型。为此,我遵循以下步骤;首先,我将 pytorch 训练的模型转换为 onnx 格式。然后我将 onnx 格式转换为 tensorflow 模型。

首先 pytorch 将模型训练到 onnx;

import torch
import torch.onnx
from detectron2.modeling import build_model
from detectron2.modeling import build_backbone
from torch.autograd import Variable

model= build_backbone(cfg)
model.eval()

dummy_input = torch.randn(1,3,224, 224,requires_grad=True)

torch.onnx.export(model,dummy_input,"drive/Detectron2/model_final.onnx")

然后onnx转tflite模型;

import onnx
import warnings
from onnx_tf.backend import prepare

model = onnx.load("drive/Detectron2/model_final.onnx")
tf_rep = prepare(model)
tf_rep.export_graph("drive/Detectron2/tf_rep.pb")
import tensorflow as tf

## TFLite Conversion
# Before conversion, fix the model input size
model = tf.saved_model.load("drive/Detectron2/tf_rep.pb")
model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs[0].set_shape([1, 3,224, 224])
tf.saved_model.save(model, "saved_model_updated", signatures=model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY])
# Convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_updated', signature_keys=['serving_default'])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

# Save the model.
with open('drive/Detectron2/model.tflite', 'wb') as f:
  f.write(tflite_model)

## TFLite Interpreter to check input shape
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)

但是当我在移动设备上使用该模型时,出现以下错误;

java.lang.AssertionError: Error occurred when initializing ImageSegmenter: The input tensor should have dimensions 1 x height x width x 3. Got 1 x 3 x 224 x 224.

我在哪里做错了?

标签: pythontensorflowmachine-learningpytorchonnx

解决方案


也许你可以尝试einops张量转换。它优雅而强大。在您的情况下,代码应该是


import einops
input_tensor = einops.rearrange(input_tensor,'b c w h -> b w h c')

推荐阅读