Tensorflow Estimator API:总结 [英] Tensorflow Estimator API: Summaries

查看:34
本文介绍了Tensorflow Estimator API:总结的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我无法使用 Tensorflow 的 Estimator API 进行摘要.

I can't achieve to make summaries work with the Estimator API of Tensorflow.

Estimator 类非常有用,原因有很多:我已经实现了自己的类,它们非常相似,但我正在尝试切换到这个类.

The Estimator class is very useful for many reasons: I have already implemented my own classes which are really similar but I am trying to switch to this one.

这是代码示例:

import tensorflow as tf
import tensorflow.contrib.layers as layers
import tensorflow.contrib.learn as learn
import numpy as np

 # To reproduce the error: docker run --rm -w /algo -v $(pwd):/algo tensorflow/tensorflow bash -c "python sample.py"

def model_fn(x, y, mode):
    logits = layers.fully_connected(x, 12, scope="dense-1")
    logits = layers.fully_connected(logits, 56, scope="dense-2")
    logits = layers.fully_connected(logits, 4, scope="dense-3")

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y), name="xentropy")

    return {"predictions":logits}, loss, tf.train.AdamOptimizer(0.001).minimize(loss)


def input_fun():
    """ To be completed for a 4 classes classification problem """

    feature = tf.constant(np.random.rand(100,10))
    labels = tf.constant(np.random.random_integers(0,3, size=(100,)))

    return feature, labels

estimator = learn.Estimator(model_fn=model_fn, )

trainingConfig = tf.contrib.learn.RunConfig(save_checkpoints_secs=60)

estimator = learn.Estimator(model_fn=model_fn, model_dir="./tmp", config=trainingConfig)

# Works
estimator.fit(input_fn=input_fun, steps=2)

# The following code does not work

# Can't initialize saver

# saver = tf.train.Saver(max_to_keep=10) # Error: No variables to save

# The following fails because I am missing a saver... :(

hooks=[
        tf.train.LoggingTensorHook(["xentropy"], every_n_iter=100),
        tf.train.CheckpointSaverHook("./tmp", save_steps=1000, checkpoint_basename='model.ckpt'),
        tf.train.StepCounterHook(every_n_steps=100, output_dir="./tmp"),
        tf.train.SummarySaverHook(save_steps=100, output_dir="./tmp"),
]

estimator.fit(input_fn=input_fun, steps=2, monitors=hooks)

如您所见,我可以创建一个 Estimator 并使用它,但我可以实现在拟合过程中添加钩子.

As you can see, I can create an Estimator and use it but I can achieve to add hooks to the fitting process.

日志挂钩工作正常,但其他挂钩需要张量和我无法提供的保护程序.

The logging hooks works just fine but the others require both tensors and a saver which I can't provide.

张量在模型函数中定义,因此我无法将它们传递给 SummaryHook 并且 Saver 无法初始化,因为没有张量保存...

The tensors are defined in the model function, thus I can't pass them to the SummaryHook and the Saver can't be initialized because there is no tensor to save...

我的问题有解决方案吗?(我猜是的,但 tensorflow 文档中缺少这部分的文档)

Is there a solution to my problem? (I am guessing yes but there is a lack of documentation of this part in the tensorflow documentation)

  • 如何初始化我的保护程序?还是应该使用其他对象,例如 Scaffold?
  • 如何将 summaries 传递给 SummaryHook,因为它们是在我的模型函数中定义的?
  • How can I initialized my saver? Or should I use other objects such as Scaffold?
  • How can I pass summaries to the SummaryHook since they are defined in my model function?

提前致谢.

PS:我已经看过 DNNClassifier API,但我想将 estimator API 用于卷积网络和其他网络.我需要为任何估算器创建摘要.

推荐答案

预期用例是让 Estimator 为您保存摘要.RunConfig 中有选项 用于配置摘要写作.RunConfigs 在 构建时通过估算器.

The intended use case is that you let the Estimator save summaries for you. There are options in RunConfig for configuring summary writing. RunConfigs get passed when constructing the Estimator.

这篇关于Tensorflow Estimator API:总结的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆