从检查点还原时,如何更改参数的数据类型? [英] when restoring from a checkpoint, how can I change the data type of the parameters?

查看:229
本文介绍了从检查点还原时,如何更改参数的数据类型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个预先训练的Tensorflow检查点,其中所有参数都是float32数据类型.

I have a pre-trained Tensorflow checkpoint, where the parameters are all of float32 data type.

如何将检查点参数加载为float16?还是有办法修改检查点的数据类型?

以下是我的代码段,试图将float32检查点加载到float16图中,但出现类型不匹配错误.

Followings is my code snippet that tries to load float32 checkpoint into a float16 graph, and I got the type mismatch error.

import tensorflow as tf

A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dense))
    save_path = saver.save(sess, "tmp.ckpt")

tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    saver.restore(sess, "tmp.ckpt")
    print(sess.run(dense))
    pass

# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float

推荐答案

仔细研究

Looking a bit into how savers work, seems you can redefine their construction through a builder object. You could for example have a builder that loads values as tf.float32 and then casts them to the actual type of the variable:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder

class CastFromFloat32SaverBuilder(BaseSaverBuilder):
  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
    from tensorflow.python.ops import io_ops
    restore_specs = []
    for saveable in saveables:
      for spec in saveable.specs:
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    names, slices, dtypes = zip(*restore_specs)
    restore_dtypes = [tf.float32 for _ in dtypes]
    with tf.device("cpu:0"):
      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

请注意,这假定所有还原的变量均为tf.float32.您可以根据需要使构建器适合您的用例,例如在构造函数中传递源类型,等等.这样,您只需要在第二个保护程序中使用上面的构建器就可以使示例工作:

Note this assumes that all restored variables are tf.float32. You can adapt the builder appropriately for your use case if necessary, e.g. passing the source type or types in the constructor, etc. With this, you just need to use the above builder in the second saver to get your example to work:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign)
    sess.run(tf.global_variables_initializer())
    print('Value to save:')
    print(sess.run(dense))
    save_path = saver.save(sess, "ckpt/tmp.ckpt")

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
    saver.restore(sess, "ckpt/tmp.ckpt")
    print('Restored value:')
    print(sess.run(dense))

输出:

Value to save:
[[ 0.50589913  0.33701038 -0.11597633]
 [ 0.27372625  0.27724823  0.49825498]
 [ 1.0897961  -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506    0.337   -0.11597]
 [ 0.2737   0.2773   0.4983 ]
 [ 1.09    -0.296   -0.9175 ]]

这篇关于从检查点还原时,如何更改参数的数据类型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆