leaspy.algo.fit.abstract_fit_algo.AbstractFitAlgo

class AbstractFitAlgo(settings)

Bases: AlgoWithDeviceMixin, AbstractAlgo

Abstract class containing common method for all fit algorithm classes.

Parameters
settingsAlgorithmSettings

The specifications of the algorithm as a AlgorithmSettings instance.

See also

Leaspy.fit()
Attributes
algorithm_devicestr

Valid torch device

current_iterationint, default 0

The number of the current iteration. The first iteration will be 1 and the last one n_iter.

sufficient_statisticsdict[str, torch.FloatTensor] or None

The previous step sufficient statistics. It is None during all the burn-in phase.

Inherited attributes

From AbstractAlgo

Methods

iteration(dataset, model, realizations)

Update the parameters (abstract method).

load_parameters(parameters)

Update the algorithm's parameters by the ones in the given dictionary.

run(model, *args[, return_noise])

Main method, run the algorithm.

run_impl(model, dataset)

Main method, run the algorithm.

set_output_manager(output_settings)

Set a FitOutputManager object for the run of the algorithm

abstract iteration(dataset: Dataset, model: AbstractModel, realizations: CollectionRealization)

Update the parameters (abstract method).

Parameters
datasetDataset

Contains the subjects’ observations in torch format to speed-up computation.

modelAbstractModel

The used model.

realizationsCollectionRealization

The parameters.

load_parameters(parameters: dict)

Update the algorithm’s parameters by the ones in the given dictionary. The keys in the io which does not belong to the algorithm’s parameters keys are ignored.

Parameters
parametersdict

Contains the pairs (key, value) of the wanted parameters

Examples

>>> settings = leaspy.io.settings.algorithm_settings.AlgorithmSettings("mcmc_saem")
>>> my_algo = leaspy.algo.fit.tensor_mcmcsaem.TensorMCMCSAEM(settings)
>>> my_algo.algo_parameters
{'n_iter': 10000,
 'n_burn_in_iter': 9000,
 'eps': 0.001,
 'L': 10,
 'sampler_ind': 'Gibbs',
 'sampler_pop': 'Gibbs',
 'annealing': {'do_annealing': False,
  'initial_temperature': 10,
  'n_plateau': 10,
  'n_iter': 200}}
>>> parameters = {'n_iter': 5000, 'n_burn_in_iter': 4000}
>>> my_algo.load_parameters(parameters)
>>> my_algo.algo_parameters
{'n_iter': 5000,
 'n_burn_in_iter': 4000,
 'eps': 0.001,
 'L': 10,
 'sampler_ind': 'Gibbs',
 'sampler_pop': 'Gibbs',
 'annealing': {'do_annealing': False,
  'initial_temperature': 10,
  'n_plateau': 10,
  'n_iter': 200}}
property log_noise_fmt

Getter

Returns
formatstr

The format for the print of the loss

run(model: AbstractModel, *args, return_noise: bool = False, **extra_kwargs) Any

Main method, run the algorithm.

TODO fix proper abstract class method: input depends on algorithm… (esp. simulate != from others…)

Parameters
modelAbstractModel

The used model.

datasetDataset

Contains all the subjects’ observations with corresponding timepoints, in torch format to speed up computations.

return_noisebool (default False), keyword only

Should the algorithm return main output and optional noise output as a 2-tuple?

Returns
Depends on algorithm class: TODO change?
run_impl(model: AbstractModel, dataset: Dataset)

Main method, run the algorithm.

Basically, it initializes the CollectionRealization object, updates it using the iteration method then returns it.

TODO fix proper abstract class

Parameters
modelAbstractModel

The used model.

datasetDataset

Contains the subjects’ observations in torch format to speed up computation.

Returns
2-tuple:
set_output_manager(output_settings)

Set a FitOutputManager object for the run of the algorithm

Parameters
output_settingsOutputsSettings

Contains the logs settings for the computation run (console print periodicity, plot periodicity …)

Examples

>>> from leaspy import AlgorithmSettings
>>> from leaspy.io.settings.outputs_settings import OutputsSettings
>>> from leaspy.algo.fit.tensor_mcmcsaem import TensorMCMCSAEM
>>> algo_settings = AlgorithmSettings("mcmc_saem")
>>> my_algo = TensorMCMCSAEM(algo_settings)
>>> settings = {'path': 'brouillons',
                'console_print_periodicity': 50,
                'plot_periodicity': 100,
                'save_periodicity': 50
                }
>>> my_algo.set_output_manager(OutputsSettings(settings))