使用 tf.image.Dataset 时在 tensorflow 中冻结图形 [英] Freezing graph in tensorflow when using tf.image.Dataset

查看:29
本文介绍了使用 tf.image.Dataset 时在 tensorflow 中冻结图形的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 tensorflow.python.tools.freeze_graph 在下面的函数中冻结张量流图:

I'm using tensorflow.python.tools.freeze_graph to freeze a tensorflow graph in the function below:

def freeze_and_save_graph(self, session, save_dir, name):
    checkpoint_prefix = os.path.join(save_dir, "model")
    checkpoint_state_name = "checkpoint"
    input_graph_name = "input_graph.pbtxt"
    output_graph_name = name

    # saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=max_checkpoints)
    checkpoint_path = self.saver.save(
        session,
        checkpoint_prefix,
        global_step=0,
        latest_filename=checkpoint_state_name)
    tf.train.write_graph(session.graph, save_dir, input_graph_name, as_text=True)
    input_graph_path = os.path.join(save_dir, input_graph_name)
    input_saver_def_path = ""
    input_binary = False
    output_node_names = "model_1/output"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(save_dir, output_graph_name)
    clear_devices = False
    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path, output_node_names,
                              restore_op_name, filename_tensor_name,
                              output_graph_path, clear_devices, "")

最近我改用 tensorflow.image.Dataset 进行预处理,如下所示:

Recently I switch to using tensorflow.image.Dataset to do preprocessing like so:

data = tf.data.Dataset.from_tensor_slices((images_train, onehot_train))
data = data.map(lambda x, y: (preprocessing_fn(x), y), num_parallel_calls=32)
data = data.shuffle(len(images_train))
data = data.batch(batch_size)
data = data.prefetch(5)
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer
session.run(init_op)

进行更改后,冻结图形需要永远.input_graph.pbtxt 的大小从 500kB 变成了 150MB.看一看,罪魁祸首是两个张量,它们的大小和形状与我的训练数据相同,并且定义了 tensor_content.即训练数据已经保存在文件中.

After making the change, freezing the graph is taking forever. The size of input_graph.pbtxt has gone from 500kB to 150MB. Having a look, the culprit is two tensors with the same size and shape as my training data and with tensor_content defined. That is, the training data has been saved in the file.

如何在没有这些数据的情况下保存图表?

How can I save the graph without this data?

推荐答案

我找到了解决方案.使用占位符而不是直接从数据构建数据集.变化是:

I found the solution. Use placeholders instead of constructing the Dataset directly from the data. The changes are:

image_tensor = tf.placeholder(tf.float32, shape=self.x_image.shape)
onehot_tensor = tf.placeholder(tf.float32, shape=self.y_true.shape)
data = tf.data.Dataset.from_tensor_slices((image_tensor, onehot_tensor))

session.run(init_op, feed_dict={images_tensor: image_train, onehot_tensor: onehot_train})

现在,当它保存图形时,它保存的是占位符而不是数据.

Now when it saves the graph, it saves placeholders instead of data.

这篇关于使用 tf.image.Dataset 时在 tensorflow 中冻结图形的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆