🏟️ Fields#

[1]:
!pip install sepes

This section introduces common recipes for fields. A sepes.field is class variable that adds certain functionality to the class with jax and numpy, but this can work with any other framework.

Add field is written like this:

class MyClass:
    my_field: Any = sepes.field()

For example, a field can be used to validate the input data, or to provide a default value. The notebook provides examples for common use cases.

sepes.field is implemented as a python descriptor, which means that it can be used in any class not necessarily a sepes class. Refer to the python documentation for more information on descriptors and how they work.

[1] Buffers#

In this example, certain array will be marked as non-trainable using jax.lax.stop_gradient and field.

The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:

class Tree(sp.TreeClass):
    def __init__(self, buffer: jax.Array):
        self.buffer = buffer

    def __call__(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)

However, if you access this buffer from other methods, then another jax.lax.stop_gradient should be used and written inside all the methods:

class Tree(sp.TreeClass):
    def method_1(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
        .
        .
        .
    def method_n(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)

Similarly, if you access buffer defined for Tree instances, from another context, you need to use jax.lax.stop_gradient again:

tree = Tree(buffer=...)
def func(tree: Tree):
    buffer = jax.lax.stop_gradient(tree.buffer)
    ...

This becomes cumbersome if this process is repeated multiple times.Alternatively, jax.lax.stop_gradient can be applied to the buffer using sepes.field whenever the buffer is accessed. The next example demonstrates this.

[2]:
import sepes as sp
import jax
import jax.numpy as jnp


def buffer_field(**kwargs):
    return sp.field(on_getattr=[jax.lax.stop_gradient], **kwargs)


@sp.autoinit  # autoinit construct `__init__` from fields
class Tree(sp.TreeClass):
    buffer: jax.Array = buffer_field()

    def __call__(self, x):
        return self.buffer**x


tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0)  # Array([1., 4., 9.], dtype=float32)


@jax.jit
def f(tree: Tree, x: jax.Array):
    return jnp.sum(tree(x))


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))
6.0
Tree(buffer=[0. 0. 0.])

[2] Masked field#

sepes provide a simple wrapper to mask data. Masking here means that the data yields no leaves when flattened. This is useful in some frameworks like jax to hide a certain values from being seen by the transformation.

Flattening a masked value

[3]:
import sepes as sp
import jax

tree = [1, sp.tree_mask(2, cond=lambda _: True)]
print(tree)
print(jax.tree_util.tree_leaves(tree))  # note that 2 is removed from the leaves
[1, #2]
[1]

Using masking with ``jax`` transformations

The next example demonstrates how to use masking to work with data types that are not supported by jax.

[4]:
import sepes as sp
import jax


def mask_field(**kwargs):
    return sp.field(
        # un mask when the value is accessed
        on_getattr=[lambda x: sp.tree_unmask(x, cond=lambda node: True)],
        # mask when the value is set
        on_setattr=[lambda x: sp.tree_mask(x, cond=lambda node: True)],
        **kwargs,
    )

Now we can use this custom field to mark some class attributes as masked. Masking a value will effectively hide it from jax transformations.

Without masking the ``str`` type

[5]:
@sp.autoinit
class Tree(sp.TreeClass):
    training_mode: str  # <- will throw error with jax transformations.
    alpha: float

    def __call__(self, x):
        if self.training_mode == "training":
            return x**self.alpha
        return x


@jax.grad
def loss_func(tree, input):
    return tree(input)


tree = Tree(training_mode="training", alpha=2.0)
print(loss_func(tree, 2.0))  # <- will throw error with jax transformations.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 18
     14     return tree(input)
     17 tree = Tree(training_mode="training", alpha=2.0)
---> 18 print(loss_func(tree, 2.0))  # <- will throw error with jax transformations.

    [... skipping hidden 5 frame]

File /opt/homebrew/Caskroom/miniconda/base/envs/dev-jax/lib/python3.12/site-packages/jax/_src/dispatch.py:281, in check_arg(arg)
    279 def check_arg(arg: Any):
    280   if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
--> 281     raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
    282                     "JAX type.")

TypeError: Argument 'training' of type <class 'str'> is not a valid JAX type.

The error resulted because jax recognize numerical values only. The next example demonstrates how to modify the class to mask the str type.

[ ]:
@sp.autoinit
class Tree(sp.TreeClass):
    training_mode: str = mask_field()  # hide the field from jax transformations
    alpha: float

    def __call__(self, x):
        if self.training_mode == "training":
            return x**self.alpha
        return x


tree = Tree(training_mode="training", alpha=2.0)
print(loss_func(tree, 2.0))

[3] Validator fields#

The following provides an example of how to use sepes.field to validate the input data. The validator function is used to check if the input data is valid. If the data is invalid, an exception is raised. This example is inspired by the python offical docs example

Range+Type validator#

[ ]:
import jax
import sepes as sp


# you can use any function
@sp.autoinit
class Range(sp.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


@sp.autoinit
class IsInstance(sp.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


@sp.autoinit
class Foo(sp.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = sp.field(on_setattr=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)

Array validator#

[ ]:
import sepes as sp
from typing import Any
import jax
import jax.numpy as jnp


class ArrayValidator(sp.TreeClass):
    """Validate shape and dtype of input array.

    Args:
        shape: Expected shape of array. available values are int, None, ...
            use int for fixed size, None for any size, and ... for any number
            of dimensions. for example (..., 1) allows any number of dimensions
            with the last dimension being 1. (1, ..., 1) allows any number of
            dimensions with the first and last dimensions being 1.
        dtype: Expected dtype of array.

    Example:
        >>> x = jnp.ones((5, 5))
        >>> # any number of dimensions with last dim=5
        >>> shape = (..., 5)
        >>> dtype = jnp.float32
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error

        >>> # must be 2 dimensions with first dim unconstrained and last dim=5
        >>> shape = (None, 5)
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error
    """

    def __init__(self, shape, dtype):
        if shape.count(...) > 1:
            raise ValueError("Only one ellipsis allowed")

        for si in shape:
            if not isinstance(si, (int, type(...), type(None))):
                raise TypeError(f"Expected int or ..., got {si}")

        self.shape = shape
        self.dtype = dtype

    def __call__(self, x):
        if not (hasattr(x, "shape") and hasattr(x, "dtype")):
            raise TypeError(f"Expected array with shape {self.shape}, got {x}")

        shape = list(self.shape)
        array_shape = list(x.shape)
        array_dtype = x.dtype

        if self.shape and array_dtype != self.dtype:
            raise TypeError(f"Dtype mismatch, {array_dtype=} != {self.dtype=}")

        if ... in shape:
            index = shape.index(...)
            shape = (
                shape[:index]
                + [None] * (len(array_shape) - len(shape) + 1)
                + shape[index + 1 :]
            )

        if len(shape) != len(array_shape):
            raise ValueError(f"{len(shape)=} != {len(array_shape)=}")

        for i, (li, ri) in enumerate(zip(shape, array_shape)):
            if li is None:
                continue
            if li != ri:
                raise ValueError(f"Size mismatch, {li} != {ri} at dimension {i}")
        return x


# any number of dimensions with firt dim=3 and last dim=6
shape = (3, ..., 6)
# dtype must be float32
dtype = jnp.float32

validator = ArrayValidator(shape=shape, dtype=dtype)

# convert to half precision from float32
converter = lambda x: x.astype(jnp.float16)


@sp.autoinit
class Tree(sp.TreeClass):
    array: jax.Array = sp.field(on_setattr=[validator, converter])


x = jnp.ones([3, 1, 2, 6])
tree = Tree(array=x)


try:
    y = jnp.ones([1, 1, 2, 3])
    tree = Tree(array=y)
except ValueError as e:
    print(e, "\n")
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

try:
    z = x.astype(jnp.float16)
    tree = Tree(array=z)
except TypeError as e:
    print(e)
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Size mismatch, 3 != 1 at dimension 0

[4] Parameterization field#

In this example, field value is parameterized using on_getattr,

[ ]:
import sepes as sp
import jax.numpy as jnp


def symmetric(array: jax.Array) -> jax.Array:
    triangle = jnp.triu(array)  # upper triangle
    return triangle + triangle.transpose(-1, -2)


@sp.autoinit
class Tree(sp.TreeClass):
    symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])


tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
print(tree.symmetric_matrix)