如何在 Tensorflow 中复制 PyTorch 的 nn.functional.unfold 函数? [英] How to replicate PyTorch's nn.functional.unfold function in Tensorflow?

查看:95
本文介绍了如何在 Tensorflow 中复制 PyTorch 的 nn.functional.unfold 函数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想用tensorflow重写pytorch的torch.nn.functional.unfold函数:

I want to use tensorflow to rewrite the pytorch's torch.nn.functional.unfold function:

#input x:[16, 1, 50, 36]
x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)
#output x:[16, 180, 16]

我尝试使用函数tf.extract_image_patches():

x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], ratings=[1, 1, 1, 1],padding='VALID')

输入x.shape:[16,1,64,98]

我得到输出 x.shape:[16,1,20,490]

然后我将 X 重塑为 [16,490,20],这是我所期望的.

Then I reshape the X to [16,490,20], that was I expect.

但是当我提供数据时出现错误:

But I get the error when I feed the data:

UnimplementedError (see above for traceback): Only support ksizes across space.
[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]

如何使用 tensorflow 重写 pytorch torch.nn.functional.unfold 函数来更改 X?

How could I use tensorflow to rewrite pytorch torch.nn.functional.unfold function to change the X?

推荐答案

x = tf.reshape(x, [16, 50, 36, 1])
x = tf.extract_image_patches(x, ksizes=[1, 4, 98, 1], strides=[1, 4, 1, 1], rates=[1, 1, 1, 1], padding='VALID')

这篇关于如何在 Tensorflow 中复制 PyTorch 的 nn.functional.unfold 函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆