import os
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import itertools
from leaspy.exceptions import (
LeaspyIndividualParamsInputError,
LeaspyInputError,
LeaspyTypeError,
)
from ...outputs import IndividualParameters
from leaspy.models import LogisticMultivariateMixtureModel
__all__ = ["Plotting"]
# TODO: outdated -
[docs]
class Plotting:
"""
.. deprecated:: 1.2
Class defining some plotting tools.
Parameters
----------
model : :class:`~leaspy.models.BaseModel`
The used model.
output_path : :obj:`str`, (optional)
Folder where plots will be saved.
If None, default to current working directory.
palette : :obj:`str` (palette name) or :class:`matplotlib.colors.Colormap` (`ListedColormap` or `LinearSegmentedColormap`)
The palette to use.
max_colors : :obj:`int` > 0, optional (default, corresponding to model nb of features)
Only used if palette is a string
"""
def __init__(self, model, output_path=".", palette="tab10", max_colors=10):
self.model = model
# ---- Graphical options
self.color_palette = None
self.standard_size = (8, 4)
self.linestyle = {
"average_model": "-",
"individual_model": "-",
"individual_data": "-",
}
self.linewidth = {
"average_model": 5,
"individual_model": 2,
"individual_data": 2,
}
self.alpha = {"average_model": 0.5, "individual_model": 1, "individual_data": 1}
self.output_path = output_path
self.set_palette(palette, max_colors)
[docs]
def set_palette(self, palette, max_colors=None):
"""
Set palette of plots
Parameters
----------
palette : :obj:`str` (palette name) or :class:`matplotlib.colors.Colormap` (`ListedColormap` or `LinearSegmentedColormap`)
The palette to use.
max_colors : :obj:`int` > 0, optional (default, corresponding to model nb of features)
Only used if palette is a string
"""
if isinstance(palette, mpl.colors.Colormap):
self.color_palette = palette
else:
if max_colors is None:
if self.model.dimension is not None:
raise LeaspyInputError(
"Initialize model first please, with a not None dimension"
)
max_colors = self.model.dimension
self.color_palette = mpl.colormaps[palette].resampled(max_colors)
[docs]
def colors(self, at=None):
"""
Wrapper over color_palette iterator to get colors
Parameters
----------
at : any legit color_palette arg (int, float or iterable of any of these) or None (default)
if None returns all colors of palette upto model dimension
Returns
-------
colors : single color tuple (RGBA) or np.array of RGBA colors (number of colors x 4)
"""
if at is None:
at = [i % self.color_palette.N for i in range(self.model.dimension)]
return self.color_palette(at)
def _raise_if_model_not_init(self):
# /!\ Break if model is not initialized
if not self.model.is_initialized:
raise LeaspyInputError("Please initialize the model before plotting")
def _handle_kwargs_begin(self, kwargs, all_features_list=None):
"""Extract kwargs corresponding to plot information and remove associated keys (in-place)."""
# get features from initialized model if not set
if all_features_list is None:
self._raise_if_model_not_init()
all_features_list = self.model.features
# ---- Get requested features (may be a subset)
features = kwargs.pop("features", all_features_list)
features_ix = list(map(all_features_list.index, features))
# ---- Colors
colors = kwargs.pop("color", self.colors(features_ix))
if len(colors) < len(features):
raise LeaspyInputError(
f"Please choose a palette with at least {len(features)} colors."
)
# TODO: reindex default colors if subset of features?
# ---- Labels
labels = kwargs.pop("labels", features)
if len(labels) != len(features):
raise LeaspyInputError(
f"Dimensions mismatch between features ({len(features)}) and labels ({len(labels)}."
)
# ---- Ax
ax = kwargs.pop("ax", None)
if ax is None:
fig, ax = plt.subplots(
1, 1, figsize=kwargs.pop("figsize", self.standard_size)
)
# ---- Handle ylim
if "logistic" in self.model.name:
ax.set_ylim(0, 1)
return ax, features, features_ix, labels, colors
def _handle_kwargs_end(self, ax, kwargs, colors, labels):
# ---- Legend
dimension = len(labels)
# if dimension is None:
# dimension = self.model.dimension
custom_lines = [
mpl.lines.Line2D([0], [0], color=colors[i], lw=4) for i in range(dimension)
]
ax.legend(custom_lines, labels, title="Features")
# ax.legend(title='Features')
ax.set_ylabel("Normalized score")
# ---- Save
if "save_as" in kwargs.keys():
plt.savefig(os.path.join(self.output_path, kwargs["save_as"]))
[docs]
def average_trajectory(self, **kwargs):
"""
Plot the population average trajectories. They are parametrized by the population parameters derived
during the calibration.
Parameters
----------
**kwargs
* alpha: :obj:`float`, default 0.6
Matplotlib's transparency option. Must be in [0, 1].
* linestyle: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...}
Matplotlib's linestyle option.
* linewidth: :obj:`float`
Matplotlib's linewidth option.
* features: list[:obj:`str`]
Name of features (if set it must be a subset of model features)
Default: all model features.
* colors: list[:obj:`str`]
Contains matplotlib compatible colors.
At least as many as number of features.
* labels: list[:obj:`str`]
Used to rename features in the plot.
Exactly as many as number of features.
Default: raw variable name of each feature
* ax: matplotlib.axes.Axes
Axes object to modify, instead of creating a new one.
* figsize: tuple of int
The figure's size.
* save_as: :obj:`str`, default None
Path to save the figure.
* title: :obj:`str`
* n_tpts: :obj:`int`
Number of timepoints in plot (default: 100)
* n_std_left, n_std_right: :obj:`float` (default: 3 and 6 resp.)
Time window around `tau_mean`, expressed as times of max(`tau_std`, 4)
Returns
-------
:class:`matplotlib.axes.Axes`
"""
# ---- Input manager
plot_kws = self._plot_kwargs("average", kwargs)
ax, _, features_ix, labels, colors = self._handle_kwargs_begin(kwargs)
# ---- Get timepoints
mean_time = self.model.parameters["tau_mean"].item()
std_time = max(self.model.parameters["tau_std"].item(), 4)
timepoints = mean_time + std_time * np.linspace(
-kwargs.get("n_std_left", 3),
kwargs.get("n_std_right", 6),
kwargs.get("n_tpts", 100),
)
timepoints = torch.tensor(timepoints, dtype=torch.float32).unsqueeze(0)
# ---- Compute average trajectory
mean_trajectory = (
self.model.compute_mean_traj(timepoints).cpu().detach().numpy()
)
# ---- plot it for each dimension
for ft_ix, ft_lbl, ft_color in zip(features_ix, labels, colors):
ax.plot(
timepoints[0, :].cpu().detach().numpy(),
mean_trajectory[0, :, ft_ix],
c=ft_color,
# label=ft_lbl, # not needed
**plot_kws["model"],
)
# ---- Title & labels
ax.set_title("Average trajectories")
ax.set_xlabel("Age")
self._handle_kwargs_end(ax, kwargs, colors, labels)
return ax
def _plot_kwargs(self, case, kwargs):
if case == "average":
return {
"model": dict(
alpha=kwargs.get("alpha", self.alpha["average_model"]),
linestyle=kwargs.get("linestyle", self.linestyle["average_model"]),
linewidth=kwargs.get("linewidth", self.linewidth["average_model"]),
)
}
elif case == "obs":
return {
"obs": dict(
alpha=kwargs.get("alpha", self.alpha["individual_data"]),
linestyle=kwargs.get(
"linestyle", self.linestyle["individual_data"]
),
linewidth=kwargs.get(
"linewidth", self.linewidth["individual_data"]
),
marker=kwargs.get("marker", "o"),
markersize=kwargs.get("markersize", "3"),
)
}
elif case == "recons":
# both observations & model will be displayed
p_obs = dict(
marker=kwargs.get("marker", "o"), # None not to display obs
markersize=kwargs.get("markersize", "4"),
alpha=kwargs.get("obs_alpha", self.alpha["individual_data"]),
linestyle=kwargs.get("obs_ls", ""),
linewidth=kwargs.get("obs_lw", self.linewidth["individual_data"]),
)
p_model = dict(
alpha=kwargs.get("alpha", self.alpha["individual_model"]),
linestyle=kwargs.get("linestyle", self.linestyle["individual_model"]),
linewidth=kwargs.get("linewidth", self.linewidth["individual_model"]),
)
return {"obs": p_obs, "model": p_model}
elif case == "cluster":
return {
"model": dict(
alpha=kwargs.get("alpha", self.alpha["average_model"]),
linewidth=kwargs.get("linewidth", self.linewidth["average_model"]),
)
}
else:
raise LeaspyInputError("case must be in {'average', 'obs', 'recons', 'cluster'}")
@staticmethod
def _get_ip_df_torch(individual_parameters):
# convert individual parameters in different cases
if isinstance(individual_parameters, IndividualParameters):
ip_df = individual_parameters.to_dataframe()
ip_torch = individual_parameters.to_pytorch()
elif isinstance(individual_parameters, pd.DataFrame):
ip_df = individual_parameters
ip_torch = IndividualParameters.from_dataframe(
individual_parameters
).to_pytorch()
elif isinstance(individual_parameters, tuple):
ip_df = IndividualParameters.from_pytorch(
*individual_parameters
).to_dataframe()
ip_torch = individual_parameters
else:
raise LeaspyTypeError(
"`individual_parameters` should be an IndividualParameters object, a pandas.DataFrame or a dict."
)
if ip_df.index.names != ["ID"]:
raise LeaspyIndividualParamsInputError(
"Individual parameters index is not ['ID'] "
f"as expected but {list(ip_df.index.names)}"
)
return ip_df, ip_torch
def _plot_patients_generic(
self,
case,
data,
patients_idx="all",
individual_parameters=None,
reparametrized_ages=False,
**kwargs,
):
# plot with reparametrized ages
ip_df, ip_torch = None, None
if individual_parameters is not None:
self._raise_if_model_not_init()
ip_df, ip_torch = self._get_ip_df_torch(individual_parameters)
# ---- Input manager
plot_kws = self._plot_kwargs(case, kwargs)
with_model = "model" in plot_kws # plot reconstruction of model as well
with_obs = "obs" in plot_kws and plot_kws["obs"].get("marker") is not None
if not (with_model or with_obs): # (or both !)
raise LeaspyInputError(
"Nothing to plot... nor model values nor observations."
)
# ---- Patients sublist
if "patient_IDs" in kwargs.keys():
warnings.warn(
"Keyword argument <patient_IDs> is deprecated! "
"Use <patients_idx> instead.",
DeprecationWarning,
)
patients_idx = kwargs.get("patient_IDs")
if isinstance(patients_idx, str):
if patients_idx == "all":
patients_idx = list(data.iter_to_idx.values())
else:
patients_idx = [patients_idx]
# features check
if self.model.is_initialized:
if data.headers != self.model.features:
raise LeaspyInputError(
"Features provided mismatch between data and model: "
f"{data.headers} != {self.model.features}"
)
ax, features, features_ix, labels, colors = self._handle_kwargs_begin(
kwargs, data.headers
)
# Data to dataframe (only selected patients)
df = data.to_dataframe()
df["ID"] = df["ID"].astype(
str
) # needed because of IndividualParameters converting ID int -> str
df = df.set_index("ID").loc[patients_idx]
if reparametrized_ages:
if ip_df is None:
raise LeaspyInputError(
"You want to plot reparametrized ages (`reparametrized_ages=True`) but you did not provide any individual parameters "
"to do so (please use `individual_parameters` argument)."
)
df = df.join(ip_df)
if self.model.parameters['tau_mean'].size(0) == 1:
t0 = self.model.parameters["tau_mean"].item()
# reparametrized ages
df["TIME_reparam"] = np.exp(df["xi"]) * (df["TIME"] - df["tau"]) + t0
else:
tau_means_per_cluster = {
c: self.model.parameters['tau_mean'][c].item() for c in range(self.model.n_clusters)
}
df["tau_mean_cluster"] = df["cluster_label"].map(tau_means_per_cluster)
df["TIME_reparam"] = np.exp(df["xi"]) * (df["TIME"] - df["tau"]) + df["tau_mean_cluster"]
# ---- Plot
# plot observations (with reparametrized times or not)
if with_obs:
self._plot_observations(
ax, df, features, colors, reparametrized_ages, plot_kws["obs"]
)
# plot reconstruction as well (model values)
if with_model:
if ip_torch is None:
raise LeaspyInputError(
"Individual reconstruction need valid individual parameters."
)
self._plot_model_trajectories(
ax,
df,
self.model,
ip_torch,
features_ix,
colors,
reparametrized_ages,
plot_kws["model"],
**kwargs,
)
# ---- Title & labels
if with_obs:
title = "Observations"
if with_model:
title += " and individual trajectories"
else: # only with_model
title = "Individual trajectories"
ax.set_title(title)
if reparametrized_ages:
ax.set_xlabel("Reparametrized age")
else:
ax.set_xlabel("Age")
self._handle_kwargs_end(ax, kwargs, colors, labels)
return ax
@staticmethod
def _plot_observations(ax, df, features, colors, reparametrized_ages, plot_kws):
"""
Internal routine: plot individual observations
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
df : :class:`pandas.DataFrame`
Data to plot
features : list[:obj:`str`]
Which features to plot (subset of model features / data features)
colors : list
List of colors (associated to features selected), in order
reparametrized_ages : bool
Should we plot trajectories in reparam age or not?
plot_kws : dict
Plot kwargs
"""
if reparametrized_ages:
time_col = "TIME_reparam"
else:
time_col = "TIME"
df_with_time = df.set_index(df[time_col].rename("T"), append=True).sort_index()
df_with_time = df_with_time[features].dropna(
how="all"
) # selected features only
for ind_id, ind_df in df_with_time.groupby("ID"):
for (ft_name, s_ind_ft), ft_color in zip(ind_df.items(), colors):
s_ind_ft = s_ind_ft.dropna()
# TODO? use a cycle of markers to better distinguish individuals?
ax.plot(
s_ind_ft.reset_index("T")["T"],
s_ind_ft,
c=ft_color,
# label=ft_lbl, # legend is done afterwards
**plot_kws,
)
@staticmethod
def _plot_model_trajectories(
ax,
df,
model,
individual_parameters,
features_ix,
colors,
reparametrized_ages,
plot_kws,
**kwargs,
):
"""
Internal routine: plot individual trajectories estimated by model
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
df : :class:`pandas.DataFrame`
Data (TODO: could be the MultiIndex [ID,TIME] instead...)
individual_parameters : tuple[list, dict]
<!> in pytorch dict format: tuple(indices:list, dict{ip_name: vals})
features_ix : list[int]
Which features to plot (order of features from model)
colors : list
List of colors (associated to features selected), in order
reparametrized_ages : bool
Should we plot trajectories in reparam age or not?
plot_kws : dict
Plot kwargs
**kwargs
* "factor_past", "factor_future": float (default 0.5)
past/future padding to plot (as fraction of total follow-up duration of subjects)
* "n_tpts": int (default 100)
nb of tpts in trajectory
"""
ip_indices, ip_torch = individual_parameters
for ind_id, ind_df in df.groupby("ID"):
ind_ix = ip_indices.index(ind_id)
ind_ip = {pn: pv[ind_ix] for pn, pv in ip_torch.items()} # torch compatible
timepoints = ind_df[
"TIME"
] # <!> always real patient ages here (to compute)
min_t, max_t = min(timepoints), max(timepoints)
total_t = max_t - min_t
timepoints = np.linspace(
min_t - kwargs.get("factor_past", 0.5) * total_t,
max_t + kwargs.get("factor_future", 0.5) * total_t,
kwargs.get("n_tpts", 100),
)
t = torch.tensor(timepoints, dtype=torch.float32).unsqueeze(0)
if isinstance(model, LogisticMultivariateMixtureModel):
# for the mixture model make sure we remove the probabilities and the cluster labels from the dictionary
valid_keys = set(model.individual_variables_names)
ind_ip = {
pn: pv[ind_ix]
for pn, pv in ip_torch.items()
if pn in valid_keys
}
trajectory = model.compute_individual_trajectory(t, ind_ip).squeeze(0)
else:
trajectory = model.compute_individual_trajectory(t, ind_ip).squeeze(0)
# times to plot if reparametrized ages are wanted
if reparametrized_ages:
timepoints = (
(
model.time_reparametrization(
t=t, alpha=ind_ip["xi"].exp(), tau=ind_ip["tau"]
)
+ model.parameters["tau_mean"].item()
)
.squeeze(0)
.cpu()
.numpy()
)
for ft_ix, ft_color in zip(features_ix, colors):
ax.plot(
timepoints,
trajectory[:, ft_ix],
c=ft_color,
# label=ft_lbl,
**plot_kws,
)
[docs]
def patient_observations(
self, data, patients_idx="all", individual_parameters=None, **kwargs
):
"""
Plot patient observations
Parameters
----------
data : :class:`~leaspy.io.data.data.Data`
patients_idx : 'all' (default), :obj:`str` or list[:obj:`str`]
Patients to display (by their ID).
individual_parameters : :class:`~leaspy.io.outputs.individual_parameters.IndividualParameters` or :class:`pandas.DataFrame` (as may be output by ip.to_dataframe()) or dict (Pytorch ip format), optional
If not None, observations are plotted with respect to reparametrized ages.
"""
return self._plot_patients_generic(
"obs",
data,
patients_idx=patients_idx,
individual_parameters=individual_parameters,
reparametrized_ages=individual_parameters is not None,
**kwargs,
)
[docs]
def patient_observations_reparametrized(
self, data, individual_parameters, patients_idx="all", **kwargs
):
"""
Plot patient observations (reparametrized ages)
"""
return self._plot_patients_generic(
"obs",
data,
patients_idx=patients_idx,
individual_parameters=individual_parameters,
reparametrized_ages=True,
**kwargs,
)
[docs]
def patient_trajectories(
self,
data,
individual_parameters,
patients_idx="all",
reparametrized_ages=False,
**kwargs,
):
"""
Plot patient observations together with model individual reconstruction
Parameters
----------
data : :class:`~leaspy.io.data.data.Data`
individual_parameters : :class:`~leaspy.io.outputs.individual_parameters.IndividualParameters` or :class:`pandas.DataFrame` (as may be output by ip.to_dataframe()) or dict (Pytorch ip format)
patients_idx : 'all' (default), :obj:`str` or list[:obj:`str`]
Patients to display (by their ID).
reparametrized_ages : :obj:`bool` (default False)
Should we plot trajectories in reparam age or not? to study source impact essentially
**kwargs
cf. :meth:`._plot_model_trajectories`
In particular, pass marker=None if you don't want observations besides model
"""
return self._plot_patients_generic(
"recons",
data,
patients_idx=patients_idx,
individual_parameters=individual_parameters,
reparametrized_ages=reparametrized_ages,
**kwargs,
)
[docs]
def average_trajectory_cluster(self, colors=None, n_features_per_plot=3, clusters=None, **kwargs):
"""
Plot the population average trajectories for each cluster. They are parametrized by the population parameters derived
from the fit. Each cluster is plotted in a different linestyle, and each feature in a different color.
Default is to plot 3 features per figure, so if there are more features they will be plotted in several plots.
We can choose which clusters to plot, default is all.
Parameters
----------
colors : list of str
List of matplotlib-compatible colors for clusters. Cycles if fewer than number of clusters.
n_features_per_plot : int, default 3
Number of features to plot in each figure.
clusters : list of int, default None
List of cluster indices to plot. If None, all clusters are plotted.
**kwargs
* alpha: :obj:`float`, default 0.6
Matplotlib's transparency option. Must be in [0, 1].
* linestyle: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...}
Matplotlib's linestyle option.
* linewidth: :obj:`float`
Matplotlib's linewidth option.
* features: list[:obj:`str`]
Name of features (if set it must be a subset of model features)
Default: all model features.
* colors: list[:obj:`str`]
Contains matplotlib compatible colors.
At least as many as number of features.
* labels: list[:obj:`str`]
Used to rename features in the plot.
Exactly as many as number of features.
Default: raw variable name of each feature
* ax: matplotlib.axes.Axes
Axes object to modify, instead of creating a new one.
* figsize: tuple of int
The figure's size.
* save_as: :obj:`str`, default None
Path to save the figure.
* title: :obj:`str`
* n_tpts: :obj:`int`
Number of timepoints in plot (default: 100)
* n_std_left, n_std_right: :obj:`float` (default: 3 and 6 resp.)
Time window around `tau_mean`, expressed as times of max(`tau_std`, 4)
Returns
-------
:class:`matplotlib.pyplot`
The pyplot module with all generated figures, allowing further modification or saving.
"""
if colors is None:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
plot_kws = self._plot_kwargs("cluster", kwargs)
# ---- Get timepoints
mean_time = self.model.parameters['tau_mean'].mean().item()
std_time = max(self.model.parameters["tau_std"].mean().item(), 4)
timepoints = mean_time + std_time * np.linspace(
-kwargs.get("n_std_left", 3),
kwargs.get("n_std_right", 6),
kwargs.get("n_tpts", 100),
)
timepoints = torch.tensor(timepoints, dtype=torch.float32)
parameters = self.model.parameters
n_clusters = self.model.n_clusters
n_features = self.model.dimension
if clusters is None:
clusters_to_plot = list(range(n_clusters)) # all clusters
elif isinstance(clusters, int):
clusters_to_plot = [clusters] # single cluster
else:
clusters_to_plot = list(clusters) # list of clusters
cluster_dict = {}
cluster_estimates = {}
for c in range(n_clusters):
cluster_dict[c] = {
'xi': parameters['xi_mean'].numpy()[c],
'tau': parameters['tau_mean'].numpy()[c],
'sources': parameters['sources_mean'].numpy()[:, c].tolist() # all sources for this cluster
}
ip = IndividualParameters()
ip.add_individual_parameters("average", cluster_dict[c])
cluster_estimates[c] = self.model.estimate({"average": timepoints}, ip)
lines = [
"-", "--", ":", "-.",
(0, (1, 1)), # densely dotted
(0, (5, 1)), # long dash, short gap
(0, (3, 1, 1, 1)), # dash-dot-dotted
(0, (5, 5)), # evenly spaced dashes
(0, (5, 2, 1, 2)), # long dash, dot, gap
(0, (2, 2, 8, 2)), # dot + long dash
(0, (10, 3)), # very long dash
]
n_lines = len(lines)
colors_cycle = itertools.cycle(colors)
# Loop over feature chunks
for start in range(0, n_features, n_features_per_plot):
end = min(start + n_features_per_plot, n_features)
feature_names = self.model.features[start:end]
plt.figure(figsize=(8, 6))
plt.ylim(0, 1)
# Plot each cluster
feature_colors = {name: next(colors_cycle) for name in feature_names}
for c in clusters_to_plot:
ls = lines[c % len(lines)] # linestyle per cluster
values = cluster_estimates[c]["average"][:, start:end].T
for name, val in zip(feature_names, values):
plt.plot(
timepoints,
val,
label=f"cluster_{c}_{name}",
c=feature_colors[name],
ls=ls,
**plot_kws["model"]
)
# Cluster legend
cluster_legend = [
Line2D([0], [0], linestyle=lines[c % len(lines)], linewidth=3, color="black", label=f"cluster_{c}")
for c in clusters_to_plot
]
legend1 = plt.legend(handles=cluster_legend, loc="upper left", prop={"size": 12})
# Feature/line style legend
feature_legend = [
Line2D([0], [0], linestyle="-", color=feature_colors[name], linewidth=3, label=f"{name}")
for i, name in enumerate(feature_names)
]
legend2 = plt.legend(handles=feature_legend, loc="lower right", title="Feature", prop={"size": 12})
plt.gca().add_artist(legend1)
plt.xlim(min(timepoints), max(timepoints))
plt.xlabel("Reparametrized age", fontsize=14)
plt.ylabel("Normalized feature value", fontsize=14)
plt.title(f"scores {start+1}-{end}", fontsize=14)
plt.suptitle("Population progression", fontsize=16)
plt.tight_layout()
return plt