Skip to content

Fix: importance_sampling=None produces error #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions pymc_extras/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class ImportanceSamplingResult:
samples: NDArray
pareto_k: float | None = None
warnings: list[str] = field(default_factory=list)
method: str = "none"
method: str = "psis"


def importance_sampling(
samples: NDArray,
logP: NDArray,
logQ: NDArray,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"] | None,
method: Literal["psis", "psir", "identity"] | None,
random_seed: int | None = None,
) -> ImportanceSamplingResult:
"""Pareto Smoothed Importance Resampling (PSIR)
Expand All @@ -44,8 +44,15 @@ def importance_sampling(
log probability values of proposal distribution, shape (L, M)
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
method : str, optional
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
method : str, None, optional
Method to apply sampling based on log importance weights (logP - logQ).
Options are:
"psis" : Pareto Smoothed Importance Sampling (default)
Recommended for more stable results.
"psir" : Pareto Smoothed Importance Resampling
Less stable than PSIS.
"identity" : Applies log importance weights directly without resampling.
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
random_seed : int | None

Returns
Expand All @@ -71,11 +78,11 @@ def importance_sampling(
warnings = []
num_paths, _, N = samples.shape

if method == "none":
if method is None:
warnings.append(
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
)
return ImportanceSamplingResult(samples=samples, warnings=warnings)
return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
else:
samples = samples.reshape(-1, N)
logP = logP.ravel()
Expand All @@ -91,17 +98,16 @@ def importance_sampling(
_warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
)
if method == "psis":
replace = False
logiw, pareto_k = az.psislw(logiw)
elif method == "psir":
replace = True
logiw, pareto_k = az.psislw(logiw)
elif method == "identity":
replace = False
pareto_k = None
else:
raise ValueError(f"Invalid importance sampling method: {method}")
match method:
case "psis":
replace = False
logiw, pareto_k = az.psislw(logiw)
case "psir":
replace = True
logiw, pareto_k = az.psislw(logiw)
case "identity":
replace = False
pareto_k = None

# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
# Pareto k may not be a good diagnostic for Pathfinder.
Expand Down
50 changes: 33 additions & 17 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def convert_flat_trace_to_idata(
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
inference_backend: Literal["pymc", "blackjax"] = "pymc",
model: Model | None = None,
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
) -> az.InferenceData:
"""convert flattened samples to arviz InferenceData format.

Expand All @@ -181,7 +181,7 @@ def convert_flat_trace_to_idata(
arviz inference data object
"""

if importance_sampling == "none":
if importance_sampling is None:
# samples.ndim == 3 in this case, otherwise ndim == 2
num_paths, num_pdraws, N = samples.shape
samples = samples.reshape(-1, N)
Expand Down Expand Up @@ -220,7 +220,7 @@ def convert_flat_trace_to_idata(
fn.trust_input = True
result = fn(*list(trace.values()))

if importance_sampling == "none":
if importance_sampling is None:
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]

elif inference_backend == "blackjax":
Expand Down Expand Up @@ -1189,7 +1189,7 @@ class MultiPathfinderResult:
elbo_argmax: NDArray | None = None
lbfgs_status: Counter = field(default_factory=Counter)
path_status: Counter = field(default_factory=Counter)
importance_sampling: str = "none"
importance_sampling: str | None = "psis"
warnings: list[str] = field(default_factory=list)
pareto_k: float | None = None

Expand Down Expand Up @@ -1258,7 +1258,7 @@ def with_warnings(self, warnings: list[str]) -> Self:
def with_importance_sampling(
self,
num_draws: int,
method: Literal["psis", "psir", "identity", "none"] | None,
method: Literal["psis", "psir", "identity"] | None,
random_seed: int | None = None,
) -> Self:
"""perform importance sampling"""
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def multipath_pathfinder(
num_elbo_draws: int,
jitter: float,
epsilon: float,
importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
importance_sampling: Literal["psis", "psir", "identity"] | None,
progressbar: bool,
concurrent: Literal["thread", "process"] | None,
random_seed: RandomSeed,
Expand Down Expand Up @@ -1460,8 +1460,14 @@ def multipath_pathfinder(
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
epsilon: float
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
importance_sampling : str, optional
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
importance_sampling : str, None, optional
Method to apply sampling based on log importance weights (logP - logQ).
"psis" : Pareto Smoothed Importance Sampling (default)
Recommended for more stable results.
"psir" : Pareto Smoothed Importance Resampling
Less stable than PSIS.
"identity" : Applies log importance weights directly without resampling.
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
progressbar : bool, optional
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
random_seed : RandomSeed, optional
Expand All @@ -1483,12 +1489,6 @@ def multipath_pathfinder(
The result containing samples and other information from the Multi-Path Pathfinder algorithm.
"""

valid_importance_sampling = ["psis", "psir", "identity", "none", None]
if importance_sampling is None:
importance_sampling = "none"
if importance_sampling.lower() not in valid_importance_sampling:
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")

*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)

pathfinder_config = PathfinderConfig(
Expand Down Expand Up @@ -1622,7 +1622,7 @@ def fit_pathfinder(
num_elbo_draws: int = 10, # K
jitter: float = 2.0,
epsilon: float = 1e-8,
importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
progressbar: bool = True,
concurrent: Literal["thread", "process"] | None = None,
random_seed: RandomSeed | None = None,
Expand Down Expand Up @@ -1662,8 +1662,15 @@ def fit_pathfinder(
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
epsilon: float
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
importance_sampling : str, optional
importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
importance_sampling : str, None, optional
Method to apply sampling based on log importance weights (logP - logQ).
Options are:
"psis" : Pareto Smoothed Importance Sampling (default)
Recommended for more stable results.
"psir" : Pareto Smoothed Importance Resampling
Less stable than PSIS.
"identity" : Applies log importance weights directly without resampling.
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
progressbar : bool, optional
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
random_seed : RandomSeed, optional
Expand All @@ -1690,6 +1697,15 @@ def fit_pathfinder(
"""

model = modelcontext(model)

valid_importance_sampling = {"psis", "psir", "identity", None}

if importance_sampling is not None:
importance_sampling = importance_sampling.lower()

if importance_sampling not in valid_importance_sampling:
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")

N = DictToArrayBijection.map(model.initial_point()).data.shape[0]

if maxcor is None:
Expand Down
50 changes: 40 additions & 10 deletions tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def reference_idata():
with model:
idata = pmx.fit(
method="pathfinder",
num_paths=50,
jitter=10.0,
num_paths=10,
jitter=12.0,
random_seed=41,
inference_backend="pymc",
)
Expand All @@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata):
with model:
idata = pmx.fit(
method="pathfinder",
num_paths=50,
jitter=10.0,
num_paths=10,
jitter=12.0,
random_seed=41,
inference_backend=inference_backend,
)
else:
idata = reference_idata
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)

assert idata.posterior["mu"].shape == (1, 1000)
assert idata.posterior["tau"].shape == (1, 1000)
Expand All @@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent):
with model:
idata_conc = pmx.fit(
method="pathfinder",
num_paths=50,
jitter=10.0,
num_paths=10,
jitter=12.0,
random_seed=41,
inference_backend="pymc",
concurrent=concurrent,
Expand All @@ -108,15 +108,15 @@ def test_seed(reference_idata):
with model:
idata_41 = pmx.fit(
method="pathfinder",
num_paths=50,
num_paths=4,
jitter=10.0,
random_seed=41,
inference_backend="pymc",
)

idata_123 = pmx.fit(
method="pathfinder",
num_paths=50,
num_paths=4,
jitter=10.0,
random_seed=123,
inference_backend="pymc",
Expand Down Expand Up @@ -171,3 +171,33 @@ def test_bfgs_sample():
assert gamma.eval().shape == (L, 2 * J, 2 * J)
assert phi.eval().shape == (L, num_samples, N)
assert logq.eval().shape == (L, num_samples)


@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
def test_pathfinder_importance_sampling(importance_sampling):
model = eight_schools_model()

num_paths = 4
num_draws_per_path = 300
num_draws = 750

with model:
idata = pmx.fit(
method="pathfinder",
num_paths=num_paths,
num_draws_per_path=num_draws_per_path,
num_draws=num_draws,
maxiter=5,
random_seed=41,
inference_backend="pymc",
importance_sampling=importance_sampling,
)

if importance_sampling is None:
assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
else:
assert idata.posterior["mu"].shape == (1, num_draws)
assert idata.posterior["tau"].shape == (1, num_draws)
assert idata.posterior["theta"].shape == (1, num_draws, 8)