Parseval theorem in torchaudio

1 minute read

Published:

In this post, we’ll proof the Parseval’s theorem in torchaudio.

import torchaudio
import torch
# download a temporary wav file

path_wav = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, sample_rate = torchaudio.load(path_wav)
print(torchaudio.info(path_wav))
print(waveform.shape)
AudioMetaData(sample_rate=16000, num_frames=54400, num_channels=1, bits_per_sample=16, encoding=PCM_S)
torch.Size([1, 54400])
class Unfold:
    def __init__(self, window_size, hop_size):
        self.window_size = window_size
        self.hop_size = hop_size

    def __call__(self, x):
        return x.unfold(-1, self.window_size, self.hop_size)


class Hann:
    def __init__(self, window_size):
        self.window = torch.hann_window(window_length=window_size, periodic=True)

    def __call__(self, x):
        return x * self.window
    
class RFFT:
    def __init__(self, n_fft):
        wsin, wcos = self.get_rfft(n_fft)
        self.wsin = wsin
        self.wcos = wcos

    @staticmethod
    def get_rfft(n_fft: int):
        # The number of FFT bins for a real signal
        freq_bins = n_fft // 2 + 1

        # The time vector
        vtime = torch.arange(0, n_fft, 1.0)
        vtime = vtime.reshape(-1, 1).repeat(1, freq_bins) # (n_fft, freq_bins)

        # The freq vector
        vfreq =  torch.arange(0, freq_bins, 1.0)
        vfreq = vfreq.reshape(1, -1).repeat(n_fft, 1) # (n_fft, freq_bins)

        # The sin and cos arg
        arg = 2. * torch.pi * vtime * vfreq / n_fft # (n_fft, freq_bins)

        # The sin and cos filters
        wsin = torch.sin(arg)
        wcos = torch.cos(arg)
        return wsin, wcos

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        real = x @ self.wsin.to(x.dtype)
        imag = x @ self.wcos.to(x.dtype)
        x = torch.stack([real, imag], -1)
        x = torch.view_as_complex(x)
        return x
    
class Compose:
    def __init__(self, *transforms):
        self.transforms = transforms

    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x
# params
window_size = 512
hop_size = 256
n_fft = 512

unfold = Unfold(window_size, hop_size)
window = Hann(window_size)
sum_pow2 = lambda x: x.pow(2.).sum(-1)
rfft = RFFT(n_fft)
abs_pow2 = lambda x: x.real ** 2 + x.imag ** 2

def add_negatives(x):
    x[..., 1:-1].mul_(2)
    return x

mean_sum = lambda x: x.sum(-1)/n_fft

Parseval’s theorem:

\[\sum_{t=1}^{n} |x_t|^2 = \frac{1}{n} \sum_{f=0}^{n} |x_f|^2\]
# functions

feature_time = Compose(
    unfold,
    window,
    sum_pow2,
)

feature_freq = Compose(
    unfold,
    window,
    rfft,
    abs_pow2,
    add_negatives,
    mean_sum,
)

left_side = feature_time(waveform)
right_side = feature_freq(waveform)

torch.isclose(left_side, right_side).all()
tensor(True)