keras.backend.function()的用途是什么 [英] What's the purpose of keras.backend.function()

查看:1666
本文介绍了keras.backend.function()的用途是什么的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

Keras手册并没有说太多:

keras.backend.function(inputs, outputs, updates=None)

Instantiates a Keras function.
Arguments
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
**kwargs: Passed to tf.Session.run.
Returns

Tensorflow源代码实际上很短,它表明K.function(...)返回一个Function对象,该对象在被调用时使用以下方法评估 outputs updates 输入.有趣的部分是它如何处理我不关注的更新.任何解释/示例/指针,以帮助理解此K.function(...)表示赞赏!这是来自的相关部分Tensorflow源代码

Tensorflow source code, which is actually quite short, shows that K.function(...) return a Function object which, when called, evaluates the outputs and updates using the inputs. The interesting part is how it handles the updates which I don't follow. Any explanations/examples/pointers to help understanding this K.function(...) is appreciated! Here is the relevant part from Tensorflow source code

class Function(object):
  """Runs a computation graph.
  Arguments:
      inputs: Feed placeholders to the computation graph.
      outputs: Output tensors to fetch.
      updates: Additional update ops to be run at function call.
      name: a name to help users identify what this function does.
  """

  def __init__(self, inputs, outputs, updates=None, name=None,
               **session_kwargs):
    updates = updates or []
    if not isinstance(inputs, (list, tuple)):
      raise TypeError('`inputs` to a TensorFlow backend function '
                      'should be a list or tuple.')
    if not isinstance(outputs, (list, tuple)):
      raise TypeError('`outputs` of a TensorFlow backend function '
                      'should be a list or tuple.')
    if not isinstance(updates, (list, tuple)):
      raise TypeError('`updates` in a TensorFlow backend function '
                      'should be a list or tuple.')
    self.inputs = list(inputs)
    self.outputs = list(outputs)
    with ops.control_dependencies(self.outputs):
      updates_ops = []
      for update in updates:
        if isinstance(update, tuple):
          p, new_p = update
          updates_ops.append(state_ops.assign(p, new_p))
        else:
          # assumed already an op
          updates_ops.append(update)
      self.updates_op = control_flow_ops.group(*updates_ops)
    self.name = name
    self.session_kwargs = session_kwargs

  def __call__(self, inputs):
    if not isinstance(inputs, (list, tuple)):
      raise TypeError('`inputs` should be a list or tuple.')
    feed_dict = {}
    for tensor, value in zip(self.inputs, inputs):
      if is_sparse(tensor):
        sparse_coo = value.tocoo()
        indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                                  np.expand_dims(sparse_coo.col, 1)), 1)
        value = (indices, sparse_coo.data, sparse_coo.shape)
      feed_dict[tensor] = value
    session = get_session()
    updated = session.run(
        self.outputs + [self.updates_op],
        feed_dict=feed_dict,
        **self.session_kwargs)
    return updated[:len(self.outputs)]


def function(inputs, outputs, updates=None, **kwargs):
  """Instantiates a Keras function.
  Arguments:
      inputs: List of placeholder tensors.
      outputs: List of output tensors.
      updates: List of update ops.
      **kwargs: Passed to `tf.Session.run`.
  Returns:
      Output values as Numpy arrays.
  Raises:
      ValueError: if invalid kwargs are passed in.
  """
  if kwargs:
    for key in kwargs:
      if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
          key not in tf_inspect.getargspec(Function.__init__)[0]):
        msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
               'backend') % key
        raise ValueError(msg)
  return Function(inputs, outputs, updates=updates, **kwargs)

推荐答案

我对此功能keras.backend.function有以下理解.我将借助.

I have the following understanding of this function keras.backend.function. I will explain it with the help of a code snippet from this.

代码段的部分如下

final_conv_layer = get_output_layer(model, "conv5_3")
get_output = K.function([model.layers[0].input], [final_conv_layer.output, model.layers[-1].output])
[conv_outputs, predictions] = get_output([img])

在此代码中,有一个模型可以从中提取conv5_3层(第1行).在函数K.function()中,第一个参数输入到该模型,第二个参数设置为2个输出的集合-一个用于卷积,第二个用于最后一层的softmax输出.

In this code, there is a model from which conv5_3 layer is extracted (line 1). In the function K.function(), the first argument is input to this model and second is set of 2 outputs - one for convolution and second for softmax output at the last layer.

根据Keras/Tensorflow手册,此函数运行我们在代码中创建的计算图,从第一个参数获取输入,并根据第二个参数提及的层提取输出数量.因此, conv_outputs final_conv_layer 的输出,而预测model.layers[-1]的输出,即模型的最后一层.

As per the Keras/Tensorflow manual, this function runs the computation graph that we have created in the code, taking input from the first parameter and extracting the number of outputs as per the layers mentioned in the second parameter. Thus, conv_outputs are output of final_conv_layer and predictions are output of model.layers[-1], i.e. the last layer of the model.

这篇关于keras.backend.function()的用途是什么的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆