Source code for calisim.sbi.sbi_wrapper

"""Contains the implementations for simulation-based inference methods using
SBI

Implements the supported simulation-based inference methods using
the SBI library.

"""

import numpy as np
import pandas as pd
import torch.nn as nn
from matplotlib import pyplot as plt
from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import (
	SNPE,
	prepare_for_sbi,
	simulate_for_sbi,
)

from ..base import SimulationBasedInferenceBase
from ..data_model import ParameterEstimateModel


[docs] class SBISimulationBasedInference(SimulationBasedInferenceBase): """The SBI simulation-based inference method class."""
[docs] def execute(self) -> None: """Execute the simulation calibration procedure.""" sbi_kwargs = self.get_calibration_func_kwargs() def simulator_func(X: np.ndarray) -> np.ndarray: X = X.detach().cpu().numpy() X = [X] results = self.calibration_func_wrapper( X, self, self.specification.observed_data, self.names, self.data_types, sbi_kwargs, ) return results[0] simulator, prior = prepare_for_sbi(simulator_func, self.parameters) method_kwargs = self.specification.method_kwargs if method_kwargs is None: method_kwargs = {} embedding_net = nn.Identity() neural_posterior = utils.posterior_nn( model=self.specification.method, embedding_net=embedding_net, **method_kwargs, ) inference = SNPE(prior=prior, density_estimator=neural_posterior) theta = self.specification.X x = self.specification.Y if theta is None or x is None: theta, x = simulate_for_sbi( simulator, proposal=prior, num_simulations=self.specification.num_simulations, ) inference = inference.append_simulations(theta, x) density_estimator = inference.train( max_num_epochs=self.specification.n_iterations ) posterior = inference.build_posterior(density_estimator) posterior.set_default_x(self.specification.observed_data) self.prior = prior self.simulator = simulator self.inference = inference self.posterior = posterior
[docs] def analyze(self) -> None: """Analyze the results of the simulation calibration procedure.""" task, time_now, experiment_name, outdir = self.prepare_analyze() n_draws = self.specification.n_samples posterior_samples = self.posterior.sample( (n_draws,), x=self.specification.observed_data ) limits = [] lower_limits, _ = posterior_samples.min(axis=0) upper_limits, _ = posterior_samples.max(axis=0) for i in range(len(self.names)): limits.append((lower_limits[i], upper_limits[i])) for plot_func in [analysis.pairplot, analysis.marginal_plot]: plt.rcParams.update({"font.size": 8}) fig, _ = plot_func( posterior_samples, figsize=self.specification.figsize, labels=self.names, limits=limits, ) self.present_fig( fig, outdir, time_now, task, experiment_name, plot_func.__name__.replace("_", "-"), ) for plot_func in [ analysis.conditional_pairplot, analysis.conditional_marginal_plot, ]: plt.rcParams.update({"font.size": 8}) fig, _ = plot_func( density=self.posterior, condition=self.posterior.sample((1,)), figsize=self.specification.figsize, labels=self.names, limits=limits, ) self.present_fig( fig, outdir, time_now, task, experiment_name, plot_func.__name__.replace("_", "-"), ) thetas = self.prior.sample((n_draws,)) xs = self.simulator(thetas) ranks, dap_samples = analysis.run_sbc( thetas, xs, self.posterior, num_posterior_samples=n_draws ) num_bins = None if n_draws <= 20: num_bins = n_draws for plot_type in ["hist", "cdf"]: plt.rcParams.update({"font.size": 8}) fig, _ = analysis.sbc_rank_plot( ranks=ranks, num_bins=num_bins, num_posterior_samples=n_draws, plot_type=plot_type, parameter_labels=self.names, ) fig_suffix = ( f"{analysis.sbc_rank_plot.__name__.replace('_', '-')}_{plot_type}" ) self.present_fig(fig, outdir, time_now, task, experiment_name, fig_suffix) trace_df = pd.DataFrame( posterior_samples.cpu().detach().numpy(), columns=self.names ) for name in trace_df: estimate = trace_df[name].mean() uncertainty = trace_df[name].std() parameter_estimate = ParameterEstimateModel( name=name, estimate=estimate, uncertainty=uncertainty ) self.add_parameter_estimate(parameter_estimate) if outdir is None: return self.to_csv(trace_df, "trace") check_stats = analysis.check_sbc( ranks, thetas, dap_samples, num_posterior_samples=n_draws ) check_stats_list = [] for metric in check_stats: metric_dict = {"metric": metric} check_stats_list.append(metric_dict) scores = check_stats[metric].detach().cpu().numpy() for i, score in enumerate(scores): col_name = self.names[i] metric_dict[col_name] = score check_stats_df = pd.DataFrame(check_stats_list) self.to_csv(check_stats_df, "diagnostics")