🥽 Masking API#
- sepes.is_masked(value)[source]#
Returns True if the value is a frozen wrapper.
- Parameters:
value (
Any) –- Return type:
bool
- sepes.tree_mask(tree, cond=<function is_nondiff>, *, is_leaf=None)[source]#
Mask leaves of a pytree based on
maskboolean pytree or callable.Masked leaves are wrapped with a wrapper that yields no leaves when
tree_flattenis called on it.- Parameters:
tree (T) – A pytree of values.
cond (Callable[[Any], bool]) – A callable that accepts a leaf and returns a boolean to mark the leaf for masking. Defaults to masking non-differentiable leaf nodes that are not instances of of python float, python complex, or inexact array types.
is_leaf (Callable[[Any], None] | None) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example,
is_leaf=lambda x: isinstance(x, list)will treat lists as leaves and will not recurse into them.
Example
>>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
Pass non-differentiable values to
jax.grad>>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
Example
Define a custom masking wrapper for a specific type.
>>> import sepes as sp >>> import jax >>> import dataclasses as dc >>> @dc.dataclass ... class MyInt: ... value: int >>> @dc.dataclass ... class MaskedInt: ... value: MyInt >>> # define a rule of how to mask an integer >>> @sp.tree_mask.def_type(MyInt) ... def mask_int(value): ... return MaskedInt(value) >>> # define a rule how to unmask the wrapper >>> @sp.tree_unmask.def_type(MaskedInt) ... def unmask_int(value): ... return value.value >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) >>> masked_tree [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] >>> sp.tree_unmask(masked_tree) [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
- sepes.tree_unmask(tree, cond=<function <lambda>>)[source]#
Undo the masking of tree leaves according to
cond. defaults to unmasking all leaves.- Parameters:
tree (
TypeVar(T)) – A pytree of values.cond (
Callable[[Any],bool]) – A callable that accepts a leaf and returns a boolean to mark the leaf to be unmasked. Defaults to always unmask.
Example
>>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
Pass non-differentiable values to
jax.grad>>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
Example
Define a custom masking wrapper for a specific type.
>>> import sepes as sp >>> import jax >>> import dataclasses as dc >>> @dc.dataclass ... class MyInt: ... value: int >>> @dc.dataclass ... class MaskedInt: ... value: MyInt >>> # define a rule of how to mask an integer >>> @sp.tree_mask.def_type(MyInt) ... def mask_int(value): ... return MaskedInt(value) >>> # define a rule how to unmask the wrapper >>> @sp.tree_unmask.def_type(MaskedInt) ... def unmask_int(value): ... return value.value >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) >>> masked_tree [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] >>> sp.tree_unmask(masked_tree) [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]