python - 输入张量的尺寸应为 1 x 高 x 宽 x 3。得到 1 x 3 x 224 x 224
问题描述
我想将 Pytorch 训练的模型转换为 tensorflow 模型并在移动设备上使用该模型。为此,我遵循以下步骤;首先,我将 pytorch 训练的模型转换为 onnx 格式。然后我将 onnx 格式转换为 tensorflow 模型。
首先 pytorch 将模型训练到 onnx;
import torch
import torch.onnx
from detectron2.modeling import build_model
from detectron2.modeling import build_backbone
from torch.autograd import Variable
model= build_backbone(cfg)
model.eval()
dummy_input = torch.randn(1,3,224, 224,requires_grad=True)
torch.onnx.export(model,dummy_input,"drive/Detectron2/model_final.onnx")
然后onnx转tflite模型;
import onnx
import warnings
from onnx_tf.backend import prepare
model = onnx.load("drive/Detectron2/model_final.onnx")
tf_rep = prepare(model)
tf_rep.export_graph("drive/Detectron2/tf_rep.pb")
import tensorflow as tf
## TFLite Conversion
# Before conversion, fix the model input size
model = tf.saved_model.load("drive/Detectron2/tf_rep.pb")
model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs[0].set_shape([1, 3,224, 224])
tf.saved_model.save(model, "saved_model_updated", signatures=model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY])
# Convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_updated', signature_keys=['serving_default'])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
# Save the model.
with open('drive/Detectron2/model.tflite', 'wb') as f:
f.write(tflite_model)
## TFLite Interpreter to check input shape
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)
但是当我在移动设备上使用该模型时,出现以下错误;
java.lang.AssertionError: Error occurred when initializing ImageSegmenter: The input tensor should have dimensions 1 x height x width x 3. Got 1 x 3 x 224 x 224.
我在哪里做错了?
解决方案
也许你可以尝试einops
张量转换。它优雅而强大。在您的情况下,代码应该是
import einops
input_tensor = einops.rearrange(input_tensor,'b c w h -> b w h c')
推荐阅读
- php - 如何接收用 angularjs 填充并由 ajax 发送到服务器端(PHP)的 json 对象?
- elasticsearch - 如何列出elasticsearch6中的所有字段
- java - 是否可以使用 lombok 从基类实例构造派生类实例?
- c++ - QDomNode.clear() 不会删除 xml 标记的值
- search - 在 DAX 中搜索关键字
- c++ - 使用两个不同大小的一维数组制作二维数组
- c# - 将空共享字符串写入单元格 openxml
- cron - 多周的一天的 Cron 格式
- select - SQL Server - 加入动态透视表中的选择列
- javascript - D3.js:如何更改单个节点中文本的位置