Source code for jaxns.internals.namedtuple_utils

import base64
import importlib

import jax
import jax.numpy as jnp
import numpy as np


[docs] def isinstance_namedtuple(obj) -> bool: """ Check if object is a namedtuple. Args: obj: object Returns: bool """ return ( isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') )
[docs] def issubclass_namedtuple(cls): """ Check if the type object is a subclass of a namedtuple. """ base_types = cls.__mro__ # Get the method resolution order of the class return any(hasattr(base, '_fields') and hasattr(base, '_asdict') for base in base_types)
[docs] def serialise_namedtuple(obj): if isinstance_namedtuple(obj): class_name = f"{obj.__class__.__module__}.{obj.__class__.__name__}" return {'type': '__namedtuple__', '__class__': class_name, '__data__': {k: serialise_namedtuple(v) for k, v in obj._asdict().items()}} elif isinstance(obj, np.ndarray): return serialise_ndarray(obj) elif isinstance(obj, jax.Array): return serialise_jax_ndarray(obj) elif isinstance(obj, (list, tuple)): return [serialise_namedtuple(v) for v in obj] elif isinstance(obj, dict): return {k: serialise_namedtuple(v) for k, v in obj.items()} else: return obj
[docs] def deserialise_namedtuple(obj): if isinstance(obj, dict) and 'type' in obj and obj['type'] == '__namedtuple__': class_path = obj['__class__'] module_name, class_name = class_path.rsplit('.', 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_(**{k: deserialise_namedtuple(v) for k, v in obj['__data__'].items()}) elif isinstance(obj, dict) and 'type' in obj and obj['type'] == '__ndarray__': return deserialise_ndarray(obj) elif isinstance(obj, dict) and 'type' in obj and obj['type'] == '__jax_ndarray__': return deserialise_jax_ndarray(obj) elif isinstance(obj, list): return [deserialise_namedtuple(v) for v in obj] elif isinstance(obj, tuple): return tuple(deserialise_namedtuple(v) for v in obj) elif isinstance(obj, dict): return {k: deserialise_namedtuple(v) for k, v in obj.items()} return obj
[docs] def serialise_ndarray(obj): if isinstance(obj, np.ndarray): data_bytes = obj.tobytes() bytes_base64 = base64.b64encode(data_bytes).decode('utf-8') return {'type': '__ndarray__', '__dtype__': str(obj.dtype), '__data__': bytes_base64, '__shape__': obj.shape} return obj
[docs] def deserialise_ndarray(obj): if isinstance(obj, dict) and obj.get('type') == '__ndarray__': bytes_base64 = obj['__data__'] data_bytes = base64.b64decode(bytes_base64) # make array from bytes and give correct dtype and shape return np.frombuffer(data_bytes, dtype=obj['__dtype__']).reshape(obj['__shape__']) return obj
[docs] def serialise_jax_ndarray(obj): if isinstance(obj, jax.Array): data_bytes = np.asarray(obj).tobytes() bytes_base64 = base64.b64encode(data_bytes).decode('utf-8') return {'type': '__jax_ndarray__', '__dtype__': str(obj.dtype), '__data__': bytes_base64, '__shape__': obj.shape} return obj
[docs] def deserialise_jax_ndarray(obj): if isinstance(obj, dict) and obj.get('type') == '__jax_ndarray__': bytes_base64 = obj['__data__'] data_bytes = base64.b64decode(bytes_base64) # make array from bytes and give correct dtype and shape return jnp.frombuffer(data_bytes, dtype=obj['__dtype__']).reshape(obj['__shape__']) return obj