jaxns
User Guide
Installation
Change Log
API Reference
jaxns
Examples
Inference of Jones scalars observables (noisy angular quantities)
Lennard-Jones Potentials for modelling phase transitions in materials
Constant Likelihood
Dual Moons likelihood
Efficient parameter estimation
First the normal nested sampler (parameter_estimation=False)
Now with parameter estimation enabled
Egg-box Likelihood with Uniform Prior
Poisson likelihood and Gamma prior
Gaussian processes with outliers
Thin Gaussian Shells with Uniform Prior
Using JAXNS to globally optimise Neural Networks
Gradient Guided
Multivariate Normal Likelihood with Multivariate Normal Prior
OU process
Self-Exciting process (Hawkes process)
jaxns
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
|
X
|
Y
_
__add__() (LogSpace method)
(PyTree method)
__and__() (TerminationCondition method)
,
[1]
,
[2]
,
[3]
,
[4]
__call__() (GlobalOptimisation method)
,
[1]
(InterpolatedArray method)
(Model method)
,
[1]
,
[2]
(NestedSampler method)
,
[1]
__contains__() (ScopedDict method)
__ge__() (LogSpace method)
__getitem__() (LogSpace method)
(ScopedDict method)
__gt__() (LogSpace method)
__hash__() (AbstractModel method)
(Model method)
,
[1]
,
[2]
__iter__() (ScopedDict method)
__le__() (LogSpace method)
__len__() (ScopedDict method)
__lt__() (LogSpace method)
__mul__() (LogSpace method)
(PyTree method)
__neg__() (LogSpace method)
(PyTree method)
__or__() (TerminationCondition method)
,
[1]
,
[2]
,
[3]
,
[4]
__post_init__() (EvidenceMaximisation method)
,
[1]
(GlobalOptimisation method)
,
[1]
(InterpolatedArray method)
(MultiDimSliceSampler method)
,
[1]
(MultiEllipsoidalSampler method)
,
[1]
(NestedSampler method)
,
[1]
(ShardedStaticNestedSampler method)
,
[1]
,
[2]
,
[3]
,
[4]
(SimpleGlobalOptimisation method)
,
[1]
(UniDimSliceSampler method)
,
[1]
(UniformSampler method)
,
[1]
__pow__() (LogSpace method)
(PyTree method)
__repr__() (LogSpace method)
(Model method)
,
[1]
,
[2]
(ScopedDict method)
(WrappedTFPDistribution method)
__setitem__() (ScopedDict method)
__sub__() (LogSpace method)
(PyTree method)
__truediv__() (LogSpace method)
(PyTree method)
A
abs() (LogSpace method)
absolute_spread (GlobalOptimisationResults attribute)
,
[1]
(GlobalOptimisationState attribute)
,
[1]
AbstractDistribution (class in jaxns.framework.abc)
AbstractModel (class in jaxns.framework.abc)
AbstractNestedSampler (class in jaxns.nested_samplers.abc)
AbstractPrior (class in jaxns.framework.abc)
AbstractSampler (class in jaxns.samplers.abc)
accepted (NewtonDiagnostic attribute)
adaptive_shrink (UniDimSliceSampler attribute)
,
[1]
add_chunk_dim() (in module jaxns.internals.maps)
analytic_posterior_samples() (in module jaxns)
(in module jaxns.utils)
apply_interp() (in module jaxns.internals.interp_utils)
argmax() (LogSpace method)
atol (GlobalOptimisationTerminationCondition attribute)
,
[1]
(TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
axis (InterpolatedArray attribute)
B
BaseAbstractMarkovSampler (class in jaxns.samplers.bases)
BaseAbstractRejectionSampler (class in jaxns.samplers.bases)
batch_size (EvidenceMaximisation attribute)
,
[1]
Bernoulli (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
Beta (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
block_until_ready() (in module jaxns.internals.maps)
BoolArray (in module jaxns.internals.types)
broadcast_dtypes() (in module jaxns.internals.shapes)
broadcast_shapes() (in module jaxns.internals.shapes)
bruteforce_evidence() (in module jaxns)
(in module jaxns.utils)
bruteforce_posterior_samples() (in module jaxns)
(in module jaxns.utils)
build_hvp() (in module jaxns.experimental.solvers.ad_utils)
BUX (in module jaxns.internals.maps)
C
c (NestedSampler attribute)
,
[1]
cast_to_count() (Policy method)
cast_to_index() (Policy method)
cast_to_measure() (Policy method)
Categorical (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
cdf_diff (TruncationWrapper attribute)
,
[1]
,
[2]
cdf_low (TruncationWrapper attribute)
,
[1]
,
[2]
cg_iters (NewtonDiagnostic attribute)
cg_solve() (in module jaxns.experimental.solvers.cg)
CGDiagnostics (class in jaxns.experimental.solvers.cg)
chunked_pmap() (in module jaxns.internals.maps)
chunked_vmap() (in module jaxns.internals.maps)
cluster_id (MultEllipsoidState attribute)
complex_type (in module jaxns.internals.mixed_precision)
compute_enclosed_prior_volume() (in module jaxns.internals.shrinkage_statistics)
compute_evidence_stats() (in module jaxns.internals.shrinkage_statistics)
compute_num_live_points_from_unit_threads() (in module jaxns.internals.tree_structure)
concatenate() (LogSpace method)
concatenate_sample_trees() (in module jaxns.internals.tree_structure)
convert_external_params() (in module jaxns.framework.context)
convert_to_array() (in module jaxns.internals.shapes)
convert_to_real() (in module jaxns.experimental.solvers.gauss_newton_cg)
count_crossed_edges() (in module jaxns.internals.tree_structure)
count_crossed_edges_less_fast() (in module jaxns.internals.tree_structure)
count_dtype (Policy attribute)
count_intervals_naive() (in module jaxns.internals.tree_structure)
count_old() (in module jaxns.internals.tree_structure)
create_init_evidence_calc() (in module jaxns.internals.shrinkage_statistics)
create_init_state() (in module jaxns.nested_samplers.common.initialisation)
create_init_termination_register() (in module jaxns.nested_samplers.common.initialisation)
create_mesh() (in module jaxns.internals.maps)
CT (in module jaxns.experimental.solvers.gauss_newton_cg)
cumprod() (LogSpace method)
cumsum() (LogSpace method)
cumulative_logsumexp() (in module jaxns.internals.log_semiring)
cumulative_op_dynamic() (in module jaxns.internals.cumulative_ops)
cumulative_op_static() (in module jaxns.internals.cumulative_ops)
D
damping (NewtonDiagnostic attribute)
ddelta_x_norm (NewtonDiagnostic attribute)
DefaultGlobalOptimisation (in module jaxns.experimental)
(in module jaxns.experimental.public)
delta_f_actual (NewtonDiagnostic attribute)
delta_f_pred (NewtonDiagnostic attribute)
delta_x_norm (NewtonDiagnostic attribute)
density_estimation() (in module jaxns.internals.stats)
deprecated() (in module jaxns.warnings)
depth (MultiEllipsoidalSampler attribute)
,
[1]
deserialise_jax_ndarray() (in module jaxns.internals.namedtuple_utils)
deserialise_namedtuple() (in module jaxns.internals.namedtuple_utils)
deserialise_ndarray() (in module jaxns.internals.namedtuple_utils)
determine_termination() (in module jaxns.nested_samplers.common.termination)
devices (GlobalOptimisation attribute)
,
[1]
(NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
(SimpleGlobalOptimisation attribute)
,
[1]
dict (ScopedDict attribute)
diff() (LogSpace method)
difficult_model (NestedSampler attribute)
,
[1]
dist (Bernoulli attribute)
,
[1]
,
[2]
(Categorical attribute)
,
[1]
,
[2]
(Poisson attribute)
,
[1]
,
[2]
(Prior property)
,
[1]
,
[2]
dist_chain (WrappedTFPDistribution attribute)
dlogZ (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
DomainType (in module jaxns.experimental.solvers.cg)
(in module jaxns.experimental.solvers.gauss_newton_cg)
draw_uniform_samples() (in module jaxns.nested_samplers.common.uniform_sample)
dtype (LogSpace property)
E
e_step() (EvidenceMaximisation method)
,
[1]
(in module jaxns.samplers.multi_ellipsoid.em_gmm)
effective_sample_size_kish() (in module jaxns.internals.stats)
efficiency_threshold (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
ellipsoid_clustering() (in module jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils)
em_gmm() (in module jaxns.samplers.multi_ellipsoid.em_gmm)
Empirical (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
EphemeralState (class in jaxns.samplers.abc)
ESS (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
ess (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
evaluate_map_estimate() (in module jaxns)
(in module jaxns.utils)
evidence_uncert (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
EvidenceMaximisation (class in jaxns.experimental)
(class in jaxns.experimental.evidence_maximisation)
EvidenceUpdateVariables (class in jaxns.internals.shrinkage_statistics)
exp() (LogSpace method)
expansion_factor (MultiEllipsoidalSampler attribute)
,
[1]
ExplicitDensityPrior (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
F
F (in module jaxns.internals.maps)
f (NewtonDiagnostic attribute)
f_prop (NewtonDiagnostic attribute)
f_quad (NewtonDiagnostic attribute)
fast_perfect_live_point_computation_jax() (in module jaxns.internals.tree_structure)
final_res_norm (CGDiagnostics attribute)
fix_left (ForcedIdentifiability attribute)
,
[1]
,
[2]
fix_right (ForcedIdentifiability attribute)
,
[1]
,
[2]
float_type (in module jaxns.internals.mixed_precision)
FloatArray (in module jaxns.internals.types)
ForcedIdentifiability (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
forward() (AbstractModel method)
(Model method)
,
[1]
,
[2]
from_signed_value() (LogSpace static method)
FV (in module jaxns.internals.maps)
G
g_norm (NewtonDiagnostic attribute)
gain_ratio (NewtonDiagnostic attribute)
get_grandparent_info() (in module jaxns.internals.logging)
get_index() (in module jaxns.internals.maps)
get_interp_indices_and_weights() (in module jaxns.internals.interp_utils)
get_parameter() (in module jaxns.framework.context)
get_sample() (AbstractSampler method)
get_sample_from_seed() (BaseAbstractMarkovSampler method)
(MultiDimSliceSampler method)
,
[1]
(UniDimSliceSampler method)
,
[1]
get_seed_point() (BaseAbstractMarkovSampler method)
(MultiDimSliceSampler method)
,
[1]
(UniDimSliceSampler method)
,
[1]
get_state() (in module jaxns.framework.context)
GlobalOptimisation (class in jaxns.experimental)
(class in jaxns.experimental.public)
GlobalOptimisationResults (class in jaxns.experimental)
(class in jaxns.experimental.global_optimisation)
GlobalOptimisationState (class in jaxns.experimental)
(class in jaxns.experimental.global_optimisation)
GlobalOptimisationTerminationCondition (class in jaxns.experimental)
(class in jaxns.experimental.global_optimisation)
grad_and_hvp() (in module jaxns.experimental.solvers.ad_utils)
gradient_guided (NestedSampler attribute)
,
[1]
(UniDimSliceSampler attribute)
,
[1]
gradient_slice (GlobalOptimisation attribute)
,
[1]
(UniDimSliceSampler attribute)
,
[1]
gtol (EvidenceMaximisation attribute)
,
[1]
H
H_mean (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
high (ForcedIdentifiability attribute)
,
[1]
,
[2]
(TruncationWrapper attribute)
,
[1]
,
[2]
hvp_forward_over_reverse() (in module jaxns.experimental.solvers.ad_utils)
hvp_linearized() (in module jaxns.experimental.solvers.ad_utils)
hvp_reverse_over_forward() (in module jaxns.experimental.solvers.ad_utils)
hvp_reverse_over_reverse() (in module jaxns.experimental.solvers.ad_utils)
I
in_trust_region (NewtonDiagnostic attribute)
index_dtype (Policy attribute)
init_efficiency_threshold (NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
init_params() (Model method)
,
[1]
,
[2]
initialize_params() (in module jaxns.samplers.multi_ellipsoid.em_gmm)
int_type (in module jaxns.internals.mixed_precision)
IntArray (in module jaxns.internals.types)
InterpolatedArray (class in jaxns.internals.interp_utils)
InvalidDistribution
InvalidPriorName
,
[1]
,
[2]
is_complex() (in module jaxns.internals.log_semiring)
isinstance_namedtuple() (in module jaxns.internals.namedtuple_utils)
issubclass_namedtuple() (in module jaxns.internals.namedtuple_utils)
items() (ScopedDict method)
iteration (NewtonDiagnostic attribute)
iterations (CGDiagnostics attribute)
J
jaxify_likelihood() (in module jaxns)
(in module jaxns.framework)
(in module jaxns.framework.jaxify)
jaxns
module
jaxns.experimental
module
jaxns.experimental.evidence_maximisation
module
jaxns.experimental.global_optimisation
module
jaxns.experimental.public
module
jaxns.experimental.solvers
module
jaxns.experimental.solvers.ad_utils
module
jaxns.experimental.solvers.cg
module
jaxns.experimental.solvers.gauss_newton_cg
module
jaxns.experimental.solvers.test_cg
module
jaxns.framework
module
jaxns.framework.abc
module
jaxns.framework.bases
module
jaxns.framework.context
module
jaxns.framework.jaxify
module
jaxns.framework.model
module
jaxns.framework.ops
module
jaxns.framework.prior
module
jaxns.framework.special_priors
module
jaxns.framework.wrapped_tfp_distribution
module
jaxns.internals
module
jaxns.internals.constraint_bijections
module
jaxns.internals.cumulative_ops
module
jaxns.internals.interp_utils
module
jaxns.internals.linalg
module
jaxns.internals.log_semiring
module
jaxns.internals.logging
module
jaxns.internals.maps
module
jaxns.internals.mixed_precision
module
jaxns.internals.namedtuple_utils
module
jaxns.internals.prefix_sum
module
jaxns.internals.pytree_utils
module
jaxns.internals.random
module
jaxns.internals.shapes
module
jaxns.internals.shrinkage_statistics
module
jaxns.internals.stats
module
jaxns.internals.tree_structure
module
jaxns.internals.types
module
jaxns.nested_samplers
module
jaxns.nested_samplers.abc
module
jaxns.nested_samplers.common
module
jaxns.nested_samplers.common.initialisation
module
jaxns.nested_samplers.common.termination
module
jaxns.nested_samplers.common.types
module
jaxns.nested_samplers.common.uniform_sample
module
jaxns.nested_samplers.sharded
module
jaxns.nested_samplers.sharded.sharded_static
module
jaxns.plotting
module
jaxns.public
module
jaxns.samplers
module
jaxns.samplers.abc
module
jaxns.samplers.bases
module
jaxns.samplers.multi_ellipsoid
module
jaxns.samplers.multi_ellipsoid.em_gmm
module
jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils
module
jaxns.samplers.multi_ellipsoidal_samplers
module
jaxns.samplers.multi_slice_sampler
module
jaxns.samplers.uni_slice_sampler
module
jaxns.samplers.uniform_samplers
module
jaxns.utils
module
jaxns.warnings
module
K
k (GlobalOptimisation attribute)
,
[1]
(NestedSampler attribute)
,
[1]
key (EphemeralState attribute)
(GlobalOptimisationState attribute)
,
[1]
(NestedSamplerState attribute)
,
[1]
,
[2]
,
[3]
,
[4]
keys() (ScopedDict method)
L
left_broadcast_multiply() (in module jaxns.internals.interp_utils)
LikelihoodInputType (in module jaxns.internals.types)
LikelihoodType (in module jaxns.internals.types)
linear_to_log_stats() (in module jaxns.internals.stats)
live_evidence_frac (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
live_points_collection (EphemeralState attribute)
load_pytree() (in module jaxns)
(in module jaxns.utils)
load_results() (in module jaxns)
(in module jaxns.utils)
log() (LogSpace method)
log_abs_val (LogSpace property)
log_dp_mean (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_efficiency (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_L (SampleTreeGraph attribute)
log_L0 (SeedPoint attribute)
log_L_contour (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_L_next (EvidenceUpdateVariables attribute)
log_L_progress (GlobalOptimisationResults attribute)
,
[1]
log_L_samples (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_L_solution (GlobalOptimisationResults attribute)
,
[1]
log_likelihood_contour (GlobalOptimisationTerminationCondition attribute)
,
[1]
log_posterior_density (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_prob_joint() (AbstractModel method)
log_prob_likelihood() (AbstractModel method)
log_prob_prior() (AbstractModel method)
(Model method)
,
[1]
,
[2]
log_X_mean (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_Z_atol (EvidenceMaximisation attribute)
,
[1]
log_Z_ftol (EvidenceMaximisation attribute)
,
[1]
log_Z_mean (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
log_Z_uncert (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
logaddexp() (in module jaxns.internals.log_semiring)
logger (in module jaxns.internals.logging)
LogSpace (class in jaxns.internals.log_semiring)
low (ForcedIdentifiability attribute)
,
[1]
,
[2]
(TruncationWrapper attribute)
,
[1]
,
[2]
M
m_step() (EvidenceMaximisation method)
,
[1]
(in module jaxns.samplers.multi_ellipsoid.em_gmm)
marginalise_dynamic() (in module jaxns)
(in module jaxns.utils)
marginalise_dynamic_from_U() (in module jaxns)
(in module jaxns.utils)
marginalise_static() (in module jaxns)
(in module jaxns.utils)
marginalise_static_from_U() (in module jaxns)
(in module jaxns.utils)
max() (LogSpace method)
max_likelihood_evals (UniformSampler attribute)
,
[1]
max_likelihood_evaluations (GlobalOptimisationTerminationCondition attribute)
,
[1]
max_num_ellipsoids (MultiEllipsoidalSampler property)
,
[1]
max_num_epochs (EvidenceMaximisation attribute)
,
[1]
max_num_likelihood_evaluations (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
max_samples (NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
(TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
maximum() (LogSpace method)
maximum_a_posteriori_point() (in module jaxns)
(in module jaxns.utils)
mean() (LogSpace method)
measure_dtype (Policy attribute)
MeasureType (in module jaxns.internals.types)
midpoint_shrink (UniDimSliceSampler attribute)
,
[1]
min() (LogSpace method)
min_efficiency (GlobalOptimisationTerminationCondition attribute)
,
[1]
minimum() (LogSpace method)
Model (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.model)
model (EvidenceMaximisation attribute)
,
[1]
(GlobalOptimisation attribute)
,
[1]
(MultiDimSliceSampler attribute)
,
[1]
(MultiEllipsoidalSampler attribute)
,
[1]
(NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
(SimpleGlobalOptimisation attribute)
,
[1]
(UniDimSliceSampler attribute)
,
[1]
(UniformSampler attribute)
,
[1]
module
jaxns
jaxns.experimental
jaxns.experimental.evidence_maximisation
jaxns.experimental.global_optimisation
jaxns.experimental.public
jaxns.experimental.solvers
jaxns.experimental.solvers.ad_utils
jaxns.experimental.solvers.cg
jaxns.experimental.solvers.gauss_newton_cg
jaxns.experimental.solvers.test_cg
jaxns.framework
jaxns.framework.abc
jaxns.framework.bases
jaxns.framework.context
jaxns.framework.jaxify
jaxns.framework.model
jaxns.framework.ops
jaxns.framework.prior
jaxns.framework.special_priors
jaxns.framework.wrapped_tfp_distribution
jaxns.internals
jaxns.internals.constraint_bijections
jaxns.internals.cumulative_ops
jaxns.internals.interp_utils
jaxns.internals.linalg
jaxns.internals.log_semiring
jaxns.internals.logging
jaxns.internals.maps
jaxns.internals.mixed_precision
jaxns.internals.namedtuple_utils
jaxns.internals.prefix_sum
jaxns.internals.pytree_utils
jaxns.internals.random
jaxns.internals.shapes
jaxns.internals.shrinkage_statistics
jaxns.internals.stats
jaxns.internals.tree_structure
jaxns.internals.types
jaxns.nested_samplers
jaxns.nested_samplers.abc
jaxns.nested_samplers.common
jaxns.nested_samplers.common.initialisation
jaxns.nested_samplers.common.termination
jaxns.nested_samplers.common.types
jaxns.nested_samplers.common.uniform_sample
jaxns.nested_samplers.sharded
jaxns.nested_samplers.sharded.sharded_static
jaxns.plotting
jaxns.public
jaxns.samplers
jaxns.samplers.abc
jaxns.samplers.bases
jaxns.samplers.multi_ellipsoid
jaxns.samplers.multi_ellipsoid.em_gmm
jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils
jaxns.samplers.multi_ellipsoidal_samplers
jaxns.samplers.multi_slice_sampler
jaxns.samplers.uni_slice_sampler
jaxns.samplers.uniform_samplers
jaxns.utils
jaxns.warnings
mp_policy (in module jaxns.internals.mixed_precision)
msqrt() (in module jaxns.internals.linalg)
mu (NewtonDiagnostic attribute)
MultEllipsoidState (class in jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils)
MultiDimSliceSampler (class in jaxns.samplers)
(class in jaxns.samplers.multi_slice_sampler)
MultiEllipsoidalSampler (class in jaxns.samplers)
(class in jaxns.samplers.multi_ellipsoidal_samplers)
N
n (ForcedIdentifiability attribute)
,
[1]
,
[2]
name (Prior attribute)
,
[1]
,
[2]
nansum() (LogSpace method)
nested_sampler (NestedSampler property)
,
[1]
NestedSampler (class in jaxns)
(class in jaxns.public)
NestedSamplerResults (class in jaxns)
(class in jaxns.nested_samplers)
(class in jaxns.nested_samplers.common)
(class in jaxns.nested_samplers.common.types)
(class in jaxns.samplers)
NestedSamplerState (class in jaxns)
(class in jaxns.nested_samplers)
(class in jaxns.nested_samplers.common)
(class in jaxns.nested_samplers.common.types)
(class in jaxns.samplers)
newton_cg_solver() (in module jaxns.experimental.solvers.gauss_newton_cg)
NewtonDiagnostic (class in jaxns.experimental.solvers.gauss_newton_cg)
next_rng_key() (in module jaxns.framework.context)
next_sample_idx (NestedSamplerState attribute)
,
[1]
,
[2]
,
[3]
,
[4]
normal_to_lognormal() (in module jaxns.internals.stats)
normalise_log_space() (in module jaxns.internals.log_semiring)
ns_kwargs (EvidenceMaximisation attribute)
,
[1]
num_dynamic_refinement_iterations (ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
num_likelihood_evaluations (GlobalOptimisationResults attribute)
,
[1]
(GlobalOptimisationState attribute)
,
[1]
num_likelihood_evaluations_per_sample (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
num_live_points (EvidenceUpdateVariables attribute)
(NestedSampler attribute)
,
[1]
(SampleLivePointCounts attribute)
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
num_live_points_per_sample (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
num_params (Model property)
,
[1]
,
[2]
num_phantom() (AbstractSampler method)
(MultiDimSliceSampler method)
,
[1]
(MultiEllipsoidalSampler method)
,
[1]
(UniDimSliceSampler method)
,
[1]
(UniformSampler method)
,
[1]
num_phantom_save (MultiDimSliceSampler attribute)
,
[1]
(UniDimSliceSampler attribute)
,
[1]
num_restrict_dims (MultiDimSliceSampler attribute)
,
[1]
num_samples (GlobalOptimisationResults attribute)
,
[1]
(GlobalOptimisationState attribute)
,
[1]
(NestedSamplerState attribute)
,
[1]
,
[2]
,
[3]
,
[4]
num_search_chains (GlobalOptimisation attribute)
,
[1]
(SimpleGlobalOptimisation attribute)
,
[1]
num_slices (MultiDimSliceSampler attribute)
,
[1]
(NestedSampler attribute)
,
[1]
(UniDimSliceSampler attribute)
,
[1]
O
ObjectiveRet (in module jaxns.experimental.solvers.gauss_newton_cg)
P
parameter_estimation (NestedSampler attribute)
,
[1]
parametrised() (Prior method)
,
[1]
,
[2]
parametrised_samples (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
params (Model property)
,
[1]
,
[2]
(MultEllipsoidState attribute)
peak_XL_frac (TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
perfect (UniDimSliceSampler attribute)
,
[1]
plot_cornerplot() (in module jaxns)
(in module jaxns.plotting)
plot_diagnostics() (in module jaxns)
(in module jaxns.plotting)
plot_tree() (in module jaxns.internals.tree_structure)
Poisson (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
Policy (class in jaxns.internals.mixed_precision)
pop_scope() (ScopedDict method)
post_process() (AbstractSampler method)
pre_process() (AbstractSampler method)
prepad() (in module jaxns.internals.maps)
prepare_func_args() (in module jaxns.internals.maps)
prepare_input() (AbstractModel method)
(Model method)
,
[1]
,
[2]
Prior (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.prior)
prior (TruncationWrapper attribute)
,
[1]
,
[2]
PriorModelGen (in module jaxns)
(in module jaxns.framework)
(in module jaxns.framework.bases)
PriorModelType (in module jaxns)
(in module jaxns.framework)
(in module jaxns.framework.bases)
PRNGKey (in module jaxns.internals.types)
PT (in module jaxns.internals.maps)
push_scope() (ScopedDict method)
PV (in module jaxns.internals.maps)
PyTree (class in jaxns.internals.maps)
pytree_unpack() (in module jaxns.internals.maps)
pytree_unravel() (in module jaxns.internals.maps)
Q
quick_unit() (in module jaxns.internals.constraint_bijections)
quick_unit_inverse() (in module jaxns.internals.constraint_bijections)
R
random_ortho_matrix() (in module jaxns.internals.random)
RandomVariableType (in module jaxns.internals.types)
refine_threshold (ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
regular_grid (InterpolatedArray attribute)
relative_spread (GlobalOptimisationResults attribute)
,
[1]
(GlobalOptimisationState attribute)
,
[1]
remove_chunk_dim() (in module jaxns.internals.maps)
replace_index() (in module jaxns.internals.maps)
resample() (in module jaxns)
(in module jaxns.utils)
resample_indicies() (in module jaxns.internals.random)
rtol (GlobalOptimisationTerminationCondition attribute)
,
[1]
(TerminationCondition attribute)
,
[1]
,
[2]
,
[3]
,
[4]
S
s (GlobalOptimisation attribute)
,
[1]
(NestedSampler attribute)
,
[1]
sample_collection (NestedSamplerState attribute)
,
[1]
,
[2]
,
[3]
,
[4]
sample_evidence() (in module jaxns)
(in module jaxns.utils)
sample_multi_ellipsoid() (in module jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils)
sample_U() (AbstractModel method)
(Model method)
,
[1]
,
[2]
sample_W() (Model method)
,
[1]
,
[2]
SampleLivePointCounts (class in jaxns.internals.tree_structure)
sampler (ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
(SimpleGlobalOptimisation attribute)
,
[1]
SamplerState (in module jaxns.samplers.abc)
samples (GlobalOptimisationState attribute)
,
[1]
(NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
samples_indices (SampleLivePointCounts attribute)
SampleTreeGraph (class in jaxns.internals.tree_structure)
sanity_check() (AbstractModel method)
(Model method)
,
[1]
,
[2]
save_pytree() (in module jaxns)
(in module jaxns.utils)
save_results() (in module jaxns)
(in module jaxns.utils)
scan_associative() (in module jaxns.internals.prefix_sum)
scan_associative_cumulative_op() (in module jaxns.internals.cumulative_ops)
scope() (in module jaxns.framework.context)
scope_prefix (ScopedDict property)
ScopedDict (class in jaxns.framework.context)
scopes (ScopedDict attribute)
SeedPoint (class in jaxns.samplers.bases)
sender_node_idx (SampleTreeGraph attribute)
serialise_jax_ndarray() (in module jaxns.internals.namedtuple_utils)
serialise_namedtuple() (in module jaxns.internals.namedtuple_utils)
serialise_ndarray() (in module jaxns.internals.namedtuple_utils)
set_params() (Model method)
,
[1]
,
[2]
set_state() (in module jaxns.framework.context)
shape (InterpolatedArray property)
ShardedStaticNestedSampler (class in jaxns)
(class in jaxns.nested_samplers)
(class in jaxns.nested_samplers.sharded)
(class in jaxns.nested_samplers.sharded.sharded_static)
(class in jaxns.samplers)
shell_frac (GlobalOptimisation attribute)
,
[1]
(SimpleGlobalOptimisation attribute)
,
[1]
shell_fraction (NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
sign (LogSpace property)
signed_log (LogSpace property)
signed_logaddexp() (in module jaxns.internals.log_semiring)
SimpleGlobalOptimisation (class in jaxns.experimental)
(class in jaxns.experimental.global_optimisation)
simulate_prior_model() (in module jaxns.framework.ops)
size (LogSpace property)
solution (GlobalOptimisationResults attribute)
,
[1]
SPT (in module jaxns.internals.maps)
sqrt() (LogSpace method)
square() (LogSpace method)
squared_norm() (in module jaxns.internals.linalg)
sum() (LogSpace method)
summary() (in module jaxns)
(in module jaxns.utils)
T
T (in module jaxns.internals.maps)
(in module jaxns.internals.mixed_precision)
(in module jaxns.samplers.bases)
termination_cond (EvidenceMaximisation attribute)
,
[1]
termination_reason (GlobalOptimisationResults attribute)
,
[1]
(NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
termination_register (EphemeralState attribute)
TerminationCondition (class in jaxns)
(class in jaxns.nested_samplers)
(class in jaxns.nested_samplers.common)
(class in jaxns.nested_samplers.common.types)
(class in jaxns.samplers)
test_cg() (in module jaxns.experimental.solvers.test_cg)
to_dict() (ScopedDict method)
to_results() (NestedSampler method)
,
[1]
total_num_likelihood_evaluations (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
total_num_samples (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
total_phantom_samples (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
train() (EvidenceMaximisation method)
,
[1]
transform() (AbstractModel method)
(in module jaxns.framework.context)
(Model method)
,
[1]
,
[2]
transform_parametrised() (AbstractModel method)
(Model method)
,
[1]
,
[2]
transform_with_state() (in module jaxns.framework.context)
tree (PyTree attribute)
tree_add() (in module jaxns.experimental.solvers.ad_utils)
tree_device_put() (in module jaxns.internals.maps)
tree_div() (in module jaxns.experimental.solvers.ad_utils)
(in module jaxns.internals.pytree_utils)
tree_dot() (in module jaxns.experimental.solvers.ad_utils)
(in module jaxns.internals.pytree_utils)
tree_mul() (in module jaxns.experimental.solvers.ad_utils)
(in module jaxns.internals.pytree_utils)
tree_neg() (in module jaxns.experimental.solvers.ad_utils)
tree_norm() (in module jaxns.experimental.solvers.ad_utils)
(in module jaxns.internals.pytree_utils)
tree_scalar_mul() (in module jaxns.experimental.solvers.ad_utils)
tree_sub() (in module jaxns.experimental.solvers.ad_utils)
(in module jaxns.internals.pytree_utils)
tree_vdot() (in module jaxns.experimental.solvers.ad_utils)
tree_vdot_real_part() (in module jaxns.experimental.solvers.ad_utils)
trim_results() (NestedSampler static method)
,
[1]
TruncationWrapper (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
tuple_prod() (in module jaxns.internals.shapes)
U
U0 (SeedPoint attribute)
U_samples (NestedSamplerResults attribute)
,
[1]
,
[2]
,
[3]
,
[4]
U_solution (GlobalOptimisationResults attribute)
,
[1]
UniDimSliceSampler (class in jaxns.samplers)
(class in jaxns.samplers.uni_slice_sampler)
UniformSampler (class in jaxns.samplers)
(class in jaxns.samplers.uniform_samplers)
UnnormalisedDirichlet (class in jaxns)
(class in jaxns.framework)
(class in jaxns.framework.special_priors)
update_evicence_calculation() (in module jaxns.internals.shrinkage_statistics)
UType (in module jaxns.internals.types)
V
V (in module jaxns.internals.cumulative_ops)
value (LogSpace property)
(Prior property)
,
[1]
,
[2]
values (InterpolatedArray attribute)
values() (ScopedDict method)
var() (LogSpace method)
verbose (EvidenceMaximisation attribute)
,
[1]
(GlobalOptimisation attribute)
,
[1]
(NestedSampler attribute)
,
[1]
(ShardedStaticNestedSampler attribute)
,
[1]
,
[2]
,
[3]
,
[4]
(SimpleGlobalOptimisation attribute)
,
[1]
W
wrap_random() (in module jaxns.framework.context)
WrappedTFPDistribution (class in jaxns.framework.wrapped_tfp_distribution)
X
X (in module jaxns.internals.cumulative_ops)
(in module jaxns.internals.mixed_precision)
(in module jaxns.internals.prefix_sum)
x (InterpolatedArray attribute)
X_solution (GlobalOptimisationResults attribute)
,
[1]
XType (in module jaxns.internals.types)
Y
Y (in module jaxns.internals.cumulative_ops)