utils
jaxns.utils
Module Contents
- resample(key, samples, log_weights, S=None, replace=False)[source]
Resample the weighted samples into uniformly weighted samples.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results
log_weights (jax.numpy.ndarray) – log-posterior weight
S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.
replace (bool) – whether to sample with replacement
- Returns:
equally weighted samples
- Return type:
jaxns.internals.types.XType
- marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (int) – static effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (jax.numpy.ndarray) – dynamic effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- marginalise_static(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
samples (dict) – dict of batched array of nested sampling samples
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (jax.numpy.ndarray) – dynamic effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- maximum_a_posteriori_point(results)[source]
Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).
- Parameters:
results (NestedSamplerResult) – Nested sampler result
- Returns:
dict of samples at MAP-point.
- Return type:
jaxns.internals.types.XType
- evaluate_map_estimate(results, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – results from run
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
estimate at MAP sample point
- Return type:
_V
- summary(results, with_parametrised=False, f_obj=None)[source]
Gives a summary of the results of a nested sampling run.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
with_parametrised (bool) – whether to include parametrised samples
f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.
- sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]
Sample the evidence distribution, but stochastically simulating the shrinkage distribution.
Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample
log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples
S (int) – The number of samples to produce
- Returns:
samples of log(Z)
- Return type:
jaxns.internals.types.FloatArray
- bruteforce_posterior_samples(model, S=60)[source]
Compute the posterior with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
samples, and log-weight
- Return type:
Tuple[jaxns.internals.types.XType, jax.numpy.ndarray]
- bruteforce_evidence(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- analytic_posterior_samples(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- save_pytree(pytree, save_file)[source]
Saves results of nested sampler in a npz file.
- Parameters:
pytree (NamedTuple) – Nested sampler result
save_file (str) – filename
- save_results(results, save_file)[source]
Saves results of nested sampler in a npz file.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
save_file (str) – filename
- load_pytree(save_file)[source]
Loads saved nested sampler results from a npz file.
- Parameters:
save_file (str) – filename
- Returns:
NestedSamplerResults