在 Python 中编写和注册自定义 Tensorflow Op [英] Writing and Registering a Custom Tensorflow Op in Python

查看:48
本文介绍了在 Python 中编写和注册自定义 Tensorflow Op的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想用 Python 编写一个自定义的 Tensorflow 操作,并将其注册到 Protobuf 注册表中以进行类似解释的操作 此处.Protobuf 注册是关键,因为我不会直接从 Python 使用这个 op,但是如果它像 C++ op 一样注册并加载到 Python 运行时环境中,那么我可以在我的环境中运行它.

I want to write a custom Tensorflow op in Python and register it in the Protobuf registry for operations like explained here. The Protobuf registration is key because I will not be using this op directly from Python, but if it is registered like a C++ op and loaded into the Python runtime environment then I can run it in my environment.

我希望代码看起来像,

import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list

""" Missing the Python equivalent of,                                                                                                                                                                        

  class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {                                                                                                                                              
  public:                                                                                                                                                                                                    
      // Implementation                                                                                                                                                                                      
  };                                                                                                                                                                                                         

  REGISTER_OP("HDF5Queue")                                                                                                                                                                                   
  .Output("handle: resource")                                                                                                                                                                                
  .Attr("filename: string")                                                                                                                                                                                  
  .Attr("datasets: list(string)")                                                                                                                                                                            
  .Attr("overwrite: bool = false")                                                                                                                                                                           
  .Attr("component_types: list(type) >= 0 = []")                                                                                                                                                             
  .Attr("shapes: list(shape) >= 0 = []")                                                                                                                                                                     
  .Attr("shared_name: string = ''")                                                                                                                                                                          
  .Attr("container: string = ''")                                                                                                                                                                            
  .Attr("capacity: int = -1")                                                                                                                                                                                
  .SetIsStateful()                                                                                                                                                                                           
  .SetShapeFn(TwoElementOutput);                                                                                                                                                                             

"""

class HDF5Queue(QueueBase):
  def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100,
               shapes=None, names=None, name="hdf5_queue"):
    if not dtypes:
      dtypes = [tf.int64, tf.float32]

    if not shapes:
      shapes = [[1], [1]]

    dtypes = _as_type_list(dtypes)
    shapes = _as_shape_list(shapes, dtypes)
    names = _as_name_list(names, dtypes)
    queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id,
                                     stream_columns=stream_columns, capacity=capacity,
                                     component_types=dtypes, shapes=shapes,
                                     name=name, container=None, shared_name=None)
    super(HDF5Queue, self).__init__(dtypes, shapes,
                                    names, queue_ref)

以上是 TF 的标准.例如,可以使用 FIFOQueue 来查看.Python 包装器Protobuf 注册C++ 实现.在编译期间生成了一个我不喜欢的 Python 包装器,但是您可以通过运行 grep -A 10 -B 10 -n FIFO $(find/usr/local -name "*gen_data_flow*.py")/dev/null

The above is pretty standard from TF. It can be seen for example with FIFOQueue. Python Wrapper, Protobuf Registration, C++ Implementation. There is a Python wrapper generated during compilation that I can't like to, but you see where its used by running grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null

下面将以 JSON 格式为 TF Graph 转储 Protobuf 消息.我希望这会与 HDF5Queue 操作的块一起转储,就像我编写 C++ 操作一样.

Below will dump a Protobuf message for the TF Graph in JSON format. I would expect this to dump with a block for the HDF5Queue operation as it does if I write C++ operations.

with tf.Session() as sess:
    queue = HDF5Queue(stream_id=0xa)
    write = queue.enqueue([[1], [1.2]])
    read  = queue.dequeue()
    print json_format.MessageToJson(tf.train.export_meta_graph())

推荐答案

这可以使用 py_func 来完成.这是一个例子.

This can sort of be done using py_func. Here is an example.

import tensorflow as tf
from google.protobuf import json_format
import sys, json, base64, numpy
from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef

graph = tf.Graph()
graph2 = tf.Graph()

def f(x):
    return x

def g(x):
    return 2*x

with graph.as_default():
    x = tf.placeholder(tf.float32, shape=(3,), name='x')
    y = tf.py_func(f, [x], tf.float32, name='y')

    # py_func_registry._funcs.clear() # Optional line to clear the Python function registry
    msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph()))

# Change the function being used by py_func
msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g))

with graph2.as_default():    
    # Load graph
    meta_graph_def = MetaGraphDef()
    json_format.Parse(json.dumps(msg), meta_graph_def)
    tf.train.import_meta_graph(meta_graph_def)

    sess = tf.Session(graph=graph2)
    print sess.run('y:0', feed_dict={'x:0':numpy.array([1, 2, 3])})
    print g(numpy.array([1, 2, 3]))

这篇关于在 Python 中编写和注册自定义 Tensorflow Op的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆