如何切片批处理并在TensorFlow中对每个切片进行操作 [英] How to slice a batch and apply an operation on each slice in TensorFlow

查看:360
本文介绍了如何切片批处理并在TensorFlow中对每个切片进行操作的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是TensorFlow的初学者,我正在尝试实现一个将批处理作为输入的函数。它必须将该批处理切成若干个,对其进行一些操作,然后将它们连接起来以构建新的张量以返回。通过阅读,我发现有一些已实现的功能,例如input_slice_producer和batch_join,但是我没有使用它们。我在下面附上了我作为解决方案找到的内容,但是它有点慢,不合适并且无法检测当前的批次大小。有没有人知道更好的方法?

I am a beginner with TensorFlow, I am trying to implement a function that takes a batch as input. It has to slice this batch into several ones, apply some operations on them, then concatenate them to build a new tensor to return. Through my readings, I found there are some implemented function like input_slice_producer and batch_join but I didn't get to work with them. I attached what I've found as solution below, but it's kinda slow, not proper and incapable of detecting the current size of batch. Does any know a better way of doing this?

def model(x):

    W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
    x_size = x.get_shape().as_list()[0]
    # x is a batch of bigger input of shape [None,6], so I couldn't 
    # get the proper size of the batch when feeding it 
    if x_size == None:
        x_size= batch_size
    #intialize the y_res
    dummy_x = tf.slice(x,[0,0],[1,6])
    result = tf.reduce_sum(tf.mul(dummy_x,W_1))
    y_res = tf.zeros([1], tf.float32)
    y_res = result
    #go throw all slices and concatenate them to get result
    for i in range(1,x_size): 
        dummy_x = tf.slice(x,[i,0],[1,6])
        result = tf.reduce_sum(tf.mul(dummy_x,W_1))
        y_res = tf.concat(0, [y_res, result])

    return y_res


推荐答案

TensorFlow函数 tf.map_fn(fn,elems) 允许您应用函数( fn )到张量的每个切片( elems )。例如,您可以将程序表示为以下形式:

The TensorFlow function tf.map_fn(fn, elems) allows you to apply a function (fn) to each slice of a tensor (elems). For example, you could express your program as follows:

def model(x):
    W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")

    def fn(x_slice):
        return tf.reduce_sum(x_slice, W_1)

    return tf.map_fn(fn, x)

也可能实现您的操作更简洁地在 tf上使用广播.mul() 运算符,它使用 NumPy广播语义,以及axis 参数/python/math_ops.html#reduce_sum rel = noreferrer> tf.reduce_sum()

It may also be possible to implement your operation more concisely using broadcasting on the tf.mul() operator, which uses NumPy broadcasting semantics, and the axis argument to tf.reduce_sum().

这篇关于如何切片批处理并在TensorFlow中对每个切片进行操作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆