# encoding: utf-8
# pylint: disable=C0103
# pylint: disable=too-many-arguments
"""
Display
=======
.. autosummary::
:toctree: generated/
wave_plot
map_show
feature_plot
embedding_plot
centroids_plot
plot_centroid
"""
import numpy as np
from pylab import get_cmap
from matplotlib import colors
import matplotlib.cm as cm
from matplotlib.axes import Axes
from matplotlib.ticker import NullFormatter
from librosa.display import TimeFormatter
from . import util
from .exceptions import ParameterError
__all__ = ['wave_plot', 'map_show', 'feature_plot', 'embedding_plot',
'centroids_plot', 'plot_centroid']
[docs]def wave_plot(y, sr=22050, x_axis='time', beats=None, beat_labs=None,
ax=None, **kwargs):
'''Plot an audio waveform and beat labels (optinal).
Parameters
----------
y : np.ndarray
audio time series
sr : number > 0 [scalar]
sampling rate of `y`
x_axis : str {'time', 'off', 'none'} or None
If 'time', the x-axis is given time tick-marks.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
kwargs
Additional keyword arguments to `matplotlib.`
Returns
-------
See also
--------
Examples
--------
'''
kwargs.setdefault('color', 'royalblue')
kwargs.setdefault('linestyle', '-')
kwargs.setdefault('alpha', 0.6)
if y.ndim > 1:
raise ValueError("`y` must be a one dimensional array. "
"Found y.ndim={}".format(y.ndim))
# time array in seconds
time = np.arange(y.size)/sr
# its maximum value
max_time = np.max(time)
# check axes and create it if needed
axes = __check_axes(ax)
# plot waveform
out = axes.plot(time, y, **kwargs)
if beats is not None:
__plot_beats(beats, max_time, axes, beat_labs=beat_labs, **kwargs)
# format x axis
if x_axis == 'time':
axes.xaxis.set_major_formatter(TimeFormatter(lag=False))
axes.xaxis.set_label_text('Time (s)')
elif x_axis is None or x_axis in ['off', 'none']:
axes.set_xticks([])
else:
raise ParameterError('Unknown x_axis value: {}'.format(x_axis))
return out
[docs]def feature_plot(feature, time, x_axis='time', beats=None, beat_labs=None,
ax=None, **kwargs):
'''Plot an audio waveform and beat labels (optinal).
Parameters
----------
feature : np.ndarray
feature time series
time : np.ndarray
time instant of the feature values
x_axis : str {'time', 'off', 'none'} or None
If 'time', the x-axis is given time tick-marks.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
kwargs
Additional keyword arguments to `matplotlib.`
Returns
-------
See also
--------
Examples
--------
'''
kwargs.setdefault('color', 'seagreen')
kwargs.setdefault('linestyle', '-')
kwargs.setdefault('alpha', 0.8)
if feature.ndim > 1:
raise ValueError("`feature` must be a one dimensional array. "
"Found feature.ndim={}".format(feature.ndim))
# maximum time value
max_time = np.max(time)
# check axes and create it if needed
axes = __check_axes(ax)
# plot waveform
out = axes.plot(time, feature, **kwargs)
if beats is not None:
__plot_beats(beats, max_time, axes, beat_labs=beat_labs, **kwargs)
# format x axis
if x_axis == 'time':
axes.xaxis.set_major_formatter(TimeFormatter(lag=False))
axes.xaxis.set_label_text('Time (s)')
elif x_axis is None or x_axis in ['off', 'none']:
axes.set_xticks([])
else:
raise ParameterError('Unknown x_axis value: {}'.format(x_axis))
return out
[docs]def centroids_plot(centroids, n_tatums=4, ax_list=None, **kwargs):
'''Plot centroids of rhythmic patterns clusters.
Parameters
----------
centroids: np.ndarray
centroids of the rhythmic patterns clusters
n_tatums : int
Number of tatums (subdivisions) per tactus beat
ax_list : list of matplotlib.axes.Axes or None, one element per centroid
Axes to plot on instead of the default `plt.gca()`.
kwargs
Additional keyword arguments to `matplotlib.`
Returns
-------
ax : list of matplotlib.axes.Axes
See also
--------
Examples
--------
'''
# number of centroids
n_centroids = len(centroids)
# check list of axes
ax = __check_axes_list(n_centroids, ax_list=ax_list)
# get colormap
cmap, _ = __get_colormap_map(n_centroids)
# plot each cluster
for ind, centroid in enumerate(centroids):
plot_centroid(centroid, n_tatums=n_tatums, ax=ax[ind],
color=cmap(ind/n_centroids), **kwargs)
return ax
[docs]def plot_centroid(centroid, n_tatums=4, ax=None, **kwargs):
'''Plot centroid of a rhythmic patterns cluster.
Parameters
----------
centroid : np.ndarray
centroid feature values
n_tatums : int
Number of tatums (subdivisions) per tactus beat
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
kwargs
Additional keyword arguments to `matplotlib.`
Returns
-------
See also
--------
Examples
--------
'''
kwargs.setdefault('color', 'seagreen')
kwargs.setdefault('alpha', 0.8)
if centroid.ndim > 1:
raise ValueError("`centroid` must be a one dimensional array. "
"Found centroid.ndim={}".format(centroid.ndim))
# number of tatums in centroid
c_tatums = centroid.size
# check axes and create it if needed
axes = __check_axes(ax)
# plot centroid
out = axes.bar(np.arange(c_tatums)+1, centroid, **kwargs)
# configure tickers and labels
__decorate_axis_centroid(axes, c_tatums, n_tatums)
return out
def __plot_beats(beats, max_time, ax, beat_labs=None, **kwargs):
'''Plot beat labels.
Parameters
----------
beats : np.ndarray
audio time series
beat_labs : list
beat labels
x_axis : str {'time', 'off', 'none'} or None
If 'time', the x-axis is given time tick-marks.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
kwargs
Additional keyword arguments to `matplotlib.`
Returns
-------
See also
--------
Examples
--------
'''
kwargs['color'] = 'black'
kwargs.setdefault('linestyle', '-')
kwargs['alpha'] = 0.3
kwargs.setdefault('linewidth', 2)
# consider beats (and labels) bellow max_time
ind_beat = util.find_nearest(beats, max_time)
new_beats = beats[:ind_beat]
if beat_labs is not None:
new_labs = beat_labs[:ind_beat]
# plot beat annotations
for beat in new_beats:
ax.axvline(x=beat, **kwargs)
# set ticks and labels
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks(new_beats)
ax2.set_xticklabels(new_labs)
#ax2.set_xlabel("beats")
return ax2
[docs]def map_show(data, x_coords=None, y_coords=None, ax=None,
n_tatums=4, clusters=None, **kwargs):
'''Display a feature map.
Parameters
----------
data : np.ndarray
Feature map to display
x_coords : np.ndarray [shape=data.shape[1]+1]
y_coords : np.ndarray [shape=data.shape[0]+1]
Optional positioning coordinates of the input data.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
n_tatums : int
Number of tatums (subdivisions) per tactus beat
clusters : np.ndarray
Array indicating cluster number for each pattern of the input data.
If provided (not None) the clusters area displayed with colors.
kwargs : additional keyword arguments
Arguments passed through to `matplotlib.pyplot.pcolormesh`.
By default, the following options are set:
- `cmap=gray_r`
- `rasterized=True`
- `edgecolors='None'`
- `shading='flat'`
Returns
-------
axes
The axis handle for the figure.
See Also
--------
matplotlib.pyplot.pcolormesh
Examples
--------
'''
kwargs.setdefault('cmap', cm.get_cmap('gray_r'))
kwargs.setdefault('rasterized', True)
kwargs.setdefault('edgecolors', 'None')
kwargs.setdefault('shading', 'flat')
# number of bars
bars = data.shape[0]
# number of tatums in a bar
tatums = data.shape[1]
# set the x and y coordinates
y_coords = np.array(range(tatums+1))+0.5
x_coords = np.array(range(bars))+1
# check axes and create it if needed
axes = __check_axes(ax)
# plot rhythmic patterns map (grayscale)
out = axes.pcolormesh(x_coords, y_coords, data.T, **kwargs)
__set_current_image(ax, out)
# if clusters are given then show them in colors
if clusters is not None:
# check clusters and return number of clusters
n_clusters = __check_clusters(clusters, bars)
# matrix to plot clusters' map
mapc = __get_cluster_matrix(clusters, y_coords.size)
# get colormap used to plot clusters
cmap, norm = __get_colormap_map(n_clusters)
# plot clusters in colors
axes.pcolormesh(x_coords, y_coords, mapc, cmap=cmap, norm=norm, alpha=0.6)
# set axes limits
axes.set_xlim(x_coords.min()-0.5, x_coords.max()+0.5)
axes.set_ylim(y_coords.min(), y_coords.max())
# configure tickers and labels
__decorate_axis_map(axes, tatums=n_tatums)
return axes
[docs]def embedding_plot(data, clusters=None, ax=None, **kwargs):
'''Display an 2D or 3D embedding of the rhythmic patterns data.
Parameters
----------
data : np.ndarray
Low-embedding data points
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
clusters : np.ndarray
Array indicating cluster number for each point of the input data.
If provided (not None) the clusters area displayed with colors.
kwargs : additional keyword arguments
Arguments passed through to `matplotlib.pyplot.pcolormesh`.
Returns
-------
axes
The axis handle for the figure.
See Also
--------
matplotlib.pyplot.pcolormesh
Examples
--------
'''
# number of points
points = data.shape[0]
# check if clusters are provided
if clusters is not None:
# check clusters and return number of clusters
n_clusters = __check_clusters(clusters, points)
# get colormap used to plot clusters
cmap, norm = __get_colormap_map(n_clusters)
# check axes and create it if needed
axes = __check_axes(ax)
# data dimension to check it is 2D or 3D
dim = data.shape[1]
if dim == 3:
if clusters is None:
axes.scatter(data[:, 0], data[:, 1], data[:, 2], **kwargs)
else:
axes.scatter(data[:, 0], data[:, 1], data[:, 2], c=clusters,
cmap=cmap, norm=norm, picker=2, **kwargs)
__decorate_axis_embedding(axes, dim)
elif dim == 2:
if clusters is None:
axes.scatter(data[:, 0], data[:, 1], **kwargs)
else:
axes.scatter(data[:, 0], data[:, 1], c=clusters,
cmap=cmap, norm=norm, picker=2, **kwargs)
__decorate_axis_embedding(axes, dim)
else:
raise ValueError("`data` points can have two or three dimension to be plotted. "
"Found data.shape[1]={}".format(data.shape[1]))
return axes
def __check_axes(axes):
'''Check if "axes" is an instance of an axis object.'''
if axes is None:
import matplotlib.pyplot as plt
axes = plt.gca()
elif not isinstance(axes, Axes):
raise ValueError("`axes` must be an instance of matplotlib.axes.Axes. "
"Found type(axes)={}".format(type(axes)))
return axes
def __check_axes_list(n_axes, ax_list=None):
'''Check if "ax_list" is a list of length n_axes
and each element is an instance of an axis object.
'''
if ax_list is None:
import matplotlib.pyplot as plt
fig = plt.gcf()
ax_list = []
for ind in range(n_axes):
ax = fig.add_subplot(n_axes, 1, ind+1)
ax_list.append(ax)
elif n_axes != len(ax_list):
raise ValueError("`ax_list` must be of correct size to match number of axes `n_axes`.")
else:
for ind in range(n_axes):
if not isinstance(ax_list[ind], Axes):
raise ValueError("`axes` must be an instance of matplotlib.axes.Axes. "
"Found type(axes)={}".format(type(ax_list[ind])))
return ax_list
def __check_clusters(clusters, bars):
'''Check if "clusters" is an instance of an axis object.
Check if "clusters" is a one dimensional array of the correct length.
'''
if isinstance(clusters, np.ndarray):
if clusters.ndim == 1:
if clusters.size == bars:
# count number of clusters
n_clusters = np.unique(clusters).size
else:
raise ValueError("`clusters` must be a one dimensional array. "
"Found clusters.ndim={}".format(clusters.ndim))
else:
raise ValueError("`clusters` must be an instance of numpy.ndarray. "
"Found type(axes)={}".format(type(clusters)))
return n_clusters
def __get_cluster_matrix(clusters, n_tatums):
'''Get clusters' matrix to plot clusters in map.'''
mapc = np.tile(clusters+0.5, (n_tatums, 1))
return mapc
def __get_colormap_map(n_clusters):
'''Get colormap for clusters' matrix.'''
# make a color map of fixed colors for colormesh
# cmap = get_cmap('RdBu', n_clusters)
cmap = get_cmap('tab10', n_clusters)
bounds = range(n_clusters+1)
norm = colors.BoundaryNorm(bounds, cmap.N)
return cmap, norm
def __set_current_image(ax, img):
'''Helper to set the current image in pyplot mode.
If the provided `ax` is not `None`, then we assume that the user is using the object API.
In this case, the pyplot current image is not set.
'''
if ax is None:
import matplotlib.pyplot as plt
plt.sci(img)
def __decorate_axis_map(axis, tatums=4):
'''Configure axis ticks and labels for feature map plot'''
# ticks at beats
ylims = axis.get_ylim()
all_tatums = int(ylims[1])
ticks_beats = [x+0.5 for x in range(0, all_tatums, tatums)]
num_beats = int(all_tatums / tatums)
labels_beats = [x+1 for x in range(num_beats)]
axis.yaxis.set_ticks(ticks_beats)
axis.set_yticklabels(labels_beats)
axis.tick_params(labelsize=10)
# axis.yaxis.set_major_formatter(NullFormatter())
axis.yaxis.grid()
gridlines = axis.get_ygridlines()
for line in gridlines:
line.set_linestyle('-')
line.set_linewidth(2)
line.set_color('black')
axis.set_ylabel('beats')
def __decorate_axis_centroid(axis, c_tatums=16, n_tatums=4, beat_ticks=True):
'''Configure axis ticks and labels for centroid plot.
Parameters
----------
axis : matplotlib.axes.Axes or None
c_tatums : int
Number of tatums (subdivisions) in centroid
n_tatums : int
Number of tatums (subdivisions) per tactus beat
beat_ticks : bool
If `True`, then labels are shown only at ticks corresponding to beats
'''
tatums = np.arange(c_tatums)
axis.xaxis.set_ticks(tatums + 1)
axis.xaxis.set_ticks_position('top')
axis.yaxis.set_major_formatter(NullFormatter())
axis.set_ylim(0, 1)
if beat_ticks:
beat_labs = [int(x / n_tatums) + 1 if (x % n_tatums) == 0 else ' ' for x in tatums]
axis.set_xticklabels(beat_labs)
def __decorate_axis_embedding(axes, dim):
'''Configure axis ticks and labels for embedding plot'''
if dim == 3:
axes.zaxis.set_major_formatter(NullFormatter())
axes.xaxis.set_major_formatter(NullFormatter())
axes.yaxis.set_major_formatter(NullFormatter())