Source code for leaspy.io.logs.visualization.plotting

import os
import warnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import itertools

from leaspy.exceptions import (
    LeaspyIndividualParamsInputError,
    LeaspyInputError,
    LeaspyTypeError,
)

from ...outputs import IndividualParameters
from leaspy.models import LogisticMultivariateMixtureModel


__all__ = ["Plotting"]


# TODO: outdated -
[docs] class Plotting: """ .. deprecated:: 1.2 Class defining some plotting tools. Parameters ---------- model : :class:`~leaspy.models.BaseModel` The used model. output_path : :obj:`str`, (optional) Folder where plots will be saved. If None, default to current working directory. palette : :obj:`str` (palette name) or :class:`matplotlib.colors.Colormap` (`ListedColormap` or `LinearSegmentedColormap`) The palette to use. max_colors : :obj:`int` > 0, optional (default, corresponding to model nb of features) Only used if palette is a string """ def __init__(self, model, output_path=".", palette="tab10", max_colors=10): self.model = model # ---- Graphical options self.color_palette = None self.standard_size = (8, 4) self.linestyle = { "average_model": "-", "individual_model": "-", "individual_data": "-", } self.linewidth = { "average_model": 5, "individual_model": 2, "individual_data": 2, } self.alpha = {"average_model": 0.5, "individual_model": 1, "individual_data": 1} self.output_path = output_path self.set_palette(palette, max_colors)
[docs] def set_palette(self, palette, max_colors=None): """ Set palette of plots Parameters ---------- palette : :obj:`str` (palette name) or :class:`matplotlib.colors.Colormap` (`ListedColormap` or `LinearSegmentedColormap`) The palette to use. max_colors : :obj:`int` > 0, optional (default, corresponding to model nb of features) Only used if palette is a string """ if isinstance(palette, mpl.colors.Colormap): self.color_palette = palette else: if max_colors is None: if self.model.dimension is not None: raise LeaspyInputError( "Initialize model first please, with a not None dimension" ) max_colors = self.model.dimension self.color_palette = mpl.colormaps[palette].resampled(max_colors)
[docs] def colors(self, at=None): """ Wrapper over color_palette iterator to get colors Parameters ---------- at : any legit color_palette arg (int, float or iterable of any of these) or None (default) if None returns all colors of palette upto model dimension Returns ------- colors : single color tuple (RGBA) or np.array of RGBA colors (number of colors x 4) """ if at is None: at = [i % self.color_palette.N for i in range(self.model.dimension)] return self.color_palette(at)
def _raise_if_model_not_init(self): # /!\ Break if model is not initialized if not self.model.is_initialized: raise LeaspyInputError("Please initialize the model before plotting") def _handle_kwargs_begin(self, kwargs, all_features_list=None): """Extract kwargs corresponding to plot information and remove associated keys (in-place).""" # get features from initialized model if not set if all_features_list is None: self._raise_if_model_not_init() all_features_list = self.model.features # ---- Get requested features (may be a subset) features = kwargs.pop("features", all_features_list) features_ix = list(map(all_features_list.index, features)) # ---- Colors colors = kwargs.pop("color", self.colors(features_ix)) if len(colors) < len(features): raise LeaspyInputError( f"Please choose a palette with at least {len(features)} colors." ) # TODO: reindex default colors if subset of features? # ---- Labels labels = kwargs.pop("labels", features) if len(labels) != len(features): raise LeaspyInputError( f"Dimensions mismatch between features ({len(features)}) and labels ({len(labels)}." ) # ---- Ax ax = kwargs.pop("ax", None) if ax is None: fig, ax = plt.subplots( 1, 1, figsize=kwargs.pop("figsize", self.standard_size) ) # ---- Handle ylim if "logistic" in self.model.name: ax.set_ylim(0, 1) return ax, features, features_ix, labels, colors def _handle_kwargs_end(self, ax, kwargs, colors, labels): # ---- Legend dimension = len(labels) # if dimension is None: # dimension = self.model.dimension custom_lines = [ mpl.lines.Line2D([0], [0], color=colors[i], lw=4) for i in range(dimension) ] ax.legend(custom_lines, labels, title="Features") # ax.legend(title='Features') ax.set_ylabel("Normalized score") # ---- Save if "save_as" in kwargs.keys(): plt.savefig(os.path.join(self.output_path, kwargs["save_as"]))
[docs] def average_trajectory(self, **kwargs): """ Plot the population average trajectories. They are parametrized by the population parameters derived during the calibration. Parameters ---------- **kwargs * alpha: :obj:`float`, default 0.6 Matplotlib's transparency option. Must be in [0, 1]. * linestyle: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...} Matplotlib's linestyle option. * linewidth: :obj:`float` Matplotlib's linewidth option. * features: list[:obj:`str`] Name of features (if set it must be a subset of model features) Default: all model features. * colors: list[:obj:`str`] Contains matplotlib compatible colors. At least as many as number of features. * labels: list[:obj:`str`] Used to rename features in the plot. Exactly as many as number of features. Default: raw variable name of each feature * ax: matplotlib.axes.Axes Axes object to modify, instead of creating a new one. * figsize: tuple of int The figure's size. * save_as: :obj:`str`, default None Path to save the figure. * title: :obj:`str` * n_tpts: :obj:`int` Number of timepoints in plot (default: 100) * n_std_left, n_std_right: :obj:`float` (default: 3 and 6 resp.) Time window around `tau_mean`, expressed as times of max(`tau_std`, 4) Returns ------- :class:`matplotlib.axes.Axes` """ # ---- Input manager plot_kws = self._plot_kwargs("average", kwargs) ax, _, features_ix, labels, colors = self._handle_kwargs_begin(kwargs) # ---- Get timepoints mean_time = self.model.parameters["tau_mean"].item() std_time = max(self.model.parameters["tau_std"].item(), 4) timepoints = mean_time + std_time * np.linspace( -kwargs.get("n_std_left", 3), kwargs.get("n_std_right", 6), kwargs.get("n_tpts", 100), ) timepoints = torch.tensor(timepoints, dtype=torch.float32).unsqueeze(0) # ---- Compute average trajectory mean_trajectory = ( self.model.compute_mean_traj(timepoints).cpu().detach().numpy() ) # ---- plot it for each dimension for ft_ix, ft_lbl, ft_color in zip(features_ix, labels, colors): ax.plot( timepoints[0, :].cpu().detach().numpy(), mean_trajectory[0, :, ft_ix], c=ft_color, # label=ft_lbl, # not needed **plot_kws["model"], ) # ---- Title & labels ax.set_title("Average trajectories") ax.set_xlabel("Age") self._handle_kwargs_end(ax, kwargs, colors, labels) return ax
def _plot_kwargs(self, case, kwargs): if case == "average": return { "model": dict( alpha=kwargs.get("alpha", self.alpha["average_model"]), linestyle=kwargs.get("linestyle", self.linestyle["average_model"]), linewidth=kwargs.get("linewidth", self.linewidth["average_model"]), ) } elif case == "obs": return { "obs": dict( alpha=kwargs.get("alpha", self.alpha["individual_data"]), linestyle=kwargs.get( "linestyle", self.linestyle["individual_data"] ), linewidth=kwargs.get( "linewidth", self.linewidth["individual_data"] ), marker=kwargs.get("marker", "o"), markersize=kwargs.get("markersize", "3"), ) } elif case == "recons": # both observations & model will be displayed p_obs = dict( marker=kwargs.get("marker", "o"), # None not to display obs markersize=kwargs.get("markersize", "4"), alpha=kwargs.get("obs_alpha", self.alpha["individual_data"]), linestyle=kwargs.get("obs_ls", ""), linewidth=kwargs.get("obs_lw", self.linewidth["individual_data"]), ) p_model = dict( alpha=kwargs.get("alpha", self.alpha["individual_model"]), linestyle=kwargs.get("linestyle", self.linestyle["individual_model"]), linewidth=kwargs.get("linewidth", self.linewidth["individual_model"]), ) return {"obs": p_obs, "model": p_model} elif case == "cluster": return { "model": dict( alpha=kwargs.get("alpha", self.alpha["average_model"]), linewidth=kwargs.get("linewidth", self.linewidth["average_model"]), ) } else: raise LeaspyInputError("case must be in {'average', 'obs', 'recons', 'cluster'}") @staticmethod def _get_ip_df_torch(individual_parameters): # convert individual parameters in different cases if isinstance(individual_parameters, IndividualParameters): ip_df = individual_parameters.to_dataframe() ip_torch = individual_parameters.to_pytorch() elif isinstance(individual_parameters, pd.DataFrame): ip_df = individual_parameters ip_torch = IndividualParameters.from_dataframe( individual_parameters ).to_pytorch() elif isinstance(individual_parameters, tuple): ip_df = IndividualParameters.from_pytorch( *individual_parameters ).to_dataframe() ip_torch = individual_parameters else: raise LeaspyTypeError( "`individual_parameters` should be an IndividualParameters object, a pandas.DataFrame or a dict." ) if ip_df.index.names != ["ID"]: raise LeaspyIndividualParamsInputError( "Individual parameters index is not ['ID'] " f"as expected but {list(ip_df.index.names)}" ) return ip_df, ip_torch def _plot_patients_generic( self, case, data, patients_idx="all", individual_parameters=None, reparametrized_ages=False, **kwargs, ): # plot with reparametrized ages ip_df, ip_torch = None, None if individual_parameters is not None: self._raise_if_model_not_init() ip_df, ip_torch = self._get_ip_df_torch(individual_parameters) # ---- Input manager plot_kws = self._plot_kwargs(case, kwargs) with_model = "model" in plot_kws # plot reconstruction of model as well with_obs = "obs" in plot_kws and plot_kws["obs"].get("marker") is not None if not (with_model or with_obs): # (or both !) raise LeaspyInputError( "Nothing to plot... nor model values nor observations." ) # ---- Patients sublist if "patient_IDs" in kwargs.keys(): warnings.warn( "Keyword argument <patient_IDs> is deprecated! " "Use <patients_idx> instead.", DeprecationWarning, ) patients_idx = kwargs.get("patient_IDs") if isinstance(patients_idx, str): if patients_idx == "all": patients_idx = list(data.iter_to_idx.values()) else: patients_idx = [patients_idx] # features check if self.model.is_initialized: if data.headers != self.model.features: raise LeaspyInputError( "Features provided mismatch between data and model: " f"{data.headers} != {self.model.features}" ) ax, features, features_ix, labels, colors = self._handle_kwargs_begin( kwargs, data.headers ) # Data to dataframe (only selected patients) df = data.to_dataframe() df["ID"] = df["ID"].astype( str ) # needed because of IndividualParameters converting ID int -> str df = df.set_index("ID").loc[patients_idx] if reparametrized_ages: if ip_df is None: raise LeaspyInputError( "You want to plot reparametrized ages (`reparametrized_ages=True`) but you did not provide any individual parameters " "to do so (please use `individual_parameters` argument)." ) df = df.join(ip_df) if self.model.parameters['tau_mean'].size(0) == 1: t0 = self.model.parameters["tau_mean"].item() # reparametrized ages df["TIME_reparam"] = np.exp(df["xi"]) * (df["TIME"] - df["tau"]) + t0 else: tau_means_per_cluster = { c: self.model.parameters['tau_mean'][c].item() for c in range(self.model.n_clusters) } df["tau_mean_cluster"] = df["cluster_label"].map(tau_means_per_cluster) df["TIME_reparam"] = np.exp(df["xi"]) * (df["TIME"] - df["tau"]) + df["tau_mean_cluster"] # ---- Plot # plot observations (with reparametrized times or not) if with_obs: self._plot_observations( ax, df, features, colors, reparametrized_ages, plot_kws["obs"] ) # plot reconstruction as well (model values) if with_model: if ip_torch is None: raise LeaspyInputError( "Individual reconstruction need valid individual parameters." ) self._plot_model_trajectories( ax, df, self.model, ip_torch, features_ix, colors, reparametrized_ages, plot_kws["model"], **kwargs, ) # ---- Title & labels if with_obs: title = "Observations" if with_model: title += " and individual trajectories" else: # only with_model title = "Individual trajectories" ax.set_title(title) if reparametrized_ages: ax.set_xlabel("Reparametrized age") else: ax.set_xlabel("Age") self._handle_kwargs_end(ax, kwargs, colors, labels) return ax @staticmethod def _plot_observations(ax, df, features, colors, reparametrized_ages, plot_kws): """ Internal routine: plot individual observations Parameters ---------- ax : :class:`matplotlib.axes.Axes` df : :class:`pandas.DataFrame` Data to plot features : list[:obj:`str`] Which features to plot (subset of model features / data features) colors : list List of colors (associated to features selected), in order reparametrized_ages : bool Should we plot trajectories in reparam age or not? plot_kws : dict Plot kwargs """ if reparametrized_ages: time_col = "TIME_reparam" else: time_col = "TIME" df_with_time = df.set_index(df[time_col].rename("T"), append=True).sort_index() df_with_time = df_with_time[features].dropna( how="all" ) # selected features only for ind_id, ind_df in df_with_time.groupby("ID"): for (ft_name, s_ind_ft), ft_color in zip(ind_df.items(), colors): s_ind_ft = s_ind_ft.dropna() # TODO? use a cycle of markers to better distinguish individuals? ax.plot( s_ind_ft.reset_index("T")["T"], s_ind_ft, c=ft_color, # label=ft_lbl, # legend is done afterwards **plot_kws, ) @staticmethod def _plot_model_trajectories( ax, df, model, individual_parameters, features_ix, colors, reparametrized_ages, plot_kws, **kwargs, ): """ Internal routine: plot individual trajectories estimated by model Parameters ---------- ax : :class:`matplotlib.axes.Axes` df : :class:`pandas.DataFrame` Data (TODO: could be the MultiIndex [ID,TIME] instead...) individual_parameters : tuple[list, dict] <!> in pytorch dict format: tuple(indices:list, dict{ip_name: vals}) features_ix : list[int] Which features to plot (order of features from model) colors : list List of colors (associated to features selected), in order reparametrized_ages : bool Should we plot trajectories in reparam age or not? plot_kws : dict Plot kwargs **kwargs * "factor_past", "factor_future": float (default 0.5) past/future padding to plot (as fraction of total follow-up duration of subjects) * "n_tpts": int (default 100) nb of tpts in trajectory """ ip_indices, ip_torch = individual_parameters for ind_id, ind_df in df.groupby("ID"): ind_ix = ip_indices.index(ind_id) ind_ip = {pn: pv[ind_ix] for pn, pv in ip_torch.items()} # torch compatible timepoints = ind_df[ "TIME" ] # <!> always real patient ages here (to compute) min_t, max_t = min(timepoints), max(timepoints) total_t = max_t - min_t timepoints = np.linspace( min_t - kwargs.get("factor_past", 0.5) * total_t, max_t + kwargs.get("factor_future", 0.5) * total_t, kwargs.get("n_tpts", 100), ) t = torch.tensor(timepoints, dtype=torch.float32).unsqueeze(0) if isinstance(model, LogisticMultivariateMixtureModel): # for the mixture model make sure we remove the probabilities and the cluster labels from the dictionary valid_keys = set(model.individual_variables_names) ind_ip = { pn: pv[ind_ix] for pn, pv in ip_torch.items() if pn in valid_keys } trajectory = model.compute_individual_trajectory(t, ind_ip).squeeze(0) else: trajectory = model.compute_individual_trajectory(t, ind_ip).squeeze(0) # times to plot if reparametrized ages are wanted if reparametrized_ages: timepoints = ( ( model.time_reparametrization( t=t, alpha=ind_ip["xi"].exp(), tau=ind_ip["tau"] ) + model.parameters["tau_mean"].item() ) .squeeze(0) .cpu() .numpy() ) for ft_ix, ft_color in zip(features_ix, colors): ax.plot( timepoints, trajectory[:, ft_ix], c=ft_color, # label=ft_lbl, **plot_kws, )
[docs] def patient_observations( self, data, patients_idx="all", individual_parameters=None, **kwargs ): """ Plot patient observations Parameters ---------- data : :class:`~leaspy.io.data.data.Data` patients_idx : 'all' (default), :obj:`str` or list[:obj:`str`] Patients to display (by their ID). individual_parameters : :class:`~leaspy.io.outputs.individual_parameters.IndividualParameters` or :class:`pandas.DataFrame` (as may be output by ip.to_dataframe()) or dict (Pytorch ip format), optional If not None, observations are plotted with respect to reparametrized ages. """ return self._plot_patients_generic( "obs", data, patients_idx=patients_idx, individual_parameters=individual_parameters, reparametrized_ages=individual_parameters is not None, **kwargs, )
[docs] def patient_observations_reparametrized( self, data, individual_parameters, patients_idx="all", **kwargs ): """ Plot patient observations (reparametrized ages) """ return self._plot_patients_generic( "obs", data, patients_idx=patients_idx, individual_parameters=individual_parameters, reparametrized_ages=True, **kwargs, )
[docs] def patient_trajectories( self, data, individual_parameters, patients_idx="all", reparametrized_ages=False, **kwargs, ): """ Plot patient observations together with model individual reconstruction Parameters ---------- data : :class:`~leaspy.io.data.data.Data` individual_parameters : :class:`~leaspy.io.outputs.individual_parameters.IndividualParameters` or :class:`pandas.DataFrame` (as may be output by ip.to_dataframe()) or dict (Pytorch ip format) patients_idx : 'all' (default), :obj:`str` or list[:obj:`str`] Patients to display (by their ID). reparametrized_ages : :obj:`bool` (default False) Should we plot trajectories in reparam age or not? to study source impact essentially **kwargs cf. :meth:`._plot_model_trajectories` In particular, pass marker=None if you don't want observations besides model """ return self._plot_patients_generic( "recons", data, patients_idx=patients_idx, individual_parameters=individual_parameters, reparametrized_ages=reparametrized_ages, **kwargs, )
[docs] def average_trajectory_cluster(self, colors=None, n_features_per_plot=3, clusters=None, **kwargs): """ Plot the population average trajectories for each cluster. They are parametrized by the population parameters derived from the fit. Each cluster is plotted in a different linestyle, and each feature in a different color. Default is to plot 3 features per figure, so if there are more features they will be plotted in several plots. We can choose which clusters to plot, default is all. Parameters ---------- colors : list of str List of matplotlib-compatible colors for clusters. Cycles if fewer than number of clusters. n_features_per_plot : int, default 3 Number of features to plot in each figure. clusters : list of int, default None List of cluster indices to plot. If None, all clusters are plotted. **kwargs * alpha: :obj:`float`, default 0.6 Matplotlib's transparency option. Must be in [0, 1]. * linestyle: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...} Matplotlib's linestyle option. * linewidth: :obj:`float` Matplotlib's linewidth option. * features: list[:obj:`str`] Name of features (if set it must be a subset of model features) Default: all model features. * colors: list[:obj:`str`] Contains matplotlib compatible colors. At least as many as number of features. * labels: list[:obj:`str`] Used to rename features in the plot. Exactly as many as number of features. Default: raw variable name of each feature * ax: matplotlib.axes.Axes Axes object to modify, instead of creating a new one. * figsize: tuple of int The figure's size. * save_as: :obj:`str`, default None Path to save the figure. * title: :obj:`str` * n_tpts: :obj:`int` Number of timepoints in plot (default: 100) * n_std_left, n_std_right: :obj:`float` (default: 3 and 6 resp.) Time window around `tau_mean`, expressed as times of max(`tau_std`, 4) Returns ------- :class:`matplotlib.pyplot` The pyplot module with all generated figures, allowing further modification or saving. """ if colors is None: colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] plot_kws = self._plot_kwargs("cluster", kwargs) # ---- Get timepoints mean_time = self.model.parameters['tau_mean'].mean().item() std_time = max(self.model.parameters["tau_std"].mean().item(), 4) timepoints = mean_time + std_time * np.linspace( -kwargs.get("n_std_left", 3), kwargs.get("n_std_right", 6), kwargs.get("n_tpts", 100), ) timepoints = torch.tensor(timepoints, dtype=torch.float32) parameters = self.model.parameters n_clusters = self.model.n_clusters n_features = self.model.dimension if clusters is None: clusters_to_plot = list(range(n_clusters)) # all clusters elif isinstance(clusters, int): clusters_to_plot = [clusters] # single cluster else: clusters_to_plot = list(clusters) # list of clusters cluster_dict = {} cluster_estimates = {} for c in range(n_clusters): cluster_dict[c] = { 'xi': parameters['xi_mean'].numpy()[c], 'tau': parameters['tau_mean'].numpy()[c], 'sources': parameters['sources_mean'].numpy()[:, c].tolist() # all sources for this cluster } ip = IndividualParameters() ip.add_individual_parameters("average", cluster_dict[c]) cluster_estimates[c] = self.model.estimate({"average": timepoints}, ip) lines = [ "-", "--", ":", "-.", (0, (1, 1)), # densely dotted (0, (5, 1)), # long dash, short gap (0, (3, 1, 1, 1)), # dash-dot-dotted (0, (5, 5)), # evenly spaced dashes (0, (5, 2, 1, 2)), # long dash, dot, gap (0, (2, 2, 8, 2)), # dot + long dash (0, (10, 3)), # very long dash ] n_lines = len(lines) colors_cycle = itertools.cycle(colors) # Loop over feature chunks for start in range(0, n_features, n_features_per_plot): end = min(start + n_features_per_plot, n_features) feature_names = self.model.features[start:end] plt.figure(figsize=(8, 6)) plt.ylim(0, 1) # Plot each cluster feature_colors = {name: next(colors_cycle) for name in feature_names} for c in clusters_to_plot: ls = lines[c % len(lines)] # linestyle per cluster values = cluster_estimates[c]["average"][:, start:end].T for name, val in zip(feature_names, values): plt.plot( timepoints, val, label=f"cluster_{c}_{name}", c=feature_colors[name], ls=ls, **plot_kws["model"] ) # Cluster legend cluster_legend = [ Line2D([0], [0], linestyle=lines[c % len(lines)], linewidth=3, color="black", label=f"cluster_{c}") for c in clusters_to_plot ] legend1 = plt.legend(handles=cluster_legend, loc="upper left", prop={"size": 12}) # Feature/line style legend feature_legend = [ Line2D([0], [0], linestyle="-", color=feature_colors[name], linewidth=3, label=f"{name}") for i, name in enumerate(feature_names) ] legend2 = plt.legend(handles=feature_legend, loc="lower right", title="Feature", prop={"size": 12}) plt.gca().add_artist(legend1) plt.xlim(min(timepoints), max(timepoints)) plt.xlabel("Reparametrized age", fontsize=14) plt.ylabel("Normalized feature value", fontsize=14) plt.title(f"scores {start+1}-{end}", fontsize=14) plt.suptitle("Population progression", fontsize=16) plt.tight_layout() return plt