手动将 pytorch 权重转换为卷积层的 tf.keras 权重 [英] Manualy convert pytorch weights to tf.keras weights for convolutional layer

查看:91
本文介绍了手动将 pytorch 权重转换为卷积层的 tf.keras 权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将 pytorch 模型转换为 tf.keras 模型,包括权重转换,但遇到了库输出之间的输出不匹配.

I'm trying to convert pytorch model to tf.keras model including weights conversion and came across an output missmatch between libraries' outputs.

这里我定义了两个卷积层,应该是一样的

Here I define two convolutional layers, which should be identical

torch_layer = torch.nn.Conv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    dilation=1,
    groups=1,
    bias=False,
    padding_mode='zeros'
)

tf_layer = tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=(7, 7),
    strides=(2, 2),
    padding='same',
    dilation_rate=(1, 1),
    groups=1,
    activation=None,
    use_bias=False
)
# define model to specify input channel size
tf_model = tf.keras.Sequential([tf.keras.layers.Input((256, 256, 3), batch_size=1), tf_layer])

现在我有了火炬权重并将它们转换为 tf.keras 格式

now I have torch weights and I convert them to tf.keras format

# output_channels, input_channels, x, y
torch_weights = np.random.rand(64, 3, 7, 7)
# x, y, input_channels, output_channels
tf_weights = np.transpose(torch_weights, (2, 3, 1, 0))

# assign weights
torch_layer.weight = torch.nn.Parameter(torch.Tensor(torch_weights))
tf_model.layers[0].set_weights([tf_weights])

现在我定义输入和输出不同(形状相同,值不同),我做错了什么?

now I define input and the outputs are different (shape is the same, values are different), what am I doing wrong?

torch_inputs = np.random.rand(1, 3, 256, 256)
tf_inputs = np.transpose(torch_inputs, (0, 2, 3, 1))

torch_output = torch_layer(torch.Tensor(torch_inputs))
tf_output = tf_model.layers[0](tf_inputs)

推荐答案

在tensorflow中,set_weights基本上是用于get_weights的输出,所以用<代码>赋值以避免出错.

In tensorflow, set_weights is basically used for outputs from get_weights, so it is better to use assign to avoid making mistakes.

此外,tensorflow 中的相同"填充有点复杂.有关详细信息,请参阅我的 SO回答.这取决于input_shapekernel_sizestrides.在您的示例中,它在 pytorch 中被转换为 torch.nn.ZeroPad2d((2,3,2,3)).

Besides, 'same' padding in tensorflow is a little bit complicated. For details, see my SO answer. It depends on input_shape, kernel_size and strides. In your example here, it is translated to torch.nn.ZeroPad2d((2,3,2,3)) in pytorch.

示例代码:从 tensorflow 到 pytorch

Example codes: from tensorflow to pytorch

np.random.seed(88883)

#initialize the layers respectively
torch_layer = torch.nn.Conv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=(7, 7),
    stride=(2, 2),
    bias=False
)
torch_model = torch.nn.Sequential(
              torch.nn.ZeroPad2d((2,3,2,3)),
              torch_layer
              )

tf_layer = tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=(7, 7),
    strides=(2, 2),
    padding='same',
    use_bias=False
)

#setting weights in torch layer and tf layer respectively
torch_weights = np.random.rand(64, 3, 7, 7)
tf_weights = np.transpose(torch_weights, (2, 3, 1, 0))

with torch.no_grad():
  torch_layer.weight = torch.nn.Parameter(torch.Tensor(torch_weights))

tf_layer(np.zeros((1,256,256,3)))
tf_layer.kernel.assign(tf_weights)

#prepare inputs and do inference
torch_inputs = torch.Tensor(np.random.rand(1, 3, 256, 256))
tf_inputs = np.transpose(torch_inputs.numpy(), (0, 2, 3, 1))

with torch.no_grad():
  torch_output = torch_model(torch_inputs)
tf_output = tf_layer(tf_inputs)

np.allclose(tf_output.numpy() ,np.transpose(torch_output.numpy(),(0, 2, 3, 1))) #True

从pytorch到tensorflow

from pytorch to tensorflow

torch_layer = torch.nn.Conv2d(
    in_channels=3,
    out_channels=64,
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    bias=False
)

tf_layer=tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=(7, 7),
    strides=(2, 2),
    padding='valid',
    use_bias=False
    )

tf_model = tf.keras.Sequential([
           tf.keras.layers.ZeroPadding2D((3, 3)),
           tf_layer
           ])

这篇关于手动将 pytorch 权重转换为卷积层的 tf.keras 权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆