"""Model inspection objects for programmatic access to model metadata.
This module provides :class:`Summary` and :class:`Info`, returned by
``model.summary()`` and ``model.info()`` respectively, along with the
:class:`~typing.TypedDict` schemas for training and dataset metadata.
Both classes auto-print when their return value is discarded (e.g.
``model.summary()``) and stay silent when stored in a variable
(e.g. ``s = model.summary()``). See :class:`AutoPrintMixin` for details.
"""
import textwrap
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import numpy as np
import pandas as pd
import torch
if TYPE_CHECKING:
from leaspy.models.base import BaseModel
__all__ = [
"AutoPrintMixin",
"DatasetInfo",
"Info",
"Summary",
"TrainingInfo",
"VisitsPerSubject",
"compute_bic",
"compute_aic",
"compute_icl",
"get_axis_labels",
"get_number_of_parameters",
]
# ---------------------------------------------------------------------------
# TypedDict schemas for metadata
# ---------------------------------------------------------------------------
[docs]
class VisitsPerSubject(TypedDict, total=False):
"""Per-subject visit distribution statistics."""
median: float
min: int
max: int
iqr: float
[docs]
class DatasetInfo(TypedDict, total=False):
"""Statistics of the training dataset, computed during ``fit()``."""
n_subjects: int
n_scores: int
n_visits: int
n_observations: int
visits_per_subject: VisitsPerSubject
n_events: int
[docs]
class TrainingInfo(TypedDict, total=False):
"""Metadata about the training process, captured during ``fit()``."""
algorithm: str
seed: int
n_iter: int
n_burn_in_iter: int
converged: bool
duration: str
# ---------------------------------------------------------------------------
# Shared utilities
# ---------------------------------------------------------------------------
_WIDTH = 80
[docs]
def get_axis_labels(
axis_name: Optional[str],
size: int,
feature_names: Optional[list[str]] = None,
) -> Optional[list[str]]:
"""Resolve human-readable labels for a parameter axis.
Parameters
----------
axis_name : str or None
Semantic axis name (``"feature"``, ``"source"``, ``"cluster"``,
``"basis"``).
size : int
Number of elements along the axis.
feature_names : list[str], optional
Feature names used when *axis_name* is ``"feature"``.
Returns
-------
list[str] or None
Labels for the axis, or ``None`` if no meaningful labels are available.
"""
if axis_name is None:
return None
if axis_name == "feature":
if feature_names is not None:
feats = feature_names[:size]
return [f[:8] if len(f) <= 8 else f[:7] + "." for f in feats]
return [f"f{i}" for i in range(size)]
elif axis_name == "source":
return [f"s{i}" for i in range(size)]
elif axis_name == "cluster":
return [f"c{i}" for i in range(size)]
elif axis_name == "event":
return None
elif axis_name == "basis":
return [f"b{i}" for i in range(size)]
else:
return None
def _wrap_text(label: str, text: str, indent: int = 0) -> list[str]:
"""Wrap *text* with a bold *label* prefix to fit within ``_WIDTH``."""
prefix = f"{label}: " if label else ""
initial_indent = " " * indent + prefix
subsequent_indent = " " * (indent + 4)
wrapper = textwrap.TextWrapper(
width=_WIDTH,
initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
break_long_words=False,
break_on_hyphens=False,
)
return wrapper.wrap(text)
# ---------------------------------------------------------------------------
# Auto-print mixin
# ---------------------------------------------------------------------------
[docs]
class AutoPrintMixin:
"""Mixin that auto-prints when the object is discarded.
Relies on CPython reference counting: when the return value of e.g.
``model.summary()`` is not assigned, the object is immediately
garbage-collected, triggering ``__del__`` which prints it.
When stored (``s = model.summary()``), any public attribute access
sets ``_printed = True``, suppressing the ``__del__`` output.
Subclasses must define a ``_printed: bool`` field (via dataclass)
and a ``__str__`` method.
"""
def __del__(self):
if not object.__getattribute__(self, "_printed"):
print(str(self))
def __repr__(self) -> str:
object.__setattr__(self, "_printed", True)
return str(self)
def __getattribute__(self, name: str):
value = object.__getattribute__(self, name)
# Suppress auto-print once any public attribute is accessed
if not name.startswith("_") and name != "help":
object.__setattr__(self, "_printed", True)
return value
# ---------------------------------------------------------------------------
# Metric utilities
# ---------------------------------------------------------------------------
[docs]
def get_number_of_parameters(model: "BaseModel") -> int:
"""Calculate the number of free parameters of the model.
Uses the theoretical formula:
``P = 3F + (F-1)*S + S*K + 4K``
where *F* = features, *S* = sources, *K* = clusters.
Parameters
----------
model
A fitted Leaspy model instance.
Returns
-------
int
Number of free parameters.
"""
n_features = getattr(model, "dimension", 0) or 0
n_sources = getattr(model, "source_dimension", 0) or 0
n_clusters = getattr(model, "n_clusters", 1) or 1
return (
(n_features * 3)
+ ((n_features - 1) * n_sources)
+ (n_sources * n_clusters)
+ (n_clusters * 4)
)
[docs]
def compute_bic(
nll: float,
num_params: int,
n_subjects: int,
) -> Optional[float]:
"""Calculate the Bayesian Information Criterion (BIC).
``BIC = 2 * nll + P * log(N)``
Parameters
----------
nll : float
Negative log-likelihood (``nll_attach``).
num_params : int
Number of free parameters.
n_subjects : int
Number of subjects used for model fitting.
Returns
-------
float or None
The computed BIC, or ``None`` if inputs are invalid.
"""
if n_subjects <= 0:
return None
return 2 * nll + num_params * np.log(n_subjects)
[docs]
def compute_aic(
nll: float,
num_params: int,
n_subjects: int,
) -> Optional[float]:
"""Calculate the Akaike Information Criterion (AIC).
``AIC = 2 * nll + 2 * P``
Parameters
----------
nll : float
Negative log-likelihood (``nll_attach``).
num_params : int
Number of free parameters.
n_subjects : int
Number of subjects used for model fitting (not used in AIC but included for consistency).
Returns
-------
float or None
The computed AIC, or ``None`` if inputs are invalid.
"""
if n_subjects <= 0:
return None
return 2 * nll + 2 * num_params
[docs]
def compute_icl(
bic: Optional[float],
model: "BaseModel",
) -> Optional[float]:
"""Calculate the Integrated Completed Likelihood (ICL) for a mixture model.
``ICL = BIC - sum_i sum_k pi_ik * log(pi_ik)``
where ``pi_ik`` is the posterior probability that individual *i* belongs to
cluster *k* (classification responsibility). The entropy term penalises
poorly-separated clusters on top of the BIC complexity penalty, so models
with crisp cluster assignments are preferred.
Returns ``None`` for non-mixture models (i.e. when ``model.n_clusters`` is
not set), or when ``bic`` itself is ``None``.
Parameters
----------
bic : float or None
Pre-computed BIC value for this model.
model : BaseModel
A fitted Leaspy model. Mixture-specific fields are read from
``model.state`` (per-individual ``tau``, ``xi``, ``sources``) and from
``model.parameters`` (cluster priors ``tau_mean``, ``tau_std``,
``xi_mean``, ``xi_std``, ``sources_mean``, ``probs``).
Returns
-------
float or None
The computed ICL, or ``None`` if not applicable.
"""
if bic is None or not getattr(model, "n_clusters", None):
return None
# Individual variables (tau, xi, sources) are only in state after fitting,
# not after loading from file — return None gracefully in that case.
try:
state = model.state
n_sources = getattr(model, "source_dimension", 0) or 0
ip = pd.DataFrame({
"tau": state["tau"][:, 0].cpu().numpy(),
"xi": state["xi"][:, 0].cpu().numpy(),
**{f"sources_{s}": state["sources"][:, s].cpu().numpy() for s in range(n_sources)},
})
except Exception:
return None
ip = model.get_individual_probabilities(ip)
prob_cols = [c for c in ip.columns if c.startswith("prob_cluster_")]
# np.nansum skips the NaN cells produced by the 0*log(0) terms (the .replace
# marks exact-zero responsibilities), implementing the entropy convention
# 0*log0 := 0. Plain .sum() would propagate a single NaN to the whole result.
entropy_value = np.nansum(
(ip[prob_cols] * np.log(ip[prob_cols].replace(0, np.nan))).to_numpy()
)
icl = bic - entropy_value
return icl if np.isfinite(icl) else None
def _persist_icl(model: "BaseModel") -> None:
"""Compute ICL and stash it in ``model.training_info`` so it survives save/load.
Individual variables (``tau``, ``xi``, ``sources``) are only on ``model.state``
right after fit and are not serialized; without this cache, ICL would be
unrecoverable after :meth:`BaseModel.load`. Called at the end of fit while
those variables are still in scope.
No-op for non-mixture models, when required inputs are missing, or when the
cache is already populated.
"""
if not getattr(model, "n_clusters", None):
return
if (model.training_info or {}).get("icl") is not None:
return
fm = getattr(model, "fit_metrics", None) or {}
nll_bic = fm.get("nll_attach", fm.get("nll_tot"))
n_subjects = (model.dataset_info or {}).get("n_subjects")
if nll_bic is None or n_subjects is None:
return
n_total_params = get_number_of_parameters(model)
bic = compute_bic(float(nll_bic), n_total_params, n_subjects)
icl = compute_icl(bic, model)
# Guard against non-finite values: a NaN/inf icl would be cached here and then
# serialized by to_dict() as the bare token `NaN`, which is invalid JSON.
if icl is not None and np.isfinite(icl):
model.training_info["icl"] = float(icl)
# ---------------------------------------------------------------------------
# Parameter display registry
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class _ModelDisplayMeta:
"""Display metadata for a model class (used by Summary).
Registered in ``_DISPLAY_REGISTRY`` keyed by class name.
``_get_display_meta`` walks the model's MRO to find the closest entry.
"""
individual_prior_params: tuple[str, ...]
noise_params: tuple[str, ...]
param_axes: dict[str, tuple[str, ...]]
derived_param_axes: dict[str, tuple[str, ...]] = field(default_factory=dict)
_BASE_PARAM_AXES: dict[str, tuple[str, ...]] = {
"log_g_mean": ("feature",),
"log_g_std": ("feature",),
"log_v0_mean": ("feature",),
"betas_mean": ("basis", "source"),
"mixing_matrix": ("source", "feature"),
"noise_std": ("feature",),
}
_BASE_DERIVED_PARAM_AXES: dict[str, tuple[str, ...]] = {
"v0": ("feature",),
"p0": ("feature",),
}
_DISPLAY_REGISTRY: dict[str, _ModelDisplayMeta] = {
"McmcSaemCompatibleModel": _ModelDisplayMeta(
individual_prior_params=(
"tau_mean", "tau_std", "xi_mean", "xi_std",
"sources_mean", "sources_std", "zeta_mean",
),
noise_params=("noise_std",),
param_axes=_BASE_PARAM_AXES,
derived_param_axes=_BASE_DERIVED_PARAM_AXES,
),
"TimeReparametrizedMixtureModel": _ModelDisplayMeta(
individual_prior_params=(
"tau_std", "xi_std", "sources_std", "tau_mean", "xi_mean", "sources_mean",
),
noise_params=("noise_std",),
param_axes={
**_BASE_PARAM_AXES,
"tau_mean": ("cluster",),
"tau_std": ("cluster",),
"xi_mean": ("cluster",),
"xi_std": ("cluster",),
"sources_mean": ("source", "cluster"),
"sources_std": ("source", "cluster"),
"probs": ("cluster",),
},
derived_param_axes=_BASE_DERIVED_PARAM_AXES,
),
"JointModel": _ModelDisplayMeta(
individual_prior_params=(
"tau_mean", "tau_std", "xi_mean", "xi_std",
"sources_mean", "sources_std", "zeta_mean",
),
noise_params=("noise_std",),
param_axes={
**_BASE_PARAM_AXES,
"n_log_nu_mean": ("event",),
"log_rho_mean": ("event",),
"zeta_mean": ("source", "event"),
},
derived_param_axes=_BASE_DERIVED_PARAM_AXES,
),
}
def _get_display_meta(model) -> Optional[_ModelDisplayMeta]:
"""Return display metadata for *model* by walking its MRO.
Returns the entry for the most specific registered class, or ``None``
if the model's class hierarchy has no entry in ``_DISPLAY_REGISTRY``.
"""
for cls in type(model).__mro__:
if cls.__name__ in _DISPLAY_REGISTRY:
return _DISPLAY_REGISTRY[cls.__name__]
return None
def _build_param_categories(
model, meta: _ModelDisplayMeta
) -> dict[str, list[str]]:
"""Categorize model parameters into display groups using *meta*."""
ind_priors = set(meta.individual_prior_params)
noise = set(meta.noise_params)
all_params = set(model.parameters.keys()) if model.parameters else set()
pop = all_params - ind_priors - noise
def sort_key(name: str) -> tuple[int, str, str]:
val = model.parameters[name]
axes = meta.param_axes.get(name, ())
primary_axis = axes[0] if axes else ""
n_cols = 1
if val.ndim == 1 and axes:
if get_axis_labels(primary_axis, len(val), model.features) is not None:
n_cols = len(val)
elif val.ndim == 2:
n_cols = val.shape[1]
return (n_cols, primary_axis, name)
return {
"population": sorted((k for k in pop if k in all_params), key=sort_key),
"individual_priors": sorted(
(k for k in ind_priors if k in all_params), key=sort_key
),
"noise": sorted((k for k in noise if k in all_params), key=sort_key),
}
# ---------------------------------------------------------------------------
# Info
# ---------------------------------------------------------------------------
[docs]
@dataclass(repr=False)
class Info(AutoPrintMixin):
"""Model configuration and training context (no parameter values).
Returned by ``model.info()``. Auto-prints when discarded; provides
programmatic access when stored in a variable.
Examples
--------
>>> model.info() # prints info
>>> i = model.info() # store for programmatic access
>>> i.n_subjects # 150
>>> i.help() # list available attributes
"""
name: str
model_type: str
dimension: Optional[int] = None
features: Optional[list[str]] = None
source_dimension: Optional[int] = None
n_clusters: Optional[int] = None
obs_models: Optional[list[str]] = None
n_total_params: Optional[int] = None
latent_variables: dict = field(default_factory=dict)
training_info: TrainingInfo = field(default_factory=dict)
hyperparameters: dict = field(default_factory=dict)
dataset_info: DatasetInfo = field(default_factory=dict)
leaspy_version: Optional[str] = None
_printed: bool = field(default=False, repr=False)
# -- Factory -------------------------------------------------------------
[docs]
@classmethod
def from_model(cls, model: "BaseModel") -> "Info":
"""Build an :class:`Info` from a model instance.
Works on un-fitted models too: parameter counts and DAG-derived fields
are simply omitted when the model state isn't available yet.
"""
is_initialized = getattr(model, "is_initialized", False)
# Observation model names
obs_model_names = None
if hasattr(model, "obs_models"):
obs_model_names = [om.to_string() for om in model.obs_models]
# Parameter count (requires initialized state)
n_total_params = None
if is_initialized:
try:
if getattr(model, "parameters", None):
n_total_params = get_number_of_parameters(model)
except Exception:
n_total_params = None
# Leaspy version
try:
from leaspy import __version__ as version
except ImportError:
version = None
# Latent variable distributions (requires DAG, which needs initialization)
from leaspy.variables.specs import PopulationLatentVariable, IndividualLatentVariable
latent_variables = {}
if is_initialized:
try:
dag = getattr(model, "dag", None)
except Exception:
dag = None
if dag is not None:
for kind, lv_type in [("population", PopulationLatentVariable), ("individual", IndividualLatentVariable)]:
group = {}
for var_name, var in dag.sorted_variables_by_type[lv_type].items():
dist_name = var.prior.dist_family.__name__.replace("Family", "") # "Normal", "MixtureNormal"
group[var_name] = {
"distribution": dist_name,
"parameters": list(var.prior.parameters_names),
}
if group:
latent_variables[kind] = group
# Hyperparameters (requires initialized state)
hyperparameters = {}
if is_initialized:
try:
hyperparameters = dict(getattr(model, "hyperparameters", {}))
except Exception:
hyperparameters = {}
return cls(
name=model.name,
model_type=model.__class__.__name__,
dimension=model.dimension,
features=model.features,
source_dimension=getattr(model, "source_dimension", None),
n_clusters=getattr(model, "n_clusters", None),
obs_models=obs_model_names,
n_total_params=n_total_params,
training_info=dict(model.training_info),
dataset_info=dict(model.dataset_info),
hyperparameters=hyperparameters,
leaspy_version=version,
latent_variables=latent_variables,
)
# -- Convenience properties: training ------------------------------------
@property
def algorithm(self) -> Optional[str]:
"""Algorithm name used for training."""
val = self.training_info.get("algorithm")
return val.value if hasattr(val, "value") else val
@property
def seed(self) -> Optional[int]:
"""Random seed used for training."""
return self.training_info.get("seed")
@property
def n_iter(self) -> Optional[int]:
"""Number of iterations."""
return self.training_info.get("n_iter")
@property
def n_burn_in_iter(self) -> Optional[int]:
"""Number of burn-in (memory-less) iterations."""
return self.training_info.get("n_burn_in_iter")
@property
def converged(self) -> Optional[bool]:
"""Whether training converged."""
return self.training_info.get("converged")
@property
def duration(self) -> Optional[str]:
"""Training duration."""
return self.training_info.get("duration")
@property
def hyperparameter(self) -> dict:
"""Model hyperparameters (e.g. source_dimension, obs_model)."""
return self.hyperparameters
@property
def latent_variable_distributions(self) -> dict:
"""Latent variable prior distributions, grouped by population/individual."""
return self.latent_variables
# -- Convenience properties: dataset -------------------------------------
@property
def n_subjects(self) -> Optional[int]:
"""Number of subjects in the training dataset."""
return self.dataset_info.get("n_subjects")
@property
def n_visits(self) -> Optional[int]:
"""Total number of visits."""
return self.dataset_info.get("n_visits")
@property
def n_scores(self) -> Optional[int]:
"""Number of scored features."""
return self.dataset_info.get("n_scores")
@property
def n_observations(self) -> Optional[int]:
"""Total number of observed data points."""
return self.dataset_info.get("n_observations")
@property
def visits_per_subject(self) -> Optional[VisitsPerSubject]:
"""Per-subject visit distribution statistics."""
return self.dataset_info.get("visits_per_subject")
@property
def n_events(self) -> Optional[int]:
"""Number of observed events (joint models only)."""
return self.dataset_info.get("n_events")
# -- Display -------------------------------------------------------------
def __str__(self) -> str:
lines = []
sep = "=" * _WIDTH
lines.append(sep)
lines.append(f"{'Model Information':^{_WIDTH}}")
lines.append(sep)
# Statistical Model
lines.append("Statistical Model")
lines.append(f"Type: {self.model_type}")
lines.append(f"Name: {self.name}")
lines.append(f"Dimension: {self.dimension}")
if self.source_dimension is not None:
lines.append(f"Source Dimension: {self.source_dimension}")
if self.obs_models:
lines.append(f"Observation Models: {', '.join(self.obs_models)}")
if self.n_total_params is not None:
lines.append(f"Parameters: {self.n_total_params}")
if self.n_clusters is not None:
lines.append(f"Clusters: {self.n_clusters}")
if self.latent_variables:
lines.append("")
lines.append("Latent Variables")
lines.append("-" * _WIDTH)
for kind, group in self.latent_variables.items():
lines.append(f" {kind.capitalize()}:")
for var_name, info in group.items():
params = ", ".join(info["parameters"])
lines.append(f" {var_name:<20} {info['distribution']}({params})")
lines.append("-" * _WIDTH)
# Training Dataset
if self.dataset_info:
lines.append("")
lines.append("Training Dataset")
lines.append("-" * _WIDTH)
di = self.dataset_info
lines.append(f"Subjects: {di.get('n_subjects', 'N/A')}")
lines.append(f"Visits: {di.get('n_visits', 'N/A')}")
lines.append(f"Scores (Features): {di.get('n_scores', 'N/A')}")
lines.append(f"Total Observations: {di.get('n_observations', 'N/A')}")
if "visits_per_subject" in di:
vps = di["visits_per_subject"]
lines.append(
f"Visits per Subject: Median {vps['median']:.1f} "
f"[Min {vps['min']}, Max {vps['max']}, IQR {vps['iqr']:.1f}]"
)
if "n_events" in di:
lines.append(f"Events Observed: {di['n_events']}")
# Training Details
if self.training_info:
lines.append("")
lines.append("Training Details")
lines.append("-" * _WIDTH)
ti = self.training_info
lines.append(f"Algorithm: {self.algorithm or 'N/A'}")
if "seed" in ti:
lines.append(f"Seed: {ti['seed']}")
lines.append(f"Iterations: {ti.get('n_iter', 'N/A')}")
if "n_burn_in_iter" in ti:
n_b = ti['n_burn_in_iter']
n_t = ti.get('n_iter') or 1
lines.append(f" Burn-in: {n_b}/{n_t} ({100*n_b/n_t:.0f}%)")
lines.append(f" Burn-out: {n_t - n_b}")
if ti.get("converged") is not None:
lines.append(f"Converged: {ti['converged']}")
if "duration" in ti:
lines.append(f"Duration: {ti['duration']}")
# Hyperparameters
if self.hyperparameters:
lines.append("")
lines.append("Hyperparameters (fixed values from the source code)")
lines.append("-" * _WIDTH)
for k, v in self.hyperparameters.items():
if isinstance(v, torch.Tensor):
val: float = v.item() if v.ndim == 0 else v.tolist()
else:
val: float = v
lines.append(f" {k}: {val.__round__(4) if isinstance(val, float) else val}")
# Leaspy Version
if self.leaspy_version:
lines.append("")
lines.append(f"Leaspy Version: {self.leaspy_version}")
lines.append(sep)
return "\n".join(lines)
[docs]
def help(self) -> None:
"""Print available attributes and their meanings."""
help_text = f"""
Info Help
{'=' * 60}
The Info object provides access to model configuration and training context.
Usage:
model.info() # Print model information
i = model.info() # Store to access individual attributes
Available Attributes:
Model:
name Model name (str)
model_type Model class name (str)
dimension Number of features (int)
features Feature names (list[str])
source_dimension Number of sources (int or None)
n_clusters Number of clusters (int or None)
obs_models Observation model names (list[str] or None)
n_total_params Number of free parameters (int)
hyperparameters Model hyperparameters dict (e.g. source_dimension)
Training:
algorithm Algorithm name (str)
seed Random seed (int)
n_iter Number of iterations (int)
n_burn_in_iter Number of burn-in iterations (int)
converged Whether training converged (bool or None)
duration Training duration (str)
Dataset:
n_subjects Number of subjects (int)
n_visits Total visits (int)
n_scores Number of scored features (int)
n_observations Total observations (int)
visits_per_subject Visit distribution stats (dict)
n_events Observed events, joint models only (int or None)
Other:
training_info Full training metadata (TrainingInfo)
dataset_info Full dataset statistics (DatasetInfo)
leaspy_version Leaspy version (str)
Examples:
>>> i = model.info()
>>> i.algorithm # 'mcmc_saem'
>>> i.n_subjects # 150
"""
print(help_text)
object.__setattr__(self, "_printed", True)
# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
[docs]
@dataclass(repr=False)
class Summary(AutoPrintMixin):
"""Structured summary of a Leaspy model including parameter values.
Returned by ``model.summary()``. Auto-prints when discarded; provides
programmatic access when stored in a variable.
Examples
--------
>>> model.summary() # prints the formatted summary
>>> s = model.summary() # store for programmatic access
>>> s.algorithm # 'mcmc_saem'
>>> s.get_param('tau_std') # tensor([10.5])
>>> s.help() # list available attributes
"""
name: str
model_type: str
dimension: Optional[int] = None
features: Optional[list[str]] = None
source_dimension: Optional[int] = None
n_clusters: Optional[int] = None
obs_models: Optional[list[str]] = None
n_total_params: Optional[int] = None
nll: Optional[float] = None
bic: Optional[float] = None
aic: Optional[float] = None
icl: Optional[float] = None
training_info: TrainingInfo = field(default_factory=dict)
dataset_info: DatasetInfo = field(default_factory=dict)
parameters: dict[str, dict[str, Any]] = field(default_factory=dict)
derived_parameters: dict[str, Any] = field(default_factory=dict)
leaspy_version: Optional[str] = None
_param_axes: dict = field(default_factory=dict, repr=False)
_derived_param_axes: dict = field(default_factory=dict, repr=False)
_feature_names: Optional[list[str]] = field(default=None, repr=False)
_printed: bool = field(default=False, repr=False)
# -- Factory -------------------------------------------------------------
[docs]
@classmethod
def from_model(cls, model: "BaseModel") -> "Summary":
"""Build a :class:`Summary` from a model instance."""
from leaspy.exceptions import LeaspyModelInputError
if not model.is_initialized:
raise LeaspyModelInputError(
"Model is not initialized. Call fit() first."
)
if model.parameters is None or len(model.parameters) == 0:
raise LeaspyModelInputError(
"Model has no parameters. Call fit() first."
)
# NLL
nll = None
fm = getattr(model, "fit_metrics", None) or {}
if nll_val := fm.get("nll_tot"):
nll = float(nll_val)
# Parameter count & BIC
n_total_params = get_number_of_parameters(model)
bic = None
nll_bic = fm.get("nll_attach", fm.get("nll_tot"))
n_subjects = model.dataset_info.get("n_subjects")
if nll_bic is not None and n_subjects is not None:
bic = compute_bic(float(nll_bic), n_total_params, n_subjects)
aic = None
if nll_bic is not None and n_subjects is not None:
aic = compute_aic(float(nll_bic), n_total_params, n_subjects)
icl = (model.training_info or {}).get("icl")
if icl is None:
icl = compute_icl(bic, model)
# Observation model names
obs_model_names = None
if hasattr(model, "obs_models"):
obs_model_names = [om.to_string() for om in model.obs_models]
# Leaspy version
try:
from leaspy import __version__ as version
except ImportError:
version = None
# Group parameters by category
_meta = _get_display_meta(model)
params_by_category = {}
if _meta is not None:
cats = _build_param_categories(model, _meta)
cat_names = {
"population": "Population Parameters",
"individual_priors": "Individual Parameters",
"noise": "Noise Model",
}
for cat_key, display_name in cat_names.items():
param_names = cats.get(cat_key, [])
if param_names:
params_by_category[display_name] = {
name: model.parameters[name]
for name in param_names
if name in model.parameters
}
else:
params_by_category["Parameters"] = dict(model.parameters)
# Derived parameters (v0, p0, ...) from model-side transform
derived = model.compute_derived_parameters()
return cls(
name=model.name,
model_type=model.__class__.__name__,
dimension=model.dimension,
features=model.features,
source_dimension=getattr(model, "source_dimension", None),
n_clusters=getattr(model, "n_clusters", None),
obs_models=obs_model_names,
n_total_params=n_total_params,
nll=nll,
bic=bic,
aic=aic,
icl=icl,
training_info=dict(model.training_info),
dataset_info=dict(model.dataset_info),
parameters=params_by_category,
derived_parameters=derived,
leaspy_version=version,
_param_axes=_meta.param_axes if _meta is not None else {},
_derived_param_axes=_meta.derived_param_axes if _meta is not None else {},
_feature_names=model.features,
)
# -- Convenience properties ----------------------------------------------
@property
def sources(self) -> Optional[list[str]]:
"""Source names (e.g. ``['s0', 's1']``) or ``None``."""
if self.source_dimension is None:
return None
return [f"s{i}" for i in range(self.source_dimension)]
@property
def clusters(self) -> Optional[list[str]]:
"""Cluster names (e.g. ``['c0', 'c1']``) or ``None``."""
if self.n_clusters is None:
return None
return [f"c{i}" for i in range(self.n_clusters)]
@property
def algorithm(self) -> Optional[str]:
"""Algorithm name used for training."""
val = self.training_info.get("algorithm")
return val.value if hasattr(val, "value") else val
@property
def seed(self) -> Optional[int]:
"""Random seed used for training."""
return self.training_info.get("seed")
@property
def n_iter(self) -> Optional[int]:
"""Number of iterations."""
return self.training_info.get("n_iter")
@property
def converged(self) -> Optional[bool]:
"""Whether training converged."""
return self.training_info.get("converged")
@property
def n_subjects(self) -> Optional[int]:
"""Number of subjects in the training dataset."""
return self.dataset_info.get("n_subjects")
@property
def n_visits(self) -> Optional[int]:
"""Total number of visits."""
return self.dataset_info.get("n_visits")
@property
def n_observations(self) -> Optional[int]:
"""Total number of observations."""
return self.dataset_info.get("n_observations")
[docs]
def get_param(self, name: str) -> Optional[Any]:
"""Get a parameter value by name, searching fitted and derived parameters.
Parameters
----------
name : str
Parameter name (e.g. ``'betas_mean'``, ``'tau_std'``, ``'v0'``, ``'p0'``).
Returns
-------
value
The parameter value (typically a ``torch.Tensor``), or ``None``.
"""
for category_params in self.parameters.values():
if name in category_params:
return category_params[name]
derived = object.__getattribute__(self, "derived_parameters")
if name in derived:
return derived[name]
return None
# -- Display -------------------------------------------------------------
def __str__(self) -> str:
lines = []
sep = "=" * _WIDTH
# Header
lines.append(sep)
lines.append(f"{'Model Summary':^{_WIDTH}}")
lines.append(sep)
lines.append(f"Model Name: {self.name}")
lines.append(f"Model Type: {self.model_type}")
if self.features is not None:
feat_str = ", ".join(self.features)
lines.extend(
_wrap_text(f"Features ({self.dimension})", feat_str)
)
if self.source_dimension is not None:
sources = [f"Source {i} (s{i})" for i in range(self.source_dimension)]
lines.extend(
_wrap_text(
f"Sources ({self.source_dimension})",
", ".join(sources),
)
)
if self.n_clusters is not None:
clusters = [f"Cluster {i} (c{i})" for i in range(self.n_clusters)]
lines.extend(
_wrap_text(
f"Clusters ({self.n_clusters})",
", ".join(clusters),
)
)
if self.obs_models:
lines.extend(
_wrap_text("Observation Models", ", ".join(self.obs_models))
)
if self.nll is not None:
lines.append(f"Neg. Log-Likelihood: {self.nll:.4f}")
if self.n_total_params is not None:
lines.append(f"Parameters: {self.n_total_params}")
if self.bic is not None:
lines.append(f"BIC: {self.bic:.2f}")
if self.aic is not None:
lines.append(f"AIC: {self.aic:.2f}")
if self.icl is not None:
lines.append(f"ICL: {self.icl:.2f}")
# Training Metadata
if self.training_info:
lines.append("")
lines.append("Training Metadata")
lines.append("-" * _WIDTH)
ti = self.training_info
lines.append(f"Algorithm: {self.algorithm or 'N/A'}")
if "seed" in ti:
lines.append(f"Seed: {ti['seed']}")
lines.append(f"Iterations: {ti.get('n_iter', 'N/A')}")
if ti.get("converged") is not None:
lines.append(f"Converged: {ti['converged']}")
# Data Context
if self.dataset_info:
lines.append("")
lines.append("Data Context")
lines.append("-" * _WIDTH)
di = self.dataset_info
lines.append(f"Subjects: {di.get('n_subjects', 'N/A')}")
lines.append(f"Visits: {di.get('n_visits', 'N/A')}")
lines.append(f"Total Observations: {di.get('n_observations', 'N/A')}")
# Leaspy Version
if self.leaspy_version:
lines.append(f"Leaspy Version: {self.leaspy_version}")
lines.append(sep)
# Parameters by category
for i, (category, params) in enumerate(self.parameters.items()):
if i == 0:
lines.append("")
if params:
if i > 0:
lines.append("")
lines.append(category)
lines.append("-" * _WIDTH)
lines.extend(self._format_parameter_group(params))
# Derived parameters (interpretable scale)
derived = object.__getattribute__(self, "derived_parameters")
if derived:
lines.append("")
lines.append("Derived Parameters (interpretable scale)")
lines.append("-" * _WIDTH)
lines.extend(self._format_derived_group(derived))
lines.append(sep)
return "\n".join(lines)
[docs]
def help(self) -> None:
"""Print available attributes and their meanings."""
help_text = f"""
Summary Help
{'=' * 60}
The Summary object provides access to model metadata and parameters.
Usage:
model.summary() # Print the formatted summary
s = model.summary() # Store to access individual attributes
Available Attributes:
Model Information:
name Model name (str)
model_type Model class name, e.g., 'LogisticModel' (str)
dimension Number of features (int)
features List of feature names (list[str])
sources Source names, e.g., ['s0', 's1'] (list[str] or None)
clusters Cluster names, e.g., ['c0', 'c1'] (list[str] or None)
source_dimension Number of sources (int or None)
n_clusters Number of clusters (int or None)
obs_models Observation model names (list[str] or None)
Training:
algorithm Algorithm name, e.g., 'mcmc_saem' (str)
seed Random seed used (int)
n_iter Number of iterations (int)
converged Whether training converged (bool or None)
nll Negative log-likelihood (float or None)
n_total_params Number of free parameters (int)
bic Bayesian Information Criterion (float or None)
aic Akaike Information Criterion (float or None)
icl Integrated Completed Likelihood, mixture models only
(float or None)
Dataset:
n_subjects Number of subjects in training data (int)
n_visits Total number of visits (int)
n_observations Total number of observations (int)
Parameters:
parameters All parameters grouped by category (dict)
derived_parameters Derived parameters in interpretable scale (dict)
get_param(name) Get a parameter by name (searches fitted + derived)
Derived Parameters (interpretable scale):
v0 Velocities: exp(log_v0_mean), per feature
p0 Positions: sigmoid(-log_g_mean) or g_mean, per feature
Other:
training_info Full training metadata (TrainingInfo)
dataset_info Full dataset statistics (DatasetInfo)
leaspy_version Leaspy version used (str)
Examples:
>>> s = model.summary()
>>> s.algorithm # 'mcmc_saem'
>>> s.seed # 42
>>> s.n_subjects # 150
>>> s.get_param('tau_std') # tensor([10.5])
"""
print(help_text)
object.__setattr__(self, "_printed", True)
# -- Private formatting helpers ------------------------------------------
def _format_parameter_group(self, params: dict[str, Any]) -> list[str]:
"""Format a group of parameters for display."""
lines = []
for name, value in params.items():
if isinstance(value, torch.Tensor):
lines.append(self._format_tensor(name, value))
else:
lines.append(f" {name:<18} {value}")
return lines
def _format_derived_group(self, params: dict[str, Any]) -> list[str]:
"""Format derived parameters using derived_param_axes."""
lines = []
for name, value in params.items():
if isinstance(value, torch.Tensor):
lines.append(self._format_tensor(name, value, derived=True))
else:
lines.append(f" {name:<18} {value}")
return lines
def _format_tensor(
self, name: str, value: torch.Tensor, *, derived: bool = False
) -> str:
"""Format a tensor parameter with axis labels."""
if derived:
param_axes = object.__getattribute__(self, "_derived_param_axes")
else:
param_axes = object.__getattribute__(self, "_param_axes")
feature_names = object.__getattribute__(self, "_feature_names")
axes = param_axes.get(name, ())
if value.ndim == 0:
return f" {name:<18} {value.item():.4f}"
elif value.ndim == 1:
n = len(value)
if n > 10:
return f" {name:<18} Tensor of shape ({n},)"
axis_name = axes[0] if len(axes) >= 1 else None
col_labels = get_axis_labels(axis_name, n, feature_names)
if col_labels:
header = " " * 20 + " ".join(f"{lbl:>8}" for lbl in col_labels)
values = f" {name:<18}" + " ".join(
f"{v.item():>8.4f}" for v in value
)
return header + "\n" + values
else:
val_str = "[" + ", ".join(f"{v.item():.4f}" for v in value) + "]"
return f" {name:<18} {val_str}"
elif value.ndim == 2:
rows, cols = value.shape
if rows > 8 or cols > 8:
return f" {name:<18} Tensor of shape {tuple(value.shape)}"
row_axis = axes[0] if len(axes) >= 1 else None
col_axis = axes[1] if len(axes) >= 2 else None
row_labels = get_axis_labels(row_axis, rows, feature_names)
col_labels = get_axis_labels(col_axis, cols, feature_names)
result = [f" {name}:"]
if col_labels:
header = " " * 20 + " ".join(f"{lbl:>8}" for lbl in col_labels)
result.append(header)
for i, row in enumerate(value):
row_lbl = row_labels[i] if row_labels else f"[{i}]"
row_str = (
f" {row_lbl:<8}"
+ " ".join(f"{v.item():>8.4f}" for v in row)
)
result.append(row_str)
return "\n".join(result)
else:
return f" {name:<18} Tensor of shape {tuple(value.shape)}"