# Copyright 2023 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch compatibility module
This submodule is intended to provide PyTorch implementations of the components critical
to feature computation. It is not meant to comprehensively reproduce all functionality
in PyTorch. Each PyTorch module here contains a class method which initializes the
PyTorch module with some analogous Numpy instance discussed elsewhere. For example,
assuming `stft_frame_computer` is an instance of a
:class:`pydrobert.speech.STFTFrameComputer`, one may instantiate a
:class:`PyTorchSTFTFrameComputer` via
>>> pytorch_stft_frame_computer = PyTorchSTFTFrameComputer.from_stft_frame_computer(
... stft_frame_computer)
"""
import math
from typing import Collection, List, Optional, Sequence, Tuple
import warnings
import torch
from . import config
from .pre import Dither, Preemphasize
from .post import PostProcessor
from .compute import STFTFrameComputer, SIFrameComputer
try:
from typing import Self, Literal
except ImportError:
from typing_extensions import Self, Literal
__all__ = [
"pytorch_dither",
"pytorch_preemphasize",
"pytorch_stft_frame_computer",
"PyTorchDither",
"PyTorchPostProcessorWrapper",
"PyTorchPreemphasize",
"PyTorchShortIntegrationFrameComputer",
"PyTorchShortTimeFourierTransformFrameComputer",
"PyTorchSIFrameComputer",
"PyTorchSTFTFrameComputer",
]
def check_in(name: str, val: str, choices: Collection[str]):
if val not in choices:
choices = "', '".join(sorted(choices))
raise ValueError(f"Expected {name} to be one of '{choices}'; got '{val}'")
def check_positive(name: str, val, nonnegative=False):
pos = "non-negative" if nonnegative else "positive"
if val < 0 or (val == 0 and not nonnegative):
raise ValueError(f"Expected {name} to be {pos}; got {val}")
[docs]
def pytorch_preemphasize(sig: torch.Tensor, coeff: float = 0.97) -> torch.Tensor:
"""Functional implementation of PyTorchPreemphasize"""
sig = torch.concatenate([sig.new_zeros(1), sig])
return sig[1:] - coeff * sig[:-1]
[docs]
class PyTorchPreemphasize(torch.nn.Module):
"""PyTorch implementation of Preemphasize
Parameters
----------
coeff
Preemphasis coefficient
"""
__constants__ = ("coeff",)
coeff: float
def __init__(self, coeff: float = 0.97) -> None:
super().__init__()
self.coeff = coeff
@classmethod
def from_preemphasize(cls, preemphasize: Preemphasize) -> Self:
return cls(preemphasize.coeff)
def forward(self, sig: torch.Tensor) -> torch.Tensor:
return pytorch_preemphasize(sig, self.coeff)
[docs]
def pytorch_dither(sig: torch.Tensor, coeff: float = 1.0) -> torch.Tensor:
"""Functional implementation of PyTorchDither"""
return sig + coeff * torch.randn_like(sig)
[docs]
class PyTorchDither(torch.nn.Module):
"""PyTorch implementation of Dither
Add random, normally-distributed noise to a signal
Parameters
----------
coeff
The standard deviation of the noise
dim
The dimension to apply noise to. If unspecified, applied to all coefficients
Notes
-----
While it is usually the case in PyTorch that random noise is only added during
training, dithering serves a
"""
__constants__ = ("coeff",)
coeff: float
def __init__(self, coeff: float = 1.0):
check_positive("coeff", coeff, True)
super().__init__()
self.coeff = coeff
@classmethod
def from_dither(cls, dither: Dither) -> Self:
return cls(dither.coeff)
def forward(self, sig: torch.Tensor) -> torch.Tensor:
return pytorch_dither(sig, self.coeff)
@torch.jit.script_if_tracing
def pytorch_stft_frame_computer(
sig: torch.Tensor,
filters: List[torch.Tensor],
offsets: List[int],
frame_length: int,
frame_shift: int,
centered: bool = True,
window: Optional[torch.Tensor] = None,
dft_size: Optional[int] = None,
use_log: bool = True,
use_power: bool = False,
include_energy: bool = False,
kaldi_shift: bool = False,
is_real: bool = True,
eps: float = config.LOG_FLOOR_VALUE,
) -> torch.Tensor:
"""Functional implementation of PyTorchShortTimeFourierTransformFrameComputer"""
if dft_size is None:
dft_size_ = int(2 ** math.ceil(math.log(frame_length, 2)))
elif dft_size < frame_length:
raise RuntimeError(f"expected dft_size gte {frame_length}; got {dft_size}")
else:
dft_size_ = dft_size
num_filts = len(filters)
if num_filts != len(offsets):
raise RuntimeError(
f"filters ({num_filts}) has different length than offsets "
f"({len(offsets)})"
)
if sig.ndim != 1:
raise RuntimeError(f"Expected x to be 1-dimensional; got {sig.ndim}")
if window is not None and window.shape != (frame_length,):
raise RuntimeError(
f"Expected window to have shape {(frame_length,)}; got {window.shape}"
)
sig_len = sig.size(0)
if sig_len < frame_length // 2 + 1:
return sig.new_empty((0, num_filts))
zero = sig.new_zeros(1)
if not centered:
pad_left = 0
elif kaldi_shift:
pad_left = frame_length // 2 - frame_shift // 2
else:
pad_left = (frame_length + 1) // 2 - 1
num_frames = max(0, (sig_len + frame_shift // 2) // frame_shift)
total_len = (num_frames - 1) * frame_shift - pad_left + frame_length
pad_right = max(0, total_len - sig_len)
if pad_left or pad_right:
# symmetric padding
sig = torch.cat(
[sig[:pad_left].flip(0), sig, sig[sig_len - pad_right :].flip(0)]
)
sig = sig.as_strided((num_frames, frame_length), (frame_shift, 1))
y: List[torch.Tensor] = []
if include_energy:
energy = torch.linalg.norm(sig, 2, 1) / math.sqrt(frame_length)
if use_power:
energy = energy.square()
y.append(energy)
if window is not None:
sig = sig * window
spect = torch.fft.rfft(sig, dft_size_, 1, "backward")
del sig
half_len = spect.size(1)
mod = half_len % 2
for si, filt in zip(offsets, filters):
val, consumed, conj, filt_len = zero, 0, False, len(filt)
while consumed < filt_len:
if conj:
seg_len = max(min(si + filt_len - consumed, half_len - 2 + mod) - si, 0)
seg = spect[..., -2 + mod - si - seg_len : -2 + mod - si].conj().flip(1)
si -= half_len - 2 + mod
else:
seg_len = max(0, min(si + filt_len - consumed, half_len) - si)
seg = spect[..., si : si + seg_len]
si -= half_len
seg = seg * filt[consumed : consumed + seg_len]
if use_power:
val_f = torch.linalg.norm(seg, 2, 1).square()
else:
val_f = seg.abs().sum(1)
if is_real:
val_f = val_f * 2
val = val + val_f
conj = not conj
consumed += seg_len
si = max(0, si)
y.append(val)
y_ = torch.stack(y, 1)
if use_log:
y_ = y_.clamp_min(eps).log()
return y_
[docs]
class PyTorchShortTimeFourierTransformFrameComputer(torch.nn.Module):
"""PyTorch implementation of STFTFrameComputer
This module is a port of
:class:`pydrobert.speech.compute.ShortTimeFourierTransformFrameComputer` to PyTorch.
When called, the output should be nearly identical to a call to
:func:`ShortTimeFourierTransformFrameComputer.compute_full`, except
:class:`torch.Tensor` inputs and outputs are expected.
The easiest means of initializing this module is through the factory function
:func:`from_numpy_frame_computer`, which determines the below parameters from an
:class:`STFTFrameComputer` which has already been initialized.
The filters and window are learnable/adjustable. Be sure to disable gradients with
:func:`torch.no_grad` if a fixed feature representation is desirable.
Parameters
----------
offsets_and_truncated_filters
A sequence of pairs ``(offset, truncated_filter)``. `truncated_filter` is a
one-dimensional tensor of the non-zero frequency response of a single filter in
the bank. `offset` is the index in the short-time spectrum at which the
`truncated_filter` begins.
frame_length
The number of audio samples constituting a frame.
frame_shift
The number of audio samples between subsequent frames.
frame_style
If ``'causal'``, the first frame begins at sample ``0``. Otherwise, the
first frame is centered around sample ``0`` with the exact behaviour dictated
by the `kaldi_shift` flag.
window
If specified, a tensor of shape ``(frame_length,)`` containing the windowing
function. If unspecified, implicit rectangular windowing will be performed
(with no gradient).
dft_size
The size of the spectrum to compute for each frame. Must be greater than
or equal to `frame_length`. If unspecified, the first power of two at or beyond
the frame length will be chosen.
use_log
Whether to take the logarithm of the resulting representation
use_power
Take the power spectrum instead of the magnitude spectrum
include_energy
Whether to add a coefficient at index 0 corresponding to the frame-wise energy
of the signal
kaldi_shift
Dictates how to center frames when `frame_style` is :obj:`'centered'`. If
:obj:`True`, the k-th frame will be computed using the signal between ``signal[
k * frame_shift - frame_length // 2 + frame_shift // 2:k * frame_shift +
(frame_length + 1) // 2 + frame_shift // 2]``. These are the frame bounds for
Kaldi [povey2011]_. Otherwise, the k-th frame is ``signal[ k * frame_shift -
(frame_length + 1) // 2 + 1: k * frame_shift + frame_length // 2 + 1]``.
is_real
Whether the filters are real in the time domain. If :obj:`True`, coefficients
will be doubled (pre-log) to account for Hermitian symmetry.
"""
__constants__ = (
"centered",
"dft_size",
"frame_length",
"frame_shift",
"offsets",
"include_energy",
"use_log",
"use_power",
)
centered: bool
dft_size: int
frame_length: int
frame_shift: int
offsets: Tuple[int, ...]
include_energy: bool
use_log: bool
use_power: bool
kaldi_shift: bool
is_real: bool
def __init__(
self,
offsets_and_truncated_filters: Sequence[Tuple[int, torch.Tensor]],
frame_length: int,
frame_shift: int,
frame_style: Literal["centered", "causal"] = "centered",
window: Optional[torch.Tensor] = None,
dft_size: Optional[int] = None,
use_log: bool = True,
use_power: bool = False,
include_energy: bool = False,
kaldi_shift: bool = False,
is_real: bool = False,
) -> None:
offsets, filters = [], []
for i, (offset, filter) in enumerate(offsets_and_truncated_filters):
if filter.ndim != 1:
raise ValueError(f"filter {i} is not a vector")
elif not filter.size(0):
raise ValueError(f"filter {i} is empty")
check_positive(f"filter {i} offset", offset, True)
offsets.append(offset)
filters.append(filter)
check_positive("frame_length", frame_length)
check_positive("frame_shift", frame_shift)
check_in("frame_style", frame_style, {"causal", "centered"})
if window is not None:
if window.shape != (frame_length,):
raise ValueError(
f"Expected window.shape to be ({frame_length},); got {window.shape}"
)
if dft_size is None:
dft_size = 2 ** math.ceil(math.log(frame_length, 2))
elif dft_size < frame_length:
raise ValueError(
f"Expected dft_size to be gte {frame_length}; got {dft_size}"
)
super().__init__()
self.frame_length, self.frame_shift = frame_length, frame_shift
self.offsets, self.centered = tuple(offsets), frame_style == "centered"
self.dft_size, self.use_log, self.use_power = dft_size, use_log, use_power
self.kaldi_shift, self.is_real = kaldi_shift, is_real
self.include_energy = include_energy
self.filters = torch.nn.ParameterList(filters)
if window is None:
self.register_parameter("window", None)
else:
self.window = torch.nn.Parameter(window)
def forward(self, signal: torch.Tensor) -> torch.Tensor:
return pytorch_stft_frame_computer(
signal,
list(self.filters),
self.offsets,
self.frame_length,
self.frame_shift,
self.centered,
self.window,
self.dft_size,
self.use_log,
self.use_power,
self.include_energy,
self.kaldi_shift,
self.is_real,
)
PyTorchSTFTFrameComputer = PyTorchShortTimeFourierTransformFrameComputer
[docs]
class PyTorchPostProcessorWrapper(torch.nn.Module):
"""A PyTorch wrapper around a PostProcessor
This module merely casts incoming tensors to a :class:`numpy.ndarray`, runs
:func:`pydrobert.speech.post.PostProcessor.apply` on the result, then casts it
back into a tensor.
Most :class:`PostProcessor` classes have been reimplemented in
:mod:`pydrobert.torch` with a bona fide PyTorch implementation, which should
be preferred.
"""
__constants__ = ("postprocessor",)
postprocessor: PostProcessor
def __init__(self, postprocessor: PostProcessor):
super().__init__()
self.postprocessor = postprocessor
@classmethod
def from_postprocessor(cls, postprocessor: PostProcessor) -> Self:
return cls(postprocessor)
@torch.jit.unused
def _postprocessor_appy(self, sig: torch.Tensor) -> torch.Tensor:
if sig.device.type != "cpu":
warnings.warn(
"PyTorchPostProcessorWrapper being used on non-cpu tensor. Will "
"send to cpu for computations, then back"
)
return torch.tensor(
self.postprocessor.apply(sig.cpu().numpy()),
device=sig.device,
dtype=sig.dtype,
)
def forward(self, sig: torch.Tensor) -> torch.Tensor:
return self._postprocessor_appy(sig)
[docs]
class PyTorchShortIntegrationFrameComputer(torch.nn.Module):
"""PyTorch implementation of SIFrameComputer
This module is a port of
:class:`pydrobert.speech.compute.ShortIntegrationFrameComputer` to PyTorch. When
called, the output should be nearly identical to a call to
:func:`ShortIntegrationFrameComputer.compute_full`, except :class:`torch.Tensor`
inputs and outputs are expected.
Warnings
--------
This module is currently a mere wrapper around a
:class:`ShortIntegrationFrameComputer` instance. While we plan on reimplementing the
computer with bona fide PyTorch operations at a later date, for now, relying on the
factory function :func:`from_si_frame_computer` is the best way to ensure forward
compatibility. In addition, the module state dict cannot be saved nor loaded to
ensure forward compatibility.
"""
si_frame_computer: SIFrameComputer
def __init__(self, si_frame_computer: SIFrameComputer):
super().__init__()
self.si_frame_computer = si_frame_computer
@classmethod
def from_si_frame_computer(cls, si_frame_computer: SIFrameComputer) -> Self:
return cls(si_frame_computer)
def state_dict(self):
raise NotImplementedError
def load_state_dict(self, *args, **kwargs):
raise NotImplementedError
@torch.jit.unused
def _compute_full(self, sig: torch.Tensor) -> torch.Tensor:
return torch.tensor(
self.si_frame_computer.compute_full(sig.cpu().numpy()),
device=sig.device,
dtype=sig.dtype,
)
def forward(self, sig: torch.Tensor) -> torch.Tensor:
return self._compute_full(sig)
PyTorchSIFrameComputer = PyTorchShortIntegrationFrameComputer