首页 > 解决方案 > 将预训练的保存模型从 NCHW 转换为 NHWC,使其与 Tensorflow Lite 兼容

问题描述

我已将模型从 PyTorch 转换为 Keras,并使用后端提取 tensorflow 图。由于 PyTorch 的数据格式是 NCHW,所以提取和保存的模型也是这样。将模型转换为 TFLite 时,由于格式为 NCHW,无法转换。有没有办法将整个图转换为 NHCW?

标签: tensorflowkeras

解决方案


最好有一个与 TFLite 匹配的数据格式的图,以便更快地推理。一种方法是手动将转置操作插入图中,例如: How to convert the CIFAR10 tutorial to NCHW

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())

推荐阅读