NeoPy's picture
EXP
05aac64 verified
raw
history blame contribute delete
936 Bytes
import torch
import librosa
import torch.nn as nn
class RMSEnergyExtractor(nn.Module):
def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
super().__init__()
self.frame_length = frame_length
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
def forward(self, x):
assert x.ndim == 2
assert x.shape[0] == 1
if str(x.device).startswith("ocl"): x = x.contiguous()
rms = torch.from_numpy(
librosa.feature.rms(
y=x.squeeze(0).cpu().numpy(),
frame_length=self.frame_length,
hop_length=self.hop_length,
center=self.center,
pad_mode=self.pad_mode
)
)
return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)