diff --git a/pymc_extras/inference/pathfinder/importance_sampling.py b/pymc_extras/inference/pathfinder/importance_sampling.py index 3b4a0ee7..8d04c077 100644 --- a/pymc_extras/inference/pathfinder/importance_sampling.py +++ b/pymc_extras/inference/pathfinder/importance_sampling.py @@ -20,7 +20,7 @@ class ImportanceSamplingResult: samples: NDArray pareto_k: float | None = None warnings: list[str] = field(default_factory=list) - method: str = "none" + method: str = "psis" def importance_sampling( @@ -28,7 +28,7 @@ def importance_sampling( logP: NDArray, logQ: NDArray, num_draws: int, - method: Literal["psis", "psir", "identity", "none"] | None, + method: Literal["psis", "psir", "identity"] | None, random_seed: int | None = None, ) -> ImportanceSamplingResult: """Pareto Smoothed Importance Resampling (PSIR) @@ -44,8 +44,15 @@ def importance_sampling( log probability values of proposal distribution, shape (L, M) num_draws : int number of draws to return where num_draws <= samples.shape[0] - method : str, optional - importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths. + method : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + Options are: + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). random_seed : int | None Returns @@ -71,11 +78,11 @@ def importance_sampling( warnings = [] num_paths, _, N = samples.shape - if method == "none": + if method is None: warnings.append( "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability." ) - return ImportanceSamplingResult(samples=samples, warnings=warnings) + return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method) else: samples = samples.reshape(-1, N) logP = logP.ravel() @@ -91,17 +98,16 @@ def importance_sampling( _warnings.filterwarnings( "ignore", category=RuntimeWarning, message="overflow encountered in exp" ) - if method == "psis": - replace = False - logiw, pareto_k = az.psislw(logiw) - elif method == "psir": - replace = True - logiw, pareto_k = az.psislw(logiw) - elif method == "identity": - replace = False - pareto_k = None - else: - raise ValueError(f"Invalid importance sampling method: {method}") + match method: + case "psis": + replace = False + logiw, pareto_k = az.psislw(logiw) + case "psir": + replace = True + logiw, pareto_k = az.psislw(logiw) + case "identity": + replace = False + pareto_k = None # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. # Pareto k may not be a good diagnostic for Pathfinder. diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 8fc6f799..dfe5fc6a 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -156,7 +156,7 @@ def convert_flat_trace_to_idata( postprocessing_backend: Literal["cpu", "gpu"] = "cpu", inference_backend: Literal["pymc", "blackjax"] = "pymc", model: Model | None = None, - importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis", + importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", ) -> az.InferenceData: """convert flattened samples to arviz InferenceData format. @@ -181,7 +181,7 @@ def convert_flat_trace_to_idata( arviz inference data object """ - if importance_sampling == "none": + if importance_sampling is None: # samples.ndim == 3 in this case, otherwise ndim == 2 num_paths, num_pdraws, N = samples.shape samples = samples.reshape(-1, N) @@ -220,7 +220,7 @@ def convert_flat_trace_to_idata( fn.trust_input = True result = fn(*list(trace.values())) - if importance_sampling == "none": + if importance_sampling is None: result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result] elif inference_backend == "blackjax": @@ -1189,7 +1189,7 @@ class MultiPathfinderResult: elbo_argmax: NDArray | None = None lbfgs_status: Counter = field(default_factory=Counter) path_status: Counter = field(default_factory=Counter) - importance_sampling: str = "none" + importance_sampling: str | None = "psis" warnings: list[str] = field(default_factory=list) pareto_k: float | None = None @@ -1258,7 +1258,7 @@ def with_warnings(self, warnings: list[str]) -> Self: def with_importance_sampling( self, num_draws: int, - method: Literal["psis", "psir", "identity", "none"] | None, + method: Literal["psis", "psir", "identity"] | None, random_seed: int | None = None, ) -> Self: """perform importance sampling""" @@ -1424,7 +1424,7 @@ def multipath_pathfinder( num_elbo_draws: int, jitter: float, epsilon: float, - importance_sampling: Literal["psis", "psir", "identity", "none"] | None, + importance_sampling: Literal["psis", "psir", "identity"] | None, progressbar: bool, concurrent: Literal["thread", "process"] | None, random_seed: RandomSeed, @@ -1460,8 +1460,14 @@ def multipath_pathfinder( Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value. epsilon: float value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8). - importance_sampling : str, optional - importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N). + importance_sampling : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). progressbar : bool, optional Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time. random_seed : RandomSeed, optional @@ -1483,12 +1489,6 @@ def multipath_pathfinder( The result containing samples and other information from the Multi-Path Pathfinder algorithm. """ - valid_importance_sampling = ["psis", "psir", "identity", "none", None] - if importance_sampling is None: - importance_sampling = "none" - if importance_sampling.lower() not in valid_importance_sampling: - raise ValueError(f"Invalid importance sampling method: {importance_sampling}") - *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1) pathfinder_config = PathfinderConfig( @@ -1622,7 +1622,7 @@ def fit_pathfinder( num_elbo_draws: int = 10, # K jitter: float = 2.0, epsilon: float = 1e-8, - importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis", + importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", progressbar: bool = True, concurrent: Literal["thread", "process"] | None = None, random_seed: RandomSeed | None = None, @@ -1662,8 +1662,15 @@ def fit_pathfinder( Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value. epsilon: float value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8). - importance_sampling : str, optional - importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N). + importance_sampling : str, None, optional + Method to apply sampling based on log importance weights (logP - logQ). + Options are: + "psis" : Pareto Smoothed Importance Sampling (default) + Recommended for more stable results. + "psir" : Pareto Smoothed Importance Resampling + Less stable than PSIS. + "identity" : Applies log importance weights directly without resampling. + None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). progressbar : bool, optional Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time. random_seed : RandomSeed, optional @@ -1690,6 +1697,15 @@ def fit_pathfinder( """ model = modelcontext(model) + + valid_importance_sampling = {"psis", "psir", "identity", None} + + if importance_sampling is not None: + importance_sampling = importance_sampling.lower() + + if importance_sampling not in valid_importance_sampling: + raise ValueError(f"Invalid importance sampling method: {importance_sampling}") + N = DictToArrayBijection.map(model.initial_point()).data.shape[0] if maxcor is None: diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 1d5b2a9e..af9213ff 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -44,8 +44,8 @@ def reference_idata(): with model: idata = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend="pymc", ) @@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata): with model: idata = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend=inference_backend, ) else: idata = reference_idata - np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6) - np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5) + np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95) + np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35) assert idata.posterior["mu"].shape == (1, 1000) assert idata.posterior["tau"].shape == (1, 1000) @@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent): with model: idata_conc = pmx.fit( method="pathfinder", - num_paths=50, - jitter=10.0, + num_paths=10, + jitter=12.0, random_seed=41, inference_backend="pymc", concurrent=concurrent, @@ -108,7 +108,7 @@ def test_seed(reference_idata): with model: idata_41 = pmx.fit( method="pathfinder", - num_paths=50, + num_paths=4, jitter=10.0, random_seed=41, inference_backend="pymc", @@ -116,7 +116,7 @@ def test_seed(reference_idata): idata_123 = pmx.fit( method="pathfinder", - num_paths=50, + num_paths=4, jitter=10.0, random_seed=123, inference_backend="pymc", @@ -171,3 +171,33 @@ def test_bfgs_sample(): assert gamma.eval().shape == (L, 2 * J, 2 * J) assert phi.eval().shape == (L, num_samples, N) assert logq.eval().shape == (L, num_samples) + + +@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None]) +def test_pathfinder_importance_sampling(importance_sampling): + model = eight_schools_model() + + num_paths = 4 + num_draws_per_path = 300 + num_draws = 750 + + with model: + idata = pmx.fit( + method="pathfinder", + num_paths=num_paths, + num_draws_per_path=num_draws_per_path, + num_draws=num_draws, + maxiter=5, + random_seed=41, + inference_backend="pymc", + importance_sampling=importance_sampling, + ) + + if importance_sampling is None: + assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path) + assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path) + assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8) + else: + assert idata.posterior["mu"].shape == (1, num_draws) + assert idata.posterior["tau"].shape == (1, num_draws) + assert idata.posterior["theta"].shape == (1, num_draws, 8)