Parseval theorem in torchaudio
Published:
In this post, we’ll proof the Parseval’s theorem in torchaudio
.
import torchaudio
import torch
# download a temporary wav file
path_wav = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
waveform, sample_rate = torchaudio.load(path_wav)
print(torchaudio.info(path_wav))
print(waveform.shape)
AudioMetaData(sample_rate=16000, num_frames=54400, num_channels=1, bits_per_sample=16, encoding=PCM_S)
torch.Size([1, 54400])
class Unfold:
def __init__(self, window_size, hop_size):
self.window_size = window_size
self.hop_size = hop_size
def __call__(self, x):
return x.unfold(-1, self.window_size, self.hop_size)
class Hann:
def __init__(self, window_size):
self.window = torch.hann_window(window_length=window_size, periodic=True)
def __call__(self, x):
return x * self.window
class RFFT:
def __init__(self, n_fft):
wsin, wcos = self.get_rfft(n_fft)
self.wsin = wsin
self.wcos = wcos
@staticmethod
def get_rfft(n_fft: int):
# The number of FFT bins for a real signal
freq_bins = n_fft // 2 + 1
# The time vector
vtime = torch.arange(0, n_fft, 1.0)
vtime = vtime.reshape(-1, 1).repeat(1, freq_bins) # (n_fft, freq_bins)
# The freq vector
vfreq = torch.arange(0, freq_bins, 1.0)
vfreq = vfreq.reshape(1, -1).repeat(n_fft, 1) # (n_fft, freq_bins)
# The sin and cos arg
arg = 2. * torch.pi * vtime * vfreq / n_fft # (n_fft, freq_bins)
# The sin and cos filters
wsin = torch.sin(arg)
wcos = torch.cos(arg)
return wsin, wcos
def __call__(self, x: torch.Tensor) -> torch.Tensor:
real = x @ self.wsin.to(x.dtype)
imag = x @ self.wcos.to(x.dtype)
x = torch.stack([real, imag], -1)
x = torch.view_as_complex(x)
return x
class Compose:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
# params
window_size = 512
hop_size = 256
n_fft = 512
unfold = Unfold(window_size, hop_size)
window = Hann(window_size)
sum_pow2 = lambda x: x.pow(2.).sum(-1)
rfft = RFFT(n_fft)
abs_pow2 = lambda x: x.real ** 2 + x.imag ** 2
def add_negatives(x):
x[..., 1:-1].mul_(2)
return x
mean_sum = lambda x: x.sum(-1)/n_fft
Parseval’s theorem:
\[\sum_{t=1}^{n} |x_t|^2 = \frac{1}{n} \sum_{f=0}^{n} |x_f|^2\]# functions
feature_time = Compose(
unfold,
window,
sum_pow2,
)
feature_freq = Compose(
unfold,
window,
rfft,
abs_pow2,
add_negatives,
mean_sum,
)
left_side = feature_time(waveform)
right_side = feature_freq(waveform)
torch.isclose(left_side, right_side).all()
tensor(True)