将tf.contrib.layers.batch_norm迁移到Tensorflow 2.0 [英] Migrate tf.contrib.layers.batch_norm to Tensorflow 2.0

查看:1175
本文介绍了将tf.contrib.layers.batch_norm迁移到Tensorflow 2.0的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在将TensorFlow代码迁移到Tensorflow 2.1.0.

I'm migrating a TensorFlow code to Tensorflow 2.1.0.

这是原始代码:

conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)

这就是我所做的:

conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)

我的问题是我不知道如何处理 tf.contrib.layers.batch_norm .

My problem is that I don't know what to do with tf.contrib.layers.batch_norm.

如何将 tf.contrib.layers.batch_norm 迁移到Tensorflow 2.x?

How can I migrate tf.contrib.layers.batch_norm to Tensorflow 2.x?

更新:
使用评论建议,我认为我已正确迁移:

UPDATE:
Using the comment suggestion, I think I have migrated correctly:

conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)

但是我不确定 decay 是否像 momentum 一样,而且我不知道如何在中设置 updates_collections BatchNormalization 方法.

But I'm not sure if decay is like momentum and I don't know how to set updates_collections in the BatchNormalization method.

推荐答案

在使用我将要进行微调的训练模型时,我遇到了这个问题.像OP那样用 tf.keras.layers.BatchNormalization 替换 tf.contrib.layers.batch_norm 确实给了我一个错误,其修复方法如下所述.

I encountered this problem when working with a trained model that I was going to fine tune. Just replacing tf.contrib.layers.batch_norm with tf.keras.layers.BatchNormalization like OP did gave me an error whose fix is described below.

旧代码如下:

tf.contrib.layers.batch_norm(
    tensor,
    scale=True,
    center=True,
    is_training=self.use_batch_statistics,
    trainable=True,
    data_format=self._data_format,
    updates_collections=None,
)

,更新后的工作代码如下:

and the updated working code looks like this:

tf.keras.layers.BatchNormalization(
    name="BatchNorm",
    scale=True,
    center=True,
    trainable=True,
)(tensor)

我不确定我删除的所有关键字参数是否都会出现问题,但是一切似乎都可以正常工作.请注意 name ="BatchNorm" 参数.图层使用不同的命名架构,因此我不得不使用 inspect_checkpoint.py 工具查看模型,并找到恰好是 BatchNorm 的图层名称.

I'm unsure if all the keyword arguments I removed are going to be a problem but everything seems to work. Note the name="BatchNorm" argument. The layers use a different naming schema so I had to use the inspect_checkpoint.py tool to look at the model and find the layer names which happened to be BatchNorm.

这篇关于将tf.contrib.layers.batch_norm迁移到Tensorflow 2.0的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆