random

jaxns.internals.random

Module Contents

random_ortho_matrix(key, n, special_orthogonal=False)[source]

Samples a random orthonormal n by n matrix from Stiefels manifold. From https://stackoverflow.com/a/38430739

Parameters:
  • key – PRNG seed

  • n – Size of matrix, draws from O(num_options) group.

  • special_orthogonal (bool) –

Returns: random [num_options,num_options] matrix with determinant = +-1

resample_indicies(key, log_weights=None, S=None, replace=False, num_total=None)[source]

Get resample indicies according to a given weighting, with or without replacement.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • log_weights (Optional[jaxns.internals.types.FloatArray]) – Optional log weights

  • S (Optional[int]) – Optional number of samples. Computes effective sample size from log weights if not given.

  • replace (bool) – whether to use replacement or not.

  • num_total (Optional[int]) – Optional total sample size to use, must be given if replace=False and log_weights=None

Returns:

index array given the take indicies to resample at.

Return type:

jaxns.internals.types.IntArray