Source code for leaspy.models.summary

"""Model inspection objects for programmatic access to model metadata.

This module provides :class:`Summary` and :class:`Info`, returned by
``model.summary()`` and ``model.info()`` respectively, along with the
:class:`~typing.TypedDict` schemas for training and dataset metadata.

Both classes auto-print when their return value is discarded (e.g.
``model.summary()``) and stay silent when stored in a variable
(e.g. ``s = model.summary()``).  See :class:`AutoPrintMixin` for details.
"""

import textwrap
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional, TypedDict

import numpy as np
import pandas as pd
import torch

if TYPE_CHECKING:
    from leaspy.models.base import BaseModel

__all__ = [
    "AutoPrintMixin",
    "DatasetInfo",
    "Info",
    "Summary",
    "TrainingInfo",
    "VisitsPerSubject",
    "compute_bic",
    "compute_aic",
    "compute_icl",
    "get_axis_labels",
    "get_number_of_parameters",
]


# ---------------------------------------------------------------------------
# TypedDict schemas for metadata
# ---------------------------------------------------------------------------


[docs] class VisitsPerSubject(TypedDict, total=False): """Per-subject visit distribution statistics.""" median: float min: int max: int iqr: float
[docs] class DatasetInfo(TypedDict, total=False): """Statistics of the training dataset, computed during ``fit()``.""" n_subjects: int n_scores: int n_visits: int n_observations: int visits_per_subject: VisitsPerSubject n_events: int
[docs] class TrainingInfo(TypedDict, total=False): """Metadata about the training process, captured during ``fit()``.""" algorithm: str seed: int n_iter: int n_burn_in_iter: int converged: bool duration: str
# --------------------------------------------------------------------------- # Shared utilities # --------------------------------------------------------------------------- _WIDTH = 80
[docs] def get_axis_labels( axis_name: Optional[str], size: int, feature_names: Optional[list[str]] = None, ) -> Optional[list[str]]: """Resolve human-readable labels for a parameter axis. Parameters ---------- axis_name : str or None Semantic axis name (``"feature"``, ``"source"``, ``"cluster"``, ``"basis"``). size : int Number of elements along the axis. feature_names : list[str], optional Feature names used when *axis_name* is ``"feature"``. Returns ------- list[str] or None Labels for the axis, or ``None`` if no meaningful labels are available. """ if axis_name is None: return None if axis_name == "feature": if feature_names is not None: feats = feature_names[:size] return [f[:8] if len(f) <= 8 else f[:7] + "." for f in feats] return [f"f{i}" for i in range(size)] elif axis_name == "source": return [f"s{i}" for i in range(size)] elif axis_name == "cluster": return [f"c{i}" for i in range(size)] elif axis_name == "event": return None elif axis_name == "basis": return [f"b{i}" for i in range(size)] else: return None
def _wrap_text(label: str, text: str, indent: int = 0) -> list[str]: """Wrap *text* with a bold *label* prefix to fit within ``_WIDTH``.""" prefix = f"{label}: " if label else "" initial_indent = " " * indent + prefix subsequent_indent = " " * (indent + 4) wrapper = textwrap.TextWrapper( width=_WIDTH, initial_indent=initial_indent, subsequent_indent=subsequent_indent, break_long_words=False, break_on_hyphens=False, ) return wrapper.wrap(text) # --------------------------------------------------------------------------- # Auto-print mixin # ---------------------------------------------------------------------------
[docs] class AutoPrintMixin: """Mixin that auto-prints when the object is discarded. Relies on CPython reference counting: when the return value of e.g. ``model.summary()`` is not assigned, the object is immediately garbage-collected, triggering ``__del__`` which prints it. When stored (``s = model.summary()``), any public attribute access sets ``_printed = True``, suppressing the ``__del__`` output. Subclasses must define a ``_printed: bool`` field (via dataclass) and a ``__str__`` method. """ def __del__(self): if not object.__getattribute__(self, "_printed"): print(str(self)) def __repr__(self) -> str: object.__setattr__(self, "_printed", True) return str(self) def __getattribute__(self, name: str): value = object.__getattribute__(self, name) # Suppress auto-print once any public attribute is accessed if not name.startswith("_") and name != "help": object.__setattr__(self, "_printed", True) return value
# --------------------------------------------------------------------------- # Metric utilities # ---------------------------------------------------------------------------
[docs] def get_number_of_parameters(model: "BaseModel") -> int: """Calculate the number of free parameters of the model. Uses the theoretical formula: ``P = 3F + (F-1)*S + S*K + 4K`` where *F* = features, *S* = sources, *K* = clusters. Parameters ---------- model A fitted Leaspy model instance. Returns ------- int Number of free parameters. """ n_features = getattr(model, "dimension", 0) or 0 n_sources = getattr(model, "source_dimension", 0) or 0 n_clusters = getattr(model, "n_clusters", 1) or 1 return ( (n_features * 3) + ((n_features - 1) * n_sources) + (n_sources * n_clusters) + (n_clusters * 4) )
[docs] def compute_bic( nll: float, num_params: int, n_subjects: int, ) -> Optional[float]: """Calculate the Bayesian Information Criterion (BIC). ``BIC = 2 * nll + P * log(N)`` Parameters ---------- nll : float Negative log-likelihood (``nll_attach``). num_params : int Number of free parameters. n_subjects : int Number of subjects used for model fitting. Returns ------- float or None The computed BIC, or ``None`` if inputs are invalid. """ if n_subjects <= 0: return None return 2 * nll + num_params * np.log(n_subjects)
[docs] def compute_aic( nll: float, num_params: int, n_subjects: int, ) -> Optional[float]: """Calculate the Akaike Information Criterion (AIC). ``AIC = 2 * nll + 2 * P`` Parameters ---------- nll : float Negative log-likelihood (``nll_attach``). num_params : int Number of free parameters. n_subjects : int Number of subjects used for model fitting (not used in AIC but included for consistency). Returns ------- float or None The computed AIC, or ``None`` if inputs are invalid. """ if n_subjects <= 0: return None return 2 * nll + 2 * num_params
[docs] def compute_icl( bic: Optional[float], model: "BaseModel", ) -> Optional[float]: """Calculate the Integrated Completed Likelihood (ICL) for a mixture model. ``ICL = BIC - sum_i sum_k pi_ik * log(pi_ik)`` where ``pi_ik`` is the posterior probability that individual *i* belongs to cluster *k* (classification responsibility). The entropy term penalises poorly-separated clusters on top of the BIC complexity penalty, so models with crisp cluster assignments are preferred. Returns ``None`` for non-mixture models (i.e. when ``model.n_clusters`` is not set), or when ``bic`` itself is ``None``. Parameters ---------- bic : float or None Pre-computed BIC value for this model. model : BaseModel A fitted Leaspy model. Mixture-specific fields are read from ``model.state`` (per-individual ``tau``, ``xi``, ``sources``) and from ``model.parameters`` (cluster priors ``tau_mean``, ``tau_std``, ``xi_mean``, ``xi_std``, ``sources_mean``, ``probs``). Returns ------- float or None The computed ICL, or ``None`` if not applicable. """ if bic is None or not getattr(model, "n_clusters", None): return None # Individual variables (tau, xi, sources) are only in state after fitting, # not after loading from file — return None gracefully in that case. try: state = model.state n_sources = getattr(model, "source_dimension", 0) or 0 ip = pd.DataFrame({ "tau": state["tau"][:, 0].cpu().numpy(), "xi": state["xi"][:, 0].cpu().numpy(), **{f"sources_{s}": state["sources"][:, s].cpu().numpy() for s in range(n_sources)}, }) except Exception: return None ip = model.get_individual_probabilities(ip) prob_cols = [c for c in ip.columns if c.startswith("prob_cluster_")] # np.nansum skips the NaN cells produced by the 0*log(0) terms (the .replace # marks exact-zero responsibilities), implementing the entropy convention # 0*log0 := 0. Plain .sum() would propagate a single NaN to the whole result. entropy_value = np.nansum( (ip[prob_cols] * np.log(ip[prob_cols].replace(0, np.nan))).to_numpy() ) icl = bic - entropy_value return icl if np.isfinite(icl) else None
def _persist_icl(model: "BaseModel") -> None: """Compute ICL and stash it in ``model.training_info`` so it survives save/load. Individual variables (``tau``, ``xi``, ``sources``) are only on ``model.state`` right after fit and are not serialized; without this cache, ICL would be unrecoverable after :meth:`BaseModel.load`. Called at the end of fit while those variables are still in scope. No-op for non-mixture models, when required inputs are missing, or when the cache is already populated. """ if not getattr(model, "n_clusters", None): return if (model.training_info or {}).get("icl") is not None: return fm = getattr(model, "fit_metrics", None) or {} nll_bic = fm.get("nll_attach", fm.get("nll_tot")) n_subjects = (model.dataset_info or {}).get("n_subjects") if nll_bic is None or n_subjects is None: return n_total_params = get_number_of_parameters(model) bic = compute_bic(float(nll_bic), n_total_params, n_subjects) icl = compute_icl(bic, model) # Guard against non-finite values: a NaN/inf icl would be cached here and then # serialized by to_dict() as the bare token `NaN`, which is invalid JSON. if icl is not None and np.isfinite(icl): model.training_info["icl"] = float(icl) # --------------------------------------------------------------------------- # Parameter display registry # --------------------------------------------------------------------------- @dataclass(frozen=True) class _ModelDisplayMeta: """Display metadata for a model class (used by Summary). Registered in ``_DISPLAY_REGISTRY`` keyed by class name. ``_get_display_meta`` walks the model's MRO to find the closest entry. """ individual_prior_params: tuple[str, ...] noise_params: tuple[str, ...] param_axes: dict[str, tuple[str, ...]] derived_param_axes: dict[str, tuple[str, ...]] = field(default_factory=dict) _BASE_PARAM_AXES: dict[str, tuple[str, ...]] = { "log_g_mean": ("feature",), "log_g_std": ("feature",), "log_v0_mean": ("feature",), "betas_mean": ("basis", "source"), "mixing_matrix": ("source", "feature"), "noise_std": ("feature",), } _BASE_DERIVED_PARAM_AXES: dict[str, tuple[str, ...]] = { "v0": ("feature",), "p0": ("feature",), } _DISPLAY_REGISTRY: dict[str, _ModelDisplayMeta] = { "McmcSaemCompatibleModel": _ModelDisplayMeta( individual_prior_params=( "tau_mean", "tau_std", "xi_mean", "xi_std", "sources_mean", "sources_std", "zeta_mean", ), noise_params=("noise_std",), param_axes=_BASE_PARAM_AXES, derived_param_axes=_BASE_DERIVED_PARAM_AXES, ), "TimeReparametrizedMixtureModel": _ModelDisplayMeta( individual_prior_params=( "tau_std", "xi_std", "sources_std", "tau_mean", "xi_mean", "sources_mean", ), noise_params=("noise_std",), param_axes={ **_BASE_PARAM_AXES, "tau_mean": ("cluster",), "tau_std": ("cluster",), "xi_mean": ("cluster",), "xi_std": ("cluster",), "sources_mean": ("source", "cluster"), "sources_std": ("source", "cluster"), "probs": ("cluster",), }, derived_param_axes=_BASE_DERIVED_PARAM_AXES, ), "JointModel": _ModelDisplayMeta( individual_prior_params=( "tau_mean", "tau_std", "xi_mean", "xi_std", "sources_mean", "sources_std", "zeta_mean", ), noise_params=("noise_std",), param_axes={ **_BASE_PARAM_AXES, "n_log_nu_mean": ("event",), "log_rho_mean": ("event",), "zeta_mean": ("source", "event"), }, derived_param_axes=_BASE_DERIVED_PARAM_AXES, ), } def _get_display_meta(model) -> Optional[_ModelDisplayMeta]: """Return display metadata for *model* by walking its MRO. Returns the entry for the most specific registered class, or ``None`` if the model's class hierarchy has no entry in ``_DISPLAY_REGISTRY``. """ for cls in type(model).__mro__: if cls.__name__ in _DISPLAY_REGISTRY: return _DISPLAY_REGISTRY[cls.__name__] return None def _build_param_categories( model, meta: _ModelDisplayMeta ) -> dict[str, list[str]]: """Categorize model parameters into display groups using *meta*.""" ind_priors = set(meta.individual_prior_params) noise = set(meta.noise_params) all_params = set(model.parameters.keys()) if model.parameters else set() pop = all_params - ind_priors - noise def sort_key(name: str) -> tuple[int, str, str]: val = model.parameters[name] axes = meta.param_axes.get(name, ()) primary_axis = axes[0] if axes else "" n_cols = 1 if val.ndim == 1 and axes: if get_axis_labels(primary_axis, len(val), model.features) is not None: n_cols = len(val) elif val.ndim == 2: n_cols = val.shape[1] return (n_cols, primary_axis, name) return { "population": sorted((k for k in pop if k in all_params), key=sort_key), "individual_priors": sorted( (k for k in ind_priors if k in all_params), key=sort_key ), "noise": sorted((k for k in noise if k in all_params), key=sort_key), } # --------------------------------------------------------------------------- # Info # ---------------------------------------------------------------------------
[docs] @dataclass(repr=False) class Info(AutoPrintMixin): """Model configuration and training context (no parameter values). Returned by ``model.info()``. Auto-prints when discarded; provides programmatic access when stored in a variable. Examples -------- >>> model.info() # prints info >>> i = model.info() # store for programmatic access >>> i.n_subjects # 150 >>> i.help() # list available attributes """ name: str model_type: str dimension: Optional[int] = None features: Optional[list[str]] = None source_dimension: Optional[int] = None n_clusters: Optional[int] = None obs_models: Optional[list[str]] = None n_total_params: Optional[int] = None latent_variables: dict = field(default_factory=dict) training_info: TrainingInfo = field(default_factory=dict) hyperparameters: dict = field(default_factory=dict) dataset_info: DatasetInfo = field(default_factory=dict) leaspy_version: Optional[str] = None _printed: bool = field(default=False, repr=False) # -- Factory -------------------------------------------------------------
[docs] @classmethod def from_model(cls, model: "BaseModel") -> "Info": """Build an :class:`Info` from a model instance. Works on un-fitted models too: parameter counts and DAG-derived fields are simply omitted when the model state isn't available yet. """ is_initialized = getattr(model, "is_initialized", False) # Observation model names obs_model_names = None if hasattr(model, "obs_models"): obs_model_names = [om.to_string() for om in model.obs_models] # Parameter count (requires initialized state) n_total_params = None if is_initialized: try: if getattr(model, "parameters", None): n_total_params = get_number_of_parameters(model) except Exception: n_total_params = None # Leaspy version try: from leaspy import __version__ as version except ImportError: version = None # Latent variable distributions (requires DAG, which needs initialization) from leaspy.variables.specs import PopulationLatentVariable, IndividualLatentVariable latent_variables = {} if is_initialized: try: dag = getattr(model, "dag", None) except Exception: dag = None if dag is not None: for kind, lv_type in [("population", PopulationLatentVariable), ("individual", IndividualLatentVariable)]: group = {} for var_name, var in dag.sorted_variables_by_type[lv_type].items(): dist_name = var.prior.dist_family.__name__.replace("Family", "") # "Normal", "MixtureNormal" group[var_name] = { "distribution": dist_name, "parameters": list(var.prior.parameters_names), } if group: latent_variables[kind] = group # Hyperparameters (requires initialized state) hyperparameters = {} if is_initialized: try: hyperparameters = dict(getattr(model, "hyperparameters", {})) except Exception: hyperparameters = {} return cls( name=model.name, model_type=model.__class__.__name__, dimension=model.dimension, features=model.features, source_dimension=getattr(model, "source_dimension", None), n_clusters=getattr(model, "n_clusters", None), obs_models=obs_model_names, n_total_params=n_total_params, training_info=dict(model.training_info), dataset_info=dict(model.dataset_info), hyperparameters=hyperparameters, leaspy_version=version, latent_variables=latent_variables, )
# -- Convenience properties: training ------------------------------------ @property def algorithm(self) -> Optional[str]: """Algorithm name used for training.""" val = self.training_info.get("algorithm") return val.value if hasattr(val, "value") else val @property def seed(self) -> Optional[int]: """Random seed used for training.""" return self.training_info.get("seed") @property def n_iter(self) -> Optional[int]: """Number of iterations.""" return self.training_info.get("n_iter") @property def n_burn_in_iter(self) -> Optional[int]: """Number of burn-in (memory-less) iterations.""" return self.training_info.get("n_burn_in_iter") @property def converged(self) -> Optional[bool]: """Whether training converged.""" return self.training_info.get("converged") @property def duration(self) -> Optional[str]: """Training duration.""" return self.training_info.get("duration") @property def hyperparameter(self) -> dict: """Model hyperparameters (e.g. source_dimension, obs_model).""" return self.hyperparameters @property def latent_variable_distributions(self) -> dict: """Latent variable prior distributions, grouped by population/individual.""" return self.latent_variables # -- Convenience properties: dataset ------------------------------------- @property def n_subjects(self) -> Optional[int]: """Number of subjects in the training dataset.""" return self.dataset_info.get("n_subjects") @property def n_visits(self) -> Optional[int]: """Total number of visits.""" return self.dataset_info.get("n_visits") @property def n_scores(self) -> Optional[int]: """Number of scored features.""" return self.dataset_info.get("n_scores") @property def n_observations(self) -> Optional[int]: """Total number of observed data points.""" return self.dataset_info.get("n_observations") @property def visits_per_subject(self) -> Optional[VisitsPerSubject]: """Per-subject visit distribution statistics.""" return self.dataset_info.get("visits_per_subject") @property def n_events(self) -> Optional[int]: """Number of observed events (joint models only).""" return self.dataset_info.get("n_events") # -- Display ------------------------------------------------------------- def __str__(self) -> str: lines = [] sep = "=" * _WIDTH lines.append(sep) lines.append(f"{'Model Information':^{_WIDTH}}") lines.append(sep) # Statistical Model lines.append("Statistical Model") lines.append(f"Type: {self.model_type}") lines.append(f"Name: {self.name}") lines.append(f"Dimension: {self.dimension}") if self.source_dimension is not None: lines.append(f"Source Dimension: {self.source_dimension}") if self.obs_models: lines.append(f"Observation Models: {', '.join(self.obs_models)}") if self.n_total_params is not None: lines.append(f"Parameters: {self.n_total_params}") if self.n_clusters is not None: lines.append(f"Clusters: {self.n_clusters}") if self.latent_variables: lines.append("") lines.append("Latent Variables") lines.append("-" * _WIDTH) for kind, group in self.latent_variables.items(): lines.append(f" {kind.capitalize()}:") for var_name, info in group.items(): params = ", ".join(info["parameters"]) lines.append(f" {var_name:<20} {info['distribution']}({params})") lines.append("-" * _WIDTH) # Training Dataset if self.dataset_info: lines.append("") lines.append("Training Dataset") lines.append("-" * _WIDTH) di = self.dataset_info lines.append(f"Subjects: {di.get('n_subjects', 'N/A')}") lines.append(f"Visits: {di.get('n_visits', 'N/A')}") lines.append(f"Scores (Features): {di.get('n_scores', 'N/A')}") lines.append(f"Total Observations: {di.get('n_observations', 'N/A')}") if "visits_per_subject" in di: vps = di["visits_per_subject"] lines.append( f"Visits per Subject: Median {vps['median']:.1f} " f"[Min {vps['min']}, Max {vps['max']}, IQR {vps['iqr']:.1f}]" ) if "n_events" in di: lines.append(f"Events Observed: {di['n_events']}") # Training Details if self.training_info: lines.append("") lines.append("Training Details") lines.append("-" * _WIDTH) ti = self.training_info lines.append(f"Algorithm: {self.algorithm or 'N/A'}") if "seed" in ti: lines.append(f"Seed: {ti['seed']}") lines.append(f"Iterations: {ti.get('n_iter', 'N/A')}") if "n_burn_in_iter" in ti: n_b = ti['n_burn_in_iter'] n_t = ti.get('n_iter') or 1 lines.append(f" Burn-in: {n_b}/{n_t} ({100*n_b/n_t:.0f}%)") lines.append(f" Burn-out: {n_t - n_b}") if ti.get("converged") is not None: lines.append(f"Converged: {ti['converged']}") if "duration" in ti: lines.append(f"Duration: {ti['duration']}") # Hyperparameters if self.hyperparameters: lines.append("") lines.append("Hyperparameters (fixed values from the source code)") lines.append("-" * _WIDTH) for k, v in self.hyperparameters.items(): if isinstance(v, torch.Tensor): val: float = v.item() if v.ndim == 0 else v.tolist() else: val: float = v lines.append(f" {k}: {val.__round__(4) if isinstance(val, float) else val}") # Leaspy Version if self.leaspy_version: lines.append("") lines.append(f"Leaspy Version: {self.leaspy_version}") lines.append(sep) return "\n".join(lines)
[docs] def help(self) -> None: """Print available attributes and their meanings.""" help_text = f""" Info Help {'=' * 60} The Info object provides access to model configuration and training context. Usage: model.info() # Print model information i = model.info() # Store to access individual attributes Available Attributes: Model: name Model name (str) model_type Model class name (str) dimension Number of features (int) features Feature names (list[str]) source_dimension Number of sources (int or None) n_clusters Number of clusters (int or None) obs_models Observation model names (list[str] or None) n_total_params Number of free parameters (int) hyperparameters Model hyperparameters dict (e.g. source_dimension) Training: algorithm Algorithm name (str) seed Random seed (int) n_iter Number of iterations (int) n_burn_in_iter Number of burn-in iterations (int) converged Whether training converged (bool or None) duration Training duration (str) Dataset: n_subjects Number of subjects (int) n_visits Total visits (int) n_scores Number of scored features (int) n_observations Total observations (int) visits_per_subject Visit distribution stats (dict) n_events Observed events, joint models only (int or None) Other: training_info Full training metadata (TrainingInfo) dataset_info Full dataset statistics (DatasetInfo) leaspy_version Leaspy version (str) Examples: >>> i = model.info() >>> i.algorithm # 'mcmc_saem' >>> i.n_subjects # 150 """ print(help_text) object.__setattr__(self, "_printed", True)
# --------------------------------------------------------------------------- # Summary # ---------------------------------------------------------------------------
[docs] @dataclass(repr=False) class Summary(AutoPrintMixin): """Structured summary of a Leaspy model including parameter values. Returned by ``model.summary()``. Auto-prints when discarded; provides programmatic access when stored in a variable. Examples -------- >>> model.summary() # prints the formatted summary >>> s = model.summary() # store for programmatic access >>> s.algorithm # 'mcmc_saem' >>> s.get_param('tau_std') # tensor([10.5]) >>> s.help() # list available attributes """ name: str model_type: str dimension: Optional[int] = None features: Optional[list[str]] = None source_dimension: Optional[int] = None n_clusters: Optional[int] = None obs_models: Optional[list[str]] = None n_total_params: Optional[int] = None nll: Optional[float] = None bic: Optional[float] = None aic: Optional[float] = None icl: Optional[float] = None training_info: TrainingInfo = field(default_factory=dict) dataset_info: DatasetInfo = field(default_factory=dict) parameters: dict[str, dict[str, Any]] = field(default_factory=dict) derived_parameters: dict[str, Any] = field(default_factory=dict) leaspy_version: Optional[str] = None _param_axes: dict = field(default_factory=dict, repr=False) _derived_param_axes: dict = field(default_factory=dict, repr=False) _feature_names: Optional[list[str]] = field(default=None, repr=False) _printed: bool = field(default=False, repr=False) # -- Factory -------------------------------------------------------------
[docs] @classmethod def from_model(cls, model: "BaseModel") -> "Summary": """Build a :class:`Summary` from a model instance.""" from leaspy.exceptions import LeaspyModelInputError if not model.is_initialized: raise LeaspyModelInputError( "Model is not initialized. Call fit() first." ) if model.parameters is None or len(model.parameters) == 0: raise LeaspyModelInputError( "Model has no parameters. Call fit() first." ) # NLL nll = None fm = getattr(model, "fit_metrics", None) or {} if nll_val := fm.get("nll_tot"): nll = float(nll_val) # Parameter count & BIC n_total_params = get_number_of_parameters(model) bic = None nll_bic = fm.get("nll_attach", fm.get("nll_tot")) n_subjects = model.dataset_info.get("n_subjects") if nll_bic is not None and n_subjects is not None: bic = compute_bic(float(nll_bic), n_total_params, n_subjects) aic = None if nll_bic is not None and n_subjects is not None: aic = compute_aic(float(nll_bic), n_total_params, n_subjects) icl = (model.training_info or {}).get("icl") if icl is None: icl = compute_icl(bic, model) # Observation model names obs_model_names = None if hasattr(model, "obs_models"): obs_model_names = [om.to_string() for om in model.obs_models] # Leaspy version try: from leaspy import __version__ as version except ImportError: version = None # Group parameters by category _meta = _get_display_meta(model) params_by_category = {} if _meta is not None: cats = _build_param_categories(model, _meta) cat_names = { "population": "Population Parameters", "individual_priors": "Individual Parameters", "noise": "Noise Model", } for cat_key, display_name in cat_names.items(): param_names = cats.get(cat_key, []) if param_names: params_by_category[display_name] = { name: model.parameters[name] for name in param_names if name in model.parameters } else: params_by_category["Parameters"] = dict(model.parameters) # Derived parameters (v0, p0, ...) from model-side transform derived = model.compute_derived_parameters() return cls( name=model.name, model_type=model.__class__.__name__, dimension=model.dimension, features=model.features, source_dimension=getattr(model, "source_dimension", None), n_clusters=getattr(model, "n_clusters", None), obs_models=obs_model_names, n_total_params=n_total_params, nll=nll, bic=bic, aic=aic, icl=icl, training_info=dict(model.training_info), dataset_info=dict(model.dataset_info), parameters=params_by_category, derived_parameters=derived, leaspy_version=version, _param_axes=_meta.param_axes if _meta is not None else {}, _derived_param_axes=_meta.derived_param_axes if _meta is not None else {}, _feature_names=model.features, )
# -- Convenience properties ---------------------------------------------- @property def sources(self) -> Optional[list[str]]: """Source names (e.g. ``['s0', 's1']``) or ``None``.""" if self.source_dimension is None: return None return [f"s{i}" for i in range(self.source_dimension)] @property def clusters(self) -> Optional[list[str]]: """Cluster names (e.g. ``['c0', 'c1']``) or ``None``.""" if self.n_clusters is None: return None return [f"c{i}" for i in range(self.n_clusters)] @property def algorithm(self) -> Optional[str]: """Algorithm name used for training.""" val = self.training_info.get("algorithm") return val.value if hasattr(val, "value") else val @property def seed(self) -> Optional[int]: """Random seed used for training.""" return self.training_info.get("seed") @property def n_iter(self) -> Optional[int]: """Number of iterations.""" return self.training_info.get("n_iter") @property def converged(self) -> Optional[bool]: """Whether training converged.""" return self.training_info.get("converged") @property def n_subjects(self) -> Optional[int]: """Number of subjects in the training dataset.""" return self.dataset_info.get("n_subjects") @property def n_visits(self) -> Optional[int]: """Total number of visits.""" return self.dataset_info.get("n_visits") @property def n_observations(self) -> Optional[int]: """Total number of observations.""" return self.dataset_info.get("n_observations")
[docs] def get_param(self, name: str) -> Optional[Any]: """Get a parameter value by name, searching fitted and derived parameters. Parameters ---------- name : str Parameter name (e.g. ``'betas_mean'``, ``'tau_std'``, ``'v0'``, ``'p0'``). Returns ------- value The parameter value (typically a ``torch.Tensor``), or ``None``. """ for category_params in self.parameters.values(): if name in category_params: return category_params[name] derived = object.__getattribute__(self, "derived_parameters") if name in derived: return derived[name] return None
# -- Display ------------------------------------------------------------- def __str__(self) -> str: lines = [] sep = "=" * _WIDTH # Header lines.append(sep) lines.append(f"{'Model Summary':^{_WIDTH}}") lines.append(sep) lines.append(f"Model Name: {self.name}") lines.append(f"Model Type: {self.model_type}") if self.features is not None: feat_str = ", ".join(self.features) lines.extend( _wrap_text(f"Features ({self.dimension})", feat_str) ) if self.source_dimension is not None: sources = [f"Source {i} (s{i})" for i in range(self.source_dimension)] lines.extend( _wrap_text( f"Sources ({self.source_dimension})", ", ".join(sources), ) ) if self.n_clusters is not None: clusters = [f"Cluster {i} (c{i})" for i in range(self.n_clusters)] lines.extend( _wrap_text( f"Clusters ({self.n_clusters})", ", ".join(clusters), ) ) if self.obs_models: lines.extend( _wrap_text("Observation Models", ", ".join(self.obs_models)) ) if self.nll is not None: lines.append(f"Neg. Log-Likelihood: {self.nll:.4f}") if self.n_total_params is not None: lines.append(f"Parameters: {self.n_total_params}") if self.bic is not None: lines.append(f"BIC: {self.bic:.2f}") if self.aic is not None: lines.append(f"AIC: {self.aic:.2f}") if self.icl is not None: lines.append(f"ICL: {self.icl:.2f}") # Training Metadata if self.training_info: lines.append("") lines.append("Training Metadata") lines.append("-" * _WIDTH) ti = self.training_info lines.append(f"Algorithm: {self.algorithm or 'N/A'}") if "seed" in ti: lines.append(f"Seed: {ti['seed']}") lines.append(f"Iterations: {ti.get('n_iter', 'N/A')}") if ti.get("converged") is not None: lines.append(f"Converged: {ti['converged']}") # Data Context if self.dataset_info: lines.append("") lines.append("Data Context") lines.append("-" * _WIDTH) di = self.dataset_info lines.append(f"Subjects: {di.get('n_subjects', 'N/A')}") lines.append(f"Visits: {di.get('n_visits', 'N/A')}") lines.append(f"Total Observations: {di.get('n_observations', 'N/A')}") # Leaspy Version if self.leaspy_version: lines.append(f"Leaspy Version: {self.leaspy_version}") lines.append(sep) # Parameters by category for i, (category, params) in enumerate(self.parameters.items()): if i == 0: lines.append("") if params: if i > 0: lines.append("") lines.append(category) lines.append("-" * _WIDTH) lines.extend(self._format_parameter_group(params)) # Derived parameters (interpretable scale) derived = object.__getattribute__(self, "derived_parameters") if derived: lines.append("") lines.append("Derived Parameters (interpretable scale)") lines.append("-" * _WIDTH) lines.extend(self._format_derived_group(derived)) lines.append(sep) return "\n".join(lines)
[docs] def help(self) -> None: """Print available attributes and their meanings.""" help_text = f""" Summary Help {'=' * 60} The Summary object provides access to model metadata and parameters. Usage: model.summary() # Print the formatted summary s = model.summary() # Store to access individual attributes Available Attributes: Model Information: name Model name (str) model_type Model class name, e.g., 'LogisticModel' (str) dimension Number of features (int) features List of feature names (list[str]) sources Source names, e.g., ['s0', 's1'] (list[str] or None) clusters Cluster names, e.g., ['c0', 'c1'] (list[str] or None) source_dimension Number of sources (int or None) n_clusters Number of clusters (int or None) obs_models Observation model names (list[str] or None) Training: algorithm Algorithm name, e.g., 'mcmc_saem' (str) seed Random seed used (int) n_iter Number of iterations (int) converged Whether training converged (bool or None) nll Negative log-likelihood (float or None) n_total_params Number of free parameters (int) bic Bayesian Information Criterion (float or None) aic Akaike Information Criterion (float or None) icl Integrated Completed Likelihood, mixture models only (float or None) Dataset: n_subjects Number of subjects in training data (int) n_visits Total number of visits (int) n_observations Total number of observations (int) Parameters: parameters All parameters grouped by category (dict) derived_parameters Derived parameters in interpretable scale (dict) get_param(name) Get a parameter by name (searches fitted + derived) Derived Parameters (interpretable scale): v0 Velocities: exp(log_v0_mean), per feature p0 Positions: sigmoid(-log_g_mean) or g_mean, per feature Other: training_info Full training metadata (TrainingInfo) dataset_info Full dataset statistics (DatasetInfo) leaspy_version Leaspy version used (str) Examples: >>> s = model.summary() >>> s.algorithm # 'mcmc_saem' >>> s.seed # 42 >>> s.n_subjects # 150 >>> s.get_param('tau_std') # tensor([10.5]) """ print(help_text) object.__setattr__(self, "_printed", True)
# -- Private formatting helpers ------------------------------------------ def _format_parameter_group(self, params: dict[str, Any]) -> list[str]: """Format a group of parameters for display.""" lines = [] for name, value in params.items(): if isinstance(value, torch.Tensor): lines.append(self._format_tensor(name, value)) else: lines.append(f" {name:<18} {value}") return lines def _format_derived_group(self, params: dict[str, Any]) -> list[str]: """Format derived parameters using derived_param_axes.""" lines = [] for name, value in params.items(): if isinstance(value, torch.Tensor): lines.append(self._format_tensor(name, value, derived=True)) else: lines.append(f" {name:<18} {value}") return lines def _format_tensor( self, name: str, value: torch.Tensor, *, derived: bool = False ) -> str: """Format a tensor parameter with axis labels.""" if derived: param_axes = object.__getattribute__(self, "_derived_param_axes") else: param_axes = object.__getattribute__(self, "_param_axes") feature_names = object.__getattribute__(self, "_feature_names") axes = param_axes.get(name, ()) if value.ndim == 0: return f" {name:<18} {value.item():.4f}" elif value.ndim == 1: n = len(value) if n > 10: return f" {name:<18} Tensor of shape ({n},)" axis_name = axes[0] if len(axes) >= 1 else None col_labels = get_axis_labels(axis_name, n, feature_names) if col_labels: header = " " * 20 + " ".join(f"{lbl:>8}" for lbl in col_labels) values = f" {name:<18}" + " ".join( f"{v.item():>8.4f}" for v in value ) return header + "\n" + values else: val_str = "[" + ", ".join(f"{v.item():.4f}" for v in value) + "]" return f" {name:<18} {val_str}" elif value.ndim == 2: rows, cols = value.shape if rows > 8 or cols > 8: return f" {name:<18} Tensor of shape {tuple(value.shape)}" row_axis = axes[0] if len(axes) >= 1 else None col_axis = axes[1] if len(axes) >= 2 else None row_labels = get_axis_labels(row_axis, rows, feature_names) col_labels = get_axis_labels(col_axis, cols, feature_names) result = [f" {name}:"] if col_labels: header = " " * 20 + " ".join(f"{lbl:>8}" for lbl in col_labels) result.append(header) for i, row in enumerate(value): row_lbl = row_labels[i] if row_labels else f"[{i}]" row_str = ( f" {row_lbl:<8}" + " ".join(f"{v.item():>8.4f}" for v in row) ) result.append(row_str) return "\n".join(result) else: return f" {name:<18} Tensor of shape {tuple(value.shape)}"