💫 Transformations#

This section introduces common function transformations that are used in conjunction with pytrees. Examples includes function transformation that wraps jax transforms or a function transformation that wraps numpy.

[1]:
!pip install sepes

[1] Broadcasting transformations#

Using bcmap to apply a function over pytree leaves with automatic broadcasting for scalar arguments.

bcmap + numpy#

In this recipe, numpy functions will operate directly on TreeClass instances.

[2]:
import sepes as sp
import jax
import jax.numpy as jnp


@sp.leafwise  # enable math operations on leaves
@sp.autoinit  # generate __init__ from type annotations
class Tree(sp.TreeClass):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])


tree = Tree()

# make where work with arbitrary pytrees
tree_where = sp.bcmap(jnp.where)
# for values > 2, add 100, else set to 0
print(tree_where(tree > 2, tree + 100, 0))
Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

bcmap on pytrees with non-jaxtype

In case the tree has some non-jaxtype leaves, The above will fail, but we can use tree_mask/tree_unmask to fix it

[3]:
# in case the tree has some non-jaxtype leaves
# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it
import sepes as sp
import jax.numpy as jnp
from typing import Callable


@sp.leafwise  # enable math operations on leaves
@sp.autoinit  # generate __init__ from type annotations
class Tree(sp.TreeClass):
    a: float = 1.0
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])
    name: str = "tree"  # non-jaxtype
    func: Callable = lambda x: x  # non-jaxtype


tree = Tree()

try:
    # make where work with arbitrary pytrees with non-jaxtype leaves
    tree_where = sp.bcmap(jnp.where)
    # for values > 2, add 100, else set to 0
    print(tree_where(tree > 2, tree + 100, 0))
except TypeError as e:
    print("bcmap fail", e)
    # now we can use `tree_mask`/`tree_unmask` to fix it
    masked_tree = sp.tree_mask(tree)  # mask non-jaxtype leaves
    masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)
    unmasked_tree = sp.tree_unmask(masked_tree)
    print(unmasked_tree)
bcmap fail '>' not supported between instances of 'str' and 'int'
Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=<lambda>(x))

bcmap + configs#

The next example shows how to use serket.bcmap to loop over a configuration dictionary that defines creation of simple linear layers.

[4]:
import sepes as sp
import jax


class Linear(sp.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, *, key: jax.Array):
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jnp.zeros((out_dim,))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight + self.bias


config = {
    # each layer gets a different input dimension
    "in_dim": [1, 2, 3, 4],
    # out_dim is broadcasted to all layers
    "out_dim": 1,
    # each layer gets a different key
    "key": list(jax.random.split(jax.random.PRNGKey(0), 4)),
}


# `bcmap` transforms a function that takes a single input into a function that
# arbitrary pytree inputs. in case of a single input, the input is broadcasted
# to match the tree structure of the first argument
# (in our example is a list of 4 inputs)


@sp.bcmap
def build_layer(in_dim, out_dim, *, key: jax.Array):
    return Linear(in_dim, out_dim, key=key)


build_layer(config["in_dim"], config["out_dim"], key=config["key"])
[4]:
[Linear(
   weight=f32[1,1](μ=0.31, σ=0.00, ∈[0.31,0.31]),
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[2,1](μ=-1.27, σ=0.33, ∈[-1.59,-0.94]),
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[3,1](μ=0.24, σ=0.53, ∈[-0.48,0.77]),
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[4,1](μ=-0.28, σ=0.21, ∈[-0.64,-0.08]),
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 )]

[2] Masked transformations#

As an alternative to using sp.tree_unmask on pytrees before calling the function -as seen throughout training examples and recipes- , another approach is to wrap a certain transformation - not pytrees - (e.g. jit) to be make the masking/unmasking automatic; however this apporach will incur more overhead than applying sp.tree_unmask before the function call.

The following example demonstrate how to wrap jit, and vmap.

[3]:
import sepes as sp
import functools as ft
import jax
import jax.random as jr
import jax.numpy as jnp
from typing import Any
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


def automask(jax_transform: Callable[P, T]) -> Callable[P, T]:
    """Enable jax transformations to accept non-jax types. e.g. ``jax.grad``."""
    # works with functions that takes a function as input
    # and returns a function as output e.g. `jax.grad`

    def out_transform(func, **transformation_kwargs):
        @ft.partial(jax_transform, **transformation_kwargs)
        def jax_boundary(*args, **kwargs):
            args, kwargs = sp.tree_unmask((args, kwargs))
            return sp.tree_mask(func(*args, **kwargs))

        @ft.wraps(func)
        def outer_wrapper(*args, **kwargs):
            args, kwargs = sp.tree_mask((args, kwargs))
            output = jax_boundary(*args, **kwargs)
            return sp.tree_unmask(output)

        return outer_wrapper

    return out_transform


def inline_automask(jax_transform: Callable[P, T]) -> Callable[P, T]:
    """Enable jax transformations to accept non-jax types e.g. ``jax.lax.scan``."""
    # works with functions that takes a function and arguments as input
    # and returns jax types as output e.g. `jax.lax.scan`

    def outer_wrapper(func, *args, **kwargs):
        args, kwargs = sp.tree_mask((args, kwargs))

        def func_masked(*args, **kwargs):
            args, kwargs = sp.tree_unmask((args, kwargs))
            return sp.tree_mask(func(*args, **kwargs))

        output = jax_transform(func_masked, *args, **kwargs)
        return sp.tree_unmask(output)

    return outer_wrapper

automask(jit)#

[4]:
x, y = jnp.ones([5, 5]), jnp.ones([5, 5])

params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name="layer")


def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
    return jnp.tanh(x @ params["w1"]) @ params["w2"]


try:
    forward_jit = jax.jit(forward)
    print(forward_jit(params, x))
except TypeError as e:
    print("`jit error`:", e)
    # now with `automask` the function can accept non-jax types (e.g. string)
    forward_jit = automask(jax.jit)(forward)
    print("\nUsing automask:")
    print(f"{forward_jit(params, x)=}")
`jit error`: Argument 'layer' of type <class 'str'> is not a valid JAX type

Using automask:
forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
       [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
       [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
       [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
       [4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)

automask(vmap)#

[6]:
def make_params(key: jax.Array):
    k1, k2 = jax.random.split(key.astype(jnp.uint32))
    return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name="layer")


keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)

try:
    params = jax.vmap(make_params)(keys)
    print(params)
except TypeError as e:
    print("`vmap error`:", e)
    # now with `automask` the function can accept non-jax types (e.g. string)
    params = automask(jax.vmap)(make_params)(keys)
    print("\nUsing automask:")
    print(sp.tree_repr(params))
`vmap error`: Output from batched function 'layer' with type <class 'str'> is not a valid JAX type

Using automask:
dict(
  name=layer,
  w1=f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]),
  w2=f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])
)

automask(make_jaxpr)#

[8]:
def make_params(key: jax.Array):
    k1, k2 = jax.random.split(key.astype(jnp.uint32))
    return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name="layer")


keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)
params = automask(jax.vmap)(make_params)(keys)


def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
    return jnp.tanh(x @ params["w1"]) @ params["w2"]


try:
    jax.make_jaxpr(forward)(params, jnp.ones((10, 5)))
except TypeError as error:
    print(f"`jax.make_jaxpr` failed: {error=}")
    print("\nUsing `automask:")
    print(automask(jax.make_jaxpr)(forward)(params, jnp.ones((10, 5))))
`jax.make_jaxpr` failed: error=TypeError("Argument 'layer' of type <class 'str'> is not a valid JAX type")

Using `automask:
{ lambda ; a:f32[4,5,5] b:f32[4,5,5] c:f32[10,5]. let
    d:f32[10,4,5] = dot_general[
      dimension_numbers=(([1], [1]), ([], []))
      preferred_element_type=float32
    ] c a
    e:f32[4,10,5] = transpose[permutation=(1, 0, 2)] d
    f:f32[4,10,5] = tanh e
    g:f32[4,10,5] = dot_general[
      dimension_numbers=(([2], [1]), ([0], [0]))
      preferred_element_type=float32
    ] f b
  in (g,) }

inline_automask(scan)#

[14]:
def make_params(key: jax.Array):
    k1, k2 = jax.random.split(key.astype(jnp.uint32))
    return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name="layer")


keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)
params = automask(jax.vmap)(make_params)(keys)


def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
    return jnp.tanh(x @ params["w1"]) @ params["w2"]


def scan_func(params, input):
    # layer contains non-jax types
    output = forward(params, input)
    return params, output


try:
    jax.lax.scan(scan_func, params, jnp.ones((1, 1)))
except TypeError as error:
    print(f"`jax.lax.scan` Failed: {error=}")
    print("\nUsing `inline_automask`:")
    print(inline_automask(jax.lax.scan)(scan_func, params, jnp.ones((1, 1))))
`jax.lax.scan` Failed: error=TypeError("Value 'layer' with type <class 'str'> is not a valid JAX type")

Using `inline_automask`:
({'name': 'layer', 'w1': Array([[[0.6022109 , 0.06545091, 0.7613505 ]],

       [[0.33657324, 0.3744743 , 0.12130237]],

       [[0.51550114, 0.17686307, 0.6407058 ]],

       [[0.9101157 , 0.9690273 , 0.36771262]]], dtype=float32), 'w2': Array([[[0.2678218 ],
        [0.3963921 ],
        [0.7078583 ]],

       [[0.18808937],
        [0.8475715 ],
        [0.04241407]],

       [[0.74411213],
        [0.6318574 ],
        [0.58551705]],

       [[0.34456158],
        [0.5347049 ],
        [0.3992592 ]]], dtype=float32)}, Array([[[[0.62451595],
         [0.3141999 ],
         [0.59660065],
         [0.7389193 ]],

        [[0.1839285 ],
         [0.36948383],
         [0.26153624],
         [0.7847949 ]],

        [[0.81791794],
         [0.53822035],
         [0.7945141 ],
         [1.2155443 ]],

        [[0.4768083 ],
         [0.35134616],
         [0.48272693],
         [0.78913575]]]], dtype=float32))

inline_automask(eval_shape)#

[18]:
def make_params(key: jax.Array):
    k1, k2 = jax.random.split(key.astype(jnp.uint32))
    return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name="layer")


params = make_params(jr.PRNGKey(0))


def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
    return jnp.tanh(x @ params["w1"]) @ params["w2"]


try:
    jax.eval_shape(forward, params, jnp.ones((10, 1)))
except TypeError as error:
    print(f"`jax.eval_shape` Failed: {error=}")
    print("\nUsing `inline_automask`:")
    print(inline_automask(jax.eval_shape)(forward, params, jnp.ones((10, 1))))
`jax.eval_shape` Failed: error=TypeError("Argument 'layer' of type <class 'str'> is not a valid JAX type")

Using `inline_automask`:
ShapeDtypeStruct(shape=(10, 1), dtype=float32)