utils

jaxns.utils

Module Contents

resample(key, samples, log_weights, S=None, replace=False)[source]

Resample the weighted samples into uniformly weighted samples.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results

  • log_weights (jax.numpy.ndarray) – log-posterior weight

  • S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.

  • replace (bool) – whether to sample with replacement

Returns:

equally weighted samples

Return type:

jaxns.internals.types.XType

marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (jax.numpy.ndarray) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

marginalise_static(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (jax.numpy.ndarray) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

maximum_a_posteriori_point(results)[source]

Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).

Parameters:

results (NestedSamplerResult) – Nested sampler result

Returns:

dict of samples at MAP-point.

Return type:

jaxns.internals.types.XType

evaluate_map_estimate(results, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
Returns:

estimate at MAP sample point

Return type:

_V

summary(results, with_parametrised=False, f_obj=None)[source]

Gives a summary of the results of a nested sampling run.

Parameters:
  • results (NestedSamplerResults) – Nested sampler result

  • with_parametrised (bool) – whether to include parametrised samples

  • f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.

sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]

Sample the evidence distribution, but stochastically simulating the shrinkage distribution.

Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample

  • log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples

  • S (int) – The number of samples to produce

Returns:

samples of log(Z)

Return type:

jaxns.internals.types.FloatArray

bruteforce_posterior_samples(model, S=60)[source]

Compute the posterior with brute-force over a regular grid.

Parameters:
Returns:

samples, and log-weight

Return type:

Tuple[jaxns.internals.types.XType, jax.numpy.ndarray]

bruteforce_evidence(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
Returns:

log(Z)

analytic_posterior_samples(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
Returns:

log(Z)

save_pytree(pytree, save_file)[source]

Saves results of nested sampler in a npz file.

Parameters:
  • pytree (NamedTuple) – Nested sampler result

  • save_file (str) – filename

save_results(results, save_file)[source]

Saves results of nested sampler in a npz file.

Parameters:
load_pytree(save_file)[source]

Loads saved nested sampler results from a npz file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

load_results(save_file)[source]

Loads saved nested sampler results from a npz file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

Return type:

jaxns.internals.types.NestedSamplerResults