Source code for calisim.sbi.implementation

"""Contains the implementations for the simulation-based inference methods

Implements the supported simulation-based inference methods.

"""

import importlib
from collections.abc import Callable

from pydantic import Field

from ..base import CalibrationMethodBase, CalibrationWorkflowBase
from ..data_model import CalibrationModel

TASK = "simulation_based_inference"
IMPLEMENTATIONS: dict[str, type[CalibrationWorkflowBase]] = dict()

if importlib.util.find_spec("lampe") is not None:
	from .lampe_wrapper import LAMPESimulationBasedInference

	IMPLEMENTATIONS["lampe"] = LAMPESimulationBasedInference

if importlib.util.find_spec("sbi") is not None:
	from .sbi_wrapper import SBISimulationBasedInference

	IMPLEMENTATIONS["sbi"] = SBISimulationBasedInference


[docs] def get_implementations() -> dict[str, type[CalibrationWorkflowBase]]: """Get the calibration implementations for simulation-based inference. Returns: Dict[str, type[CalibrationWorkflowBase]]: The dictionary of calibration implementations for simulation-based inference. """ return IMPLEMENTATIONS
[docs] class SimulationBasedInferenceMethodModel(CalibrationModel): """The simulation-based inference method data model. Args: BaseModel (CalibrationModel): The calibration base model class. """ num_simulations: int = Field( description="The number of simulations to run", default=25 )
[docs] class SimulationBasedInferenceMethod(CalibrationMethodBase): """The simulation-based inference method class.""" def __init__( self, calibration_func: Callable, specification: SimulationBasedInferenceMethodModel, engine: str = "sbi", implementation: CalibrationWorkflowBase | None = None, ) -> None: """SimulationBasedInferenceMethod constructor. Args: calibration_func (Callable): The calibration function. For example, a simulation function or objective function. specification (SimulationBasedInferenceMethodModel): The calibration specification. engine (str, optional): The simulation-based inference backend. Defaults to "sbi". implementation (CalibrationWorkflowBase | None): The calibration workflow implementation. """ super().__init__( calibration_func, specification, TASK, engine, IMPLEMENTATIONS, implementation, )