将预训练的已保存模型从NCHW转换为NHWC以使其与Tensorflow Lite兼容 [英] Converting pretrained saved model from NCHW to NHWC to make it compatible for Tensorflow Lite

查看:591
本文介绍了将预训练的已保存模型从NCHW转换为NHWC以使其与Tensorflow Lite兼容的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经将模型从PyTorch转换为Keras,并使用后端提取了张量流图.由于PyTorch的数据格式为NCHW,因此提取并保存的模型也是如此.将模型转换为TFLite时,由于格式为NCHW,因此无法转换.有没有办法将整个图转换为NHCW?

I have converted a model from PyTorch to Keras and used the backend to extract the tensorflow graph. Since the data format for PyTorch was NCHW, the model extracted and saved is also that. While converting the model to TFLite, due to the format being NCHW, it cannot get converted. Is there a way to convert the whole graph into NHCW?

推荐答案

最好让图的数据格式与TFLite匹配,以加快推理速度.一种方法是手动将转置ops插入到图形中,例如以下示例: 如何将CIFAR10教程转换为NCHW

It is better to have a graph with the data-format matched to TFLite for faster inference. One approach is to manually insert transpose ops into the graph, like this example: How to convert the CIFAR10 tutorial to NCHW

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())

这篇关于将预训练的已保存模型从NCHW转换为NHWC以使其与Tensorflow Lite兼容的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆