TensorFlow中的选择性零权重? [英] Selectively zero weights in TensorFlow?

查看:143
本文介绍了TensorFlow中的选择性零权重?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个NxM权重变量weights和一个1s和0s mask的恒定NxM矩阵.

Lets say I have an NxM weight variable weights and a constant NxM matrix of 1s and 0s mask.

如果我的网络层是这样定义的(其他层也有类似的定义):

If a layer of my network is defined like this (with other layers similarly defined):

masked_weights = mask*weights
layer1 = tf.relu(tf.matmul(layer0, masked_weights) + biases1)

在训练过程中,该网络的行为是否会像mask中的对应0一样是weights中的零? (即好像这些权重代表的连接已完全从网络中删除)?

Will this network behave as if the corresponding 0s in mask are zeros in weights during training? (i.e. as if the connections represented by those weights had been removed from the network entirely)?

如果没有,我如何在TensorFlow中实现这个目标?

If not, how can I achieve this goal in TensorFlow?

推荐答案

答案是肯定的.实验描绘了下图.

The answer is yes. The experiment depicts the following graph.

实现为:

import numpy as np, scipy as sp, tensorflow as tf

x = tf.placeholder(tf.float32, shape=(None, 3))
weights = tf.get_variable("weights", [3, 2])
bias = tf.get_variable("bias", [2])
mask = tf.constant(np.asarray([[0, 1], [1, 0], [0, 1]], dtype=np.float32)) # constant mask

masked_weights = tf.multiply(weights, mask)
y = tf.nn.relu(tf.nn.bias_add(tf.matmul(x, masked_weights), bias))
loss = tf.losses.mean_squared_error(tf.constant(np.asarray([[1, 1]], dtype=np.float32)),y)

weights_grad = tf.gradients(loss, weights)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("Masked weights=\n", sess.run(masked_weights))
data = np.random.rand(1, 3)

print("Graident of weights\n=", sess.run(weights_grad, feed_dict={x: data}))
sess.close()

运行上面的代码后,您还将看到渐变也被屏蔽了.在我的示例中,它们是:

After running the code above, you will see the gradients are masked as well. In my example, they are:

Graident of weights
= [array([[ 0.        , -0.40866762],
       [ 0.34265977, -0.        ],
       [ 0.        , -0.35294518]], dtype=float32)]

这篇关于TensorFlow中的选择性零权重?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆