如何在 .map 函数中访问张量形状? [英] How to access Tensor shape within .map function?

查看:28
本文介绍了如何在 .map 函数中访问张量形状?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个多种长度的音频数据集,我想在 5 秒的窗口中裁剪所有这些音频(这意味着 240000 个元素,采样率为 48000).所以,在加载 .tfrecord 后,我正在做:

I have a dataset of audios in multiple lengths, and I want to crop all of them in 5 second windows (which means 240000 elements with 48000 sample rate). So, after loading the .tfrecord, I'm doing:

audio, sr = tf.audio.decode_wav(image_data)

它返回一个具有音频长度的张量.如果这个长度小于 240000,我想重复音频内容,直到它是 240000.所以我正在使用 tf.data.Dataset.map() 函数处理所有音频:

which returns me a Tensor that has the audio length. If this length is less than the 240000 I would like to repeat the audio content til it's 240000. So I'm doing on ALL audios, with a tf.data.Dataset.map() function:

audio = tf.tile(audio, [5])

因为这就是将我最短的音频填充到所需长度所需要的.

Since that's what it takes to pad my shortest audio to the desired length.

但为了效率,我只想对需要它的元素进行操作:

But for efficiency I wanted to do the operation only on elements that need it:

if audio.shape[0] < 240000:
  pad_num = tf.math.ceil(240000 / audio.shape[0]) #i.e. if the audio is 120000 long, the audio will repeat 2 times
  audio = tf.tile(audio, [pad_num])

但是我无法访问 shape 属性,因为它是动态的并且会随着音频而变化.我试过使用 tf.shape(audio)audio.shapeaudio.get_shape(),但我得到了类似 形状,这不允许我进行比较.

But I can't access the shape property since it's dynamic and varies across the audios. I've tried using tf.shape(audio), audio.shape, audio.get_shape(), but I get values like None for the shape, that doesn't allow me to do the comparison.

可以这样做吗?

推荐答案

你可以使用这样的函数:

You can use a function like this:

import tensorflow as tf

def enforce_length(audio):
    # Target shape
    AUDIO_LEN = 240_000
    # Current shape
    current_len = tf.shape(audio)[0]
    # Compute number of necessary repetitions
    num_reps = AUDIO_LEN // current_len
    num_reps += tf.dtypes.cast((AUDIO_LEN % current_len) > 0, num_reps.dtype)
    # Do repetitions
    audio_rep = tf.tile(audio, [num_reps])
    # Trim to required size
    return audio_rep[:AUDIO_LEN]

# Test
examples = tf.data.Dataset.from_generator(lambda: iter([
    tf.zeros([100_000], tf.float32),
    tf.zeros([300_000], tf.float32),
    tf.zeros([123_456], tf.float32),
]), output_types=tf.float32, output_shapes=[None])
result = examples.map(enforce_length)
for item in result:
    print(item.shape)

输出:

(240000,)
(240000,)
(240000,)

这篇关于如何在 .map 函数中访问张量形状?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆