tensorflow.nn.conv2d 中 NCHW 格式的过滤器形状 [英] Filter shape in tensorflow.nn.conv2d with NCHW format

查看:23
本文介绍了tensorflow.nn.conv2d 中 NCHW 格式的过滤器形状的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

遵循 Tensorflow 的最佳性能实践,我使用的是 NCHW 数据格式,但我不确定要在 tensorflow.nn.conv2d.

Following Tensorflow's best practices for performance, I am using NCHW data format, but I am not sure about the filter shape to be used in tensorflow.nn.conv2d.

文档说对 NHWC 格式使用 [filter_height, filter_width, in_channels, out_channels],但不清楚如何处理 NCHW.

The doc says to use [filter_height, filter_width, in_channels, out_channels] for NHWC format, but is not clear about what to do with NCHW.

是否应该使用相同的形状?

Should the same shape be used ?

推荐答案

使用相同的过滤器形状应该可行.函数参数的唯一变化是步幅.例如,假设您希望您的架构同时使用这两种格式,这也是推荐的:

Using the same filter shape should work. The only change to the function arguments is the stride. As an example let's say you wanted your architecture to work with both formats, which is also recommended:

# input -> Tensor in NCHW format
if use_nchw:
    result = tf.nn.conv2d(
        input=input,
        filter=filter,
        strides=[1, 1, stride, stride],
        data_format='NCHW')
else:
    input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC

    result = tf.nn.conv2d(
        input=input_t,
        filter=filter,
        strides=[1, stride, stride, 1])

    result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW  

这篇关于tensorflow.nn.conv2d 中 NCHW 格式的过滤器形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆