如何在队列中的Tensorflow中解码pfm文件? [英] how to decode pfm files in Tensorflow in a queue?

查看:211
本文介绍了如何在队列中的Tensorflow中解码pfm文件?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经将文件名放入队列中,并且这些文件是* .pfm文件.然后,我编写了一个convert函数readPFM()将* .pfm文件转换为ndarray.

I've made a queue of filenames, and the files are *.pfm file. And I write a convert functionreadPFM() to convert the *.pfm file into ndarray.

我想做的是,当文件从队列中出队时,我将使用该函数将其转换为numpy ndarray.然后将其输入到图形中.但是代码不起作用.

What I want to do is that when a file is dequeued from the queue, I'll use the function to convert it into numpy ndarray. Then it will be fed into the graph. But the code doesn't work.

def disparity(batch_size, path, LR, epochs=2):
    filenames = file_name(path, LR, 'pfm')
    filenames = sorted(filenames)

    filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=epochs)
    reader = tf.WholeFileReader()
    key, img_bytes = reader.read(filename_queue)
    disparity, _ = readPFM(img_bytes)

    return tf.train.batch([disparity], batch_size, dynamic_pad=True)

pfm文件读取功能在这里.

The pfm file read func is here.

def readPFM(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data, scale


def writePFM(file, image, scale=1):
    file = open(file, 'wb')

    color = None

    if image.dtype.name != 'float32':
        raise Exception('Image dtype must be float32.')

    image = np.flipud(image)

    if len(image.shape) == 3 and image.shape[2] == 3:  # color image
        color = True
    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:  # greyscale
        color = False
    else:
        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')

    file.write('PF\n' if color else 'Pf\n')
    file.write('%d %d\n' % (image.shape[1], image.shape[0]))

    endian = image.dtype.byteorder

    if endian == '<' or endian == '=' and sys.byteorder == 'little':
        scale = -scale

    file.write('%f\n' % scale)

    image.tofile(file)

错误消息表明我的函数无法处理张量,因为它只能处理* .pfm文件.

The error message shows that my function cannot deal with the tensor, as it can only handle the *.pfm files.

有什么解决办法吗?

推荐答案

您不能像在tensorflow中那样使用readPFM函数,需要用tf.py_func包装它.

You can't use your readPFM function as is in tensorflow, you need to wrap it in with a tf.py_func.

# helper function
def decode_pfm(path):
    data, _ = load_pfm(open(path, 'rb'))

    # http://netpbm.sourceforge.net/doc/pfm.html
    # pfm stores the data bottom-to-top, need to reverse
    data = np.flipud(data)
    data = np.expand_dims(data, 2)
    return data

def read_and_decode(path):
    image_decoded = tf.py_func(decode_pfm, [path], tf.float32)

    # py_func does not set the shape, you might need to explictly
    # set it
    image_decoded.set_shape((H, W, channels))

这篇关于如何在队列中的Tensorflow中解码pfm文件?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆