如何用另一个网络的权重初始化一个网络的权重? [英] How to initialize the weights of a network with the weights of another network?

查看:25
本文介绍了如何用另一个网络的权重初始化一个网络的权重?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想将 2 个网络合并为一个网络,同时保持原始网络的权重.

I want to combine 2 networks to one network while keeping the weights of the original network.

我使用以下方法以 numpy 形式保存了权重:

I saved the weights in in their numpy form using:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    weights[i.name] = i.eval()

我找不到将权重加载到新网络变量中的方法.有没有办法将权重加载到所有变量?

I can't find a way to load the weights into the new network's variables. Is there a way to load the weights to all the variables?

我尝试了以下操作但出现错误:

I tried the following but get en error:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    i.initializer = weights[i.name]

错误:

AttributeError: can't set attribute

推荐答案

两个函数都可以写

def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
    return {v.name: sess.run(v) for v in tf.global_variables()}


def load_from_dict(sess, data):
    for v in tf.global_variables():
        if v.name in data.keys():
            sess.run(v.assign(data[v.name]))

诀窍是简单地遍历所有变量并检查它们是否存在于字典中,例如

The trick is to simply iterate over all variables and just check whether they exists in the dictionary, like

import tensorflow as tf
import numpy as np


def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
    return {v.name: sess.run(v) for v in tf.global_variables()}


def load_from_dict(sess, data):
    for v in tf.global_variables():
        if v.name in data.keys():
            sess.run(v.assign(data[v.name]))


def network(x):
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc0')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc1')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc2')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc3')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc4')
    return x


element = np.random.randn(8, 10)
weights = None

# first session
with tf.Session() as sess:

    x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    y = network(x)
    sess.run(tf.global_variables_initializer())

    # first evaluation
    expected = sess.run(y, {x: element})

    # dump as dict
    weights = save_to_dict(sess)

# destroy session and graph
tf.reset_default_graph()

# second session
with tf.Session() as sess:

    x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    y = network(x)
    sess.run(tf.global_variables_initializer())

    # use randomly initialized parameters
    actual = sess.run(y, {x: element})
    assert np.sum(np.abs(actual - expected)) > 0  # should NOT match

    # load previous parameters
    load_from_dict(sess, weights)

    actual = sess.run(y, {x: element})
    assert np.sum(np.abs(actual - expected)) == 0  # should match

这样,您可以简单地从字典中删除一些参数,在加载之前更改权重,甚至更改参数名称.

This way, you can simply drop some parameters from the dictionary, change the weights before loading and even change the parameter-name.

这篇关于如何用另一个网络的权重初始化一个网络的权重?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆