Source code for pydrobert.speech.vis

# Copyright 2021 Sean Robertson

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Visualization functions

Raises
------
ImportError
    This submodule requires :mod:`matplotlib`
"""

from itertools import cycle
from typing import Optional, Sequence, Tuple, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import numpy as np

from matplotlib import pyplot as plt
from matplotlib import ticker
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
from pydrobert.speech.filters import LinearFilterBank
from pydrobert.speech.compute import FrameComputer, LinearFilterBankFrameComputer
from pydrobert.speech.post import PostProcessor


__all__ = [
    "compare_feature_frames",
    "plot_frequency_response",
]


[docs] def plot_frequency_response( banks: Union[Sequence[LinearFilterBank], LinearFilterBank], axes: Optional[Axes] = None, dft_size: Optional[int] = None, half: Optional[bool] = None, title: Optional[str] = None, x_scale: Literal["hz", "ang", "bins"] = "hz", y_scale: Literal["dB", "power", "real", "imag", "both"] = "dB", cmap: Colormap = None, ) -> Figure: """Plot frequency response of filters in a filter bank Parameters ---------- bank axes An :class:`Axes` object to plot on. Default is to generate a new figure dft_size The size of the Discrete Fourier Transform to plot. Defaults to ``max(max(bank.supports), 2 * bank.sampling_rate // min(bank.supports_hz)`` half Whether to plot the half or full spectrum. Defaults to ``bank.is_real`` title What to call the graph. The default is not to show a title x_scale The frequency coordinate scale along the x axis. Hertz (:obj:`'hz'`) is cycles/sec, angular frequency (:obj:`'ang'`) is radians/sec, and :obj:`'bins'` is the sample index within the DFT y_scale How to express the frequency response along the y axis. Decibels (:obj:`'dB'`) is the log of a ratio of the maximum quantity in the bank. The range between 0 and -20 decibels is displayed. Power spectrum (:obj:`'power'`) is the squared magnitude of the frequency response. :obj:`'real'` is the real part of the response, :obj:`'imag'` is the imaginary part of the response, and :obj:`'both'` displays both :obj:`'real'` and :obj:`'imag'` as separate lines cmap A :class:`Colormap` to pull colours from. Defaults to matplotlib's default colormap Returns ------- fig : matplotlib.figure.Figure The containing figure """ try: len(banks) except AttributeError: # 1 bank banks = [banks] if not all(x.num_filts for x in banks): raise ValueError("Filter banks must have at least one filter to be visualized") if not all(x.sampling_rate == banks[0].sampling_rate for x in banks): raise ValueError("Banks must all have the same sampling rate") rate = banks[0].sampling_rate if cmap is None: cmap = plt.get_cmap() if dft_size is None: dft_size = max( int( max( max(right - left for left, right in bank.supports), 2 * rate / min(right - left for left, right in bank.supports_hz), ) ) for bank in banks ) if half is None: half = all(bank.is_real for bank in banks) if axes is None: fig, axes = plt.subplots() else: fig = axes.get_figure() colours = cmap.colors r_colours = list(cmap.colors) r_colours.reverse() responses_colours = [] for bank, first_colour, second_color in zip( banks, cycle(colours), cycle(r_colours) ): responses_colours.extend( [ ( bank.get_frequency_response(filt_idx, dft_size, half=half), first_colour, second_color, ) for filt_idx in range(bank.num_filts) ] ) if half: x = np.arange( (dft_size + dft_size % 2) // 2 + 1 - dft_size % 2, dtype=np.float32 ) else: x = np.arange(dft_size, dtype=np.float32) if x_scale in ("hz", "Hz", "hertz", "Hertz"): x_title = "Frequency (Hz)" x *= rate x /= dft_size elif x_scale in ("ang", "angle", "angular"): x_title = "Angular Frequency" x *= 2 * np.pi x /= dft_size axes.xaxis.set_major_locator(ticker.MultipleLocator(np.pi)) axes.xaxis.set_minor_locator(ticker.AutoMinorLocator(2)) axes.xaxis.set_major_formatter(ticker.FuncFormatter(_pi_formatter)) axes.xaxis.set_minor_formatter(ticker.FuncFormatter(_pi_formatter)) elif x_scale == "bins": x_title = "DFT Bin" else: raise ValueError("Invalid x_scale: {}".format(x_scale)) if y_scale in ("db", "dB", "decibels"): y_title = "Log Ratio (dB)" # maximum abs. Get ripped max_abs = max(max(np.abs(response)) for response, _, _ in responses_colours) max_abs = np.log10(max(np.finfo(float).eps, max_abs)) for filt_idx in range(len(responses_colours)): response, first_colour, second_colour = responses_colours[filt_idx] response = np.abs(response) response[response <= np.finfo(float).eps] = np.nan response[...] = 20 * (np.log10(response) - max_abs) # looks better than discontinuities response[np.isnan(response)] = -1e10 responses_colours[filt_idx] = response, first_colour, second_colour y_max = 0 y_min = -10 elif y_scale in ("pow", "power"): y_title = "Power" y_min = 0 y_max = 0 for filt_idx in range(len(responses_colours)): response, first_colour, second_colour = responses_colours[filt_idx] response = np.abs(response) ** 2 y_max = max(y_max, max(response)) responses_colours[filt_idx] = response, first_colour, second_colour y_max *= 1.04 elif y_scale in ("real", "imag", "imaginary", "both"): if y_scale == "real": y_title = "Real-value response" elif y_scale == "both": y_title = "Complex response" else: y_title = "Imaginary-value response" y_min = np.inf y_max = -np.inf new_responses_colours = [] for response, first_colour, second_colour in responses_colours: if y_scale == "real": response = np.real(response) elif y_scale == "both": response_b = np.imag(response) response = np.real(response) y_max = max(y_max, max(response_b)) y_min = min(y_min, min(response_b)) new_responses_colours.append((response_b, second_colour, first_colour)) else: response = np.imag(response) y_max = max(y_max, max(response)) y_min = min(y_min, min(response)) new_responses_colours.append((response, first_colour, second_colour)) assert np.isfinite(y_min) and np.isfinite(y_max) y_max *= 0.96 if y_max < 0 else 1.04 y_min *= 0.96 if y_min > 0 else 1.04 del responses_colours responses_colours = new_responses_colours else: raise ValueError("Invalid y_scale: {}".format(y_scale)) axes.set_xlim((0, max(x))) axes.set_ylim((y_min, y_max)) if title: axes.set_title(title) axes.set_ylabel(y_title) axes.set_xlabel(x_title) for response, colour, _ in responses_colours: axes.plot(x, response, color=colour) # if y_scale == 'both': # real_handle = lines.Line2D([], [], color=real_colour, label='Real') # imag_handle = lines.Line2D([], [], color=imag_colour, label='Imag') # axes.legend(handles=[real_handle, imag_handle]) return fig
def _pi_formatter(val, _): num_halfpi = int(np.round(2 * val / np.pi)) if np.isclose(num_halfpi * np.pi / 2, val): if not num_halfpi: return "0" elif num_halfpi == 1: return "\\u03C0 / 2" elif num_halfpi == -1: return "-\\u03C0 / 2" elif num_halfpi == 2: return "\\u03C0" elif num_halfpi == -2: return "-\\u03C0" elif num_halfpi % 2: return "{}\\u03C0 / 2".format(num_halfpi) else: return "{}\\u03C0".format(num_halfpi // 2) else: return ""
[docs] def compare_feature_frames( computers: Union[FrameComputer, Sequence[FrameComputer]], signal: np.ndarray, axes: Optional[int] = None, figure_height: float = None, figure_width: float = None, plot_titles: Tuple[str, ...] = None, positions: Tuple[Union[int, Tuple[int, int]], ...] = None, post_ops: Optional[Union[PostProcessor, Sequence[PostProcessor]]] = None, title: Optional[str] = None, **kwargs ) -> Figure: """Compare features from frame computers via spectrogram-like heat map Direct comparison of :class:`FrameComputer` objects is possible because all subclasses of this abstract data type share a common interpretation of frame boundaries (according to :func:`FrameComputer.frame_style`). Additional keyword args will be passed to the plotting routine. Parameters ---------- computers One or more frame computers to compare signal A 1D array of the raw speech. Assumed to be valid with respect to computer settings (e.g. sample rate). axes By default, this function creates a new figure and subplots. Setting one `axes` value for every `computers` value will plot feature representations from `computers` into each ordered :class:`Axes`. If `axes` do not belong to the same figure, a :class:`ValueError` will be raised figure_height If a new figure is created, this sets the figure height (in inches). This value is determined dynamically according to `figure_width` by default. A :class:`ValueError` will be raised if both `figure_height` and `axes` are set figure_width If a new figure is created, this set the figure width (in inches). This value defaults to 3.33 inches if all subplots are positioned vertically, and to 7 inches if there are at least two columns of plots. A :class:`ValueError` will be raised if both `figure_width` and `axes` are set plot_titles An ordered list of strings specifying the titles of each subplot. The default is to not display subplot titles positions If a new figure is created, `positions` decides how the subplots should be positioned relative to one another. Can contain only ints (describing the position on only the row-axis) or pairs of ints (describing the row-col positions). Positions must be contiguous and start from index 0 or 0,0 (top or top-left). `positions` cannot be specified if `axes` is specified post_ops One or more post-processors to apply (in order) to each computed feature representation. If a simple list of post-processors is provided, each operation is applied to the default axis (the feature coefficient axis). To explicitly set the axis, pairs of ``(op, axis)`` can be specified in the list. No op is allowed to change the shape of the feature representation (e.g. :class:`Deltas`), or a :class:`ValueError` will be thrown title The title of the whole figure. Default is to display no title Returns ------- fig : matplotlib.figure.Figure The containing figure """ try: iter(computers) except TypeError: computers = (computers,) if not len(computers): raise ValueError("Expected at least one computer") if plot_titles is not None: try: iter(plot_titles) except TypeError: plot_titles = [plot_titles] if len(plot_titles) != len(computers): raise ValueError( "Expected {} plot titles, got {}".format( len(computers), len(plot_titles) ) ) else: plot_titles = [None] * len(computers) if positions is not None: if len(computers) == 1 and positions not in (0, (0,), [0]): raise ValueError("Nonzero position specified for only one plot") elif axes is not None: raise ValueError("Cannot specify positions of predefined axes") elif len(positions) != len(computers): raise ValueError( "Expected {} positions, got {}".format(len(computers), len(positions)) ) if any(hasattr(p, "__iter__") for p in positions) and not all( len(p) == 1 for p in positions if hasattr(p, "__iter__") ): # expect 2-dimensional plot positioning if any(not hasattr(p, "__iter__") or len(p) != 2 for p in positions): raise ValueError("Expected all plot positions to be two-dimensional") row_set = set(p[0] for p in positions) col_set = set(p[1] for p in positions) row_len, col_len = max(row_set) + 1, max(col_set) + 1 if row_set != set(r for r in range(row_len)) or col_set != set( c for c in range(col_len) ): raise ValueError("positions not contiguous") gs_args = (row_len, col_len) else: # expect 1-dimensional plot positioning. Using gridspec, # so have to add a column coordinate positions = tuple( (next(iter(p)), 0) if hasattr(p, "__iter__") else p for p in positions ) row_set = set(p[0] for p in positions) row_len = max(row_set) + 1 if row_set != set(r for r in range(row_len)): raise ValueError("positions not contiguous") gs_args = (row_len, 1) elif axes is None: # choose our own positions num_plots = len(computers) row_len = int(np.ceil(num_plots ** 0.5)) col_len = row_len while col_len * row_len != num_plots: if col_len * row_len > num_plots and col_len > 1: row_len += 1 col_len -= 1 else: row_len -= 1 gs_args = (row_len, col_len) positions = tuple(np.ndindex(gs_args)) if figure_width is not None: if axes is not None: raise ValueError("Cannot specify figure width with predefined axes") elif axes is None: figure_width = 7.0 if gs_args[1] > 1 else 3.33 if figure_height is not None: if axes is not None: raise ValueError("Cannot specify figure height with predefined axes") elif axes is None: figure_height = figure_width * 9 / 16 / gs_args[1] * gs_args[0] if post_ops is not None: try: iter(post_ops) except TypeError: post_ops = (post_ops,) if len(post_ops) == 2 and isinstance(post_ops[1], int): post_ops = (post_ops,) else: post_ops = [] if axes is not None: try: iter(axes) except TypeError: axes = (axes,) if len(axes) != len(computers): raise ValueError( "Expected {} axes, got {}".format(len(computers), len(axes)) ) fig = axes[0].get_figure() for ax in axes[1:]: if ax.get_figure() != fig: raise ValueError("Axes do not share the same figure") else: fig = plt.figure(figsize=(figure_width, figure_height)) if len(computers) == 1: axes = (fig.add_subplot(111),) else: axes = [] sharey = all( isinstance(computer, LinearFilterBankFrameComputer) for computer in computers ) gridspec = plt.GridSpec(gs_args[0], gs_args[1]) for position in positions: if axes and sharey: ax = fig.add_subplot( gridspec[position], sharex=axes[0], sharey=axes[0] ) elif axes: ax = fig.add_subplot(gridspec[position], sharex=axes[0]) else: ax = fig.add_subplot(gridspec[position]) axes.append(ax) supremum_seconds = np.infty num_samples = len(signal) for idx, (computer, ax, plot_title) in enumerate(zip(computers, axes, plot_titles)): frame_length = computer.frame_length frame_shift = computer.frame_shift if computer.frame_style == "causal": pad_left = 0 else: # centered pad_left = (frame_length + 1) // 2 - 1 total_len = num_samples + pad_left num_frames = max(0, (total_len - frame_length) // frame_shift + 1) # individual computers may choose to add a final frame by # padding. Since this behaviour is not guaranteed, we only # consider full frames if not num_frames: raise ValueError( "The computer indexed at {} is unable to generate " "a full frame from the signal".format(idx) ) # we use frame shifts to specify bounds (frame length is likely # overlapping), with the exception of the last frame sample_bounds = np.arange(num_frames + 1, dtype=float) * frame_shift if pad_left: # r.h.s. bound half a frame shift to right of center (or # half frame right of center for last frame) # l.h.s. bound half the other way (or 0 for first frame) sample_bounds[1:-1] -= (frame_shift + 1) // 2 - 1 sample_bounds[-1] = sample_bounds[-2] + pad_left else: # l.h.s bound leftmost idx of each frame # r.h.s. is l.h.s. plus frame shift (or frame length for # last frame) sample_bounds[-1] = sample_bounds[-2] + frame_length seconds_bounds = sample_bounds / computer.sampling_rate supremum_seconds = min(supremum_seconds, seconds_bounds[-1]) feat_slice = [slice(None, num_frames), slice(None)] if isinstance(computer, LinearFilterBankFrameComputer): yscale_label = "Frequency (Hz)" bank = computer.bank num_coeffs = bank.num_filts if computer.includes_energy: feat_slice[-1] = slice(1, None) supports_hz = bank.supports_hz assert num_coeffs == len(supports_hz) centers_hz = tuple((left + right) / 2 for left, right in supports_hz) # supports may be overlapping or sparse. Instead of using # supports to directly specify boundaries, we use them as # weights to pick points between center frequencies (except # the first and last filters, which get to extend their # lower and higher bounds to their supports, respectively. feature_bounds = np.empty(num_coeffs + 1) feature_bounds[0] = max(0, supports_hz[0][0]) feature_bounds[-1] = min(computer.sampling_rate / 2, supports_hz[-1][-1]) for high_idx in range(1, num_coeffs): low_c = centers_hz[high_idx - 1] high_c = centers_hz[high_idx] low_s, high_s = supports_hz[high_idx - 1] assert high_c >= low_c split_c = low_c * (high_s / (low_s + high_s)) split_c += high_c * (low_s / (low_s + high_s)) feature_bounds[high_idx] = split_c else: # no idea how to handle. Just plot rectangular coefficients num_coeffs = computer.num_coeffs yscale_label = None feature_bounds = np.arange(num_coeffs + 1) features = computer.compute_full(signal) assert features.shape[0] >= num_frames assert features[feat_slice].shape[-1] == num_coeffs for post_op_idx, post_op in enumerate(post_ops): try: apply_axis = post_op[1] post_op = post_op[0] except TypeError: apply_axis = -1 new_features = post_op.apply(features, axis=apply_axis) if new_features.shape != features.shape: raise ValueError( "The post_op indexed at {} changed the shape of the" "features".format(post_op_idx) ) features = new_features ax.pcolormesh(seconds_bounds, feature_bounds, features[feat_slice].T, **kwargs) if plot_title is not None: ax.set_title(plot_title) ax.set_xlabel("Time (seconds)") if yscale_label: ax.set_ylabel(yscale_label) for ax in axes: ax.set_xlim((0, supremum_seconds)) if title: fig.suptitle(title) return fig