💫 Transformations#
This section introduces common function transformations that are used in conjunction with pytrees. Examples includes function transformation that wraps jax transforms or a function transformation that wraps numpy.
[1]:
!pip install sepes
[1] Broadcasting transformations#
Using bcmap to apply a function over pytree leaves with automatic broadcasting for scalar arguments.
bcmap + numpy#
In this recipe, numpy functions will operate directly on TreeClass instances.
[2]:
import sepes as sp
import jax
import jax.numpy as jnp
@sp.leafwise # enable math operations on leaves
@sp.autoinit # generate __init__ from type annotations
class Tree(sp.TreeClass):
a: int = 1
b: tuple[float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
tree = Tree()
# make where work with arbitrary pytrees
tree_where = sp.bcmap(jnp.where)
# for values > 2, add 100, else set to 0
print(tree_where(tree > 2, tree + 100, 0))
Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])
bcmap on pytrees with non-jaxtype
In case the tree has some non-jaxtype leaves, The above will fail, but we can use tree_mask/tree_unmask to fix it
[3]:
# in case the tree has some non-jaxtype leaves
# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it
import sepes as sp
import jax.numpy as jnp
from typing import Callable
@sp.leafwise # enable math operations on leaves
@sp.autoinit # generate __init__ from type annotations
class Tree(sp.TreeClass):
a: float = 1.0
b: tuple[float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
name: str = "tree" # non-jaxtype
func: Callable = lambda x: x # non-jaxtype
tree = Tree()
try:
# make where work with arbitrary pytrees with non-jaxtype leaves
tree_where = sp.bcmap(jnp.where)
# for values > 2, add 100, else set to 0
print(tree_where(tree > 2, tree + 100, 0))
except TypeError as e:
print("bcmap fail", e)
# now we can use `tree_mask`/`tree_unmask` to fix it
masked_tree = sp.tree_mask(tree) # mask non-jaxtype leaves
masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)
unmasked_tree = sp.tree_unmask(masked_tree)
print(unmasked_tree)
bcmap fail '>' not supported between instances of 'str' and 'int'
Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=<lambda>(x))
bcmap + configs#
The next example shows how to use serket.bcmap to loop over a configuration dictionary that defines creation of simple linear layers.
[4]:
import sepes as sp
import jax
class Linear(sp.TreeClass):
def __init__(self, in_dim: int, out_dim: int, *, key: jax.Array):
self.weight = jax.random.normal(key, (in_dim, out_dim))
self.bias = jnp.zeros((out_dim,))
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.weight + self.bias
config = {
# each layer gets a different input dimension
"in_dim": [1, 2, 3, 4],
# out_dim is broadcasted to all layers
"out_dim": 1,
# each layer gets a different key
"key": list(jax.random.split(jax.random.PRNGKey(0), 4)),
}
# `bcmap` transforms a function that takes a single input into a function that
# arbitrary pytree inputs. in case of a single input, the input is broadcasted
# to match the tree structure of the first argument
# (in our example is a list of 4 inputs)
@sp.bcmap
def build_layer(in_dim, out_dim, *, key: jax.Array):
return Linear(in_dim, out_dim, key=key)
build_layer(config["in_dim"], config["out_dim"], key=config["key"])
[4]:
[Linear(
weight=f32[1,1](μ=0.31, σ=0.00, ∈[0.31,0.31]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
),
Linear(
weight=f32[2,1](μ=-1.27, σ=0.33, ∈[-1.59,-0.94]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
),
Linear(
weight=f32[3,1](μ=0.24, σ=0.53, ∈[-0.48,0.77]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
),
Linear(
weight=f32[4,1](μ=-0.28, σ=0.21, ∈[-0.64,-0.08]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)]
[2] Masked transformations#
As an alternative to using sp.tree_unmask on pytrees before calling the function -as seen throughout training examples and recipes- , another approach is to wrap a certain transformation - not pytrees - (e.g. jit) to be make the masking/unmasking automatic; however this apporach will incur more overhead than applying sp.tree_unmask before the function call.
The following example demonstrate how to wrap jit, and vmap.
[3]:
import sepes as sp
import functools as ft
import jax
import jax.random as jr
import jax.numpy as jnp
from typing import Any
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
def automask(jax_transform: Callable[P, T]) -> Callable[P, T]:
"""Enable jax transformations to accept non-jax types. e.g. ``jax.grad``."""
# works with functions that takes a function as input
# and returns a function as output e.g. `jax.grad`
def out_transform(func, **transformation_kwargs):
@ft.partial(jax_transform, **transformation_kwargs)
def jax_boundary(*args, **kwargs):
args, kwargs = sp.tree_unmask((args, kwargs))
return sp.tree_mask(func(*args, **kwargs))
@ft.wraps(func)
def outer_wrapper(*args, **kwargs):
args, kwargs = sp.tree_mask((args, kwargs))
output = jax_boundary(*args, **kwargs)
return sp.tree_unmask(output)
return outer_wrapper
return out_transform
def inline_automask(jax_transform: Callable[P, T]) -> Callable[P, T]:
"""Enable jax transformations to accept non-jax types e.g. ``jax.lax.scan``."""
# works with functions that takes a function and arguments as input
# and returns jax types as output e.g. `jax.lax.scan`
def outer_wrapper(func, *args, **kwargs):
args, kwargs = sp.tree_mask((args, kwargs))
def func_masked(*args, **kwargs):
args, kwargs = sp.tree_unmask((args, kwargs))
return sp.tree_mask(func(*args, **kwargs))
output = jax_transform(func_masked, *args, **kwargs)
return sp.tree_unmask(output)
return outer_wrapper
automask(jit)#
[4]:
x, y = jnp.ones([5, 5]), jnp.ones([5, 5])
params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name="layer")
def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
return jnp.tanh(x @ params["w1"]) @ params["w2"]
try:
forward_jit = jax.jit(forward)
print(forward_jit(params, x))
except TypeError as e:
print("`jit error`:", e)
# now with `automask` the function can accept non-jax types (e.g. string)
forward_jit = automask(jax.jit)(forward)
print("\nUsing automask:")
print(f"{forward_jit(params, x)=}")
`jit error`: Argument 'layer' of type <class 'str'> is not a valid JAX type
Using automask:
forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],
[4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)
automask(vmap)#
[6]:
def make_params(key: jax.Array):
k1, k2 = jax.random.split(key.astype(jnp.uint32))
return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name="layer")
keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)
try:
params = jax.vmap(make_params)(keys)
print(params)
except TypeError as e:
print("`vmap error`:", e)
# now with `automask` the function can accept non-jax types (e.g. string)
params = automask(jax.vmap)(make_params)(keys)
print("\nUsing automask:")
print(sp.tree_repr(params))
`vmap error`: Output from batched function 'layer' with type <class 'str'> is not a valid JAX type
Using automask:
dict(
name=layer,
w1=f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]),
w2=f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])
)
automask(make_jaxpr)#
[8]:
def make_params(key: jax.Array):
k1, k2 = jax.random.split(key.astype(jnp.uint32))
return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name="layer")
keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)
params = automask(jax.vmap)(make_params)(keys)
def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
return jnp.tanh(x @ params["w1"]) @ params["w2"]
try:
jax.make_jaxpr(forward)(params, jnp.ones((10, 5)))
except TypeError as error:
print(f"`jax.make_jaxpr` failed: {error=}")
print("\nUsing `automask:")
print(automask(jax.make_jaxpr)(forward)(params, jnp.ones((10, 5))))
`jax.make_jaxpr` failed: error=TypeError("Argument 'layer' of type <class 'str'> is not a valid JAX type")
Using `automask:
{ lambda ; a:f32[4,5,5] b:f32[4,5,5] c:f32[10,5]. let
d:f32[10,4,5] = dot_general[
dimension_numbers=(([1], [1]), ([], []))
preferred_element_type=float32
] c a
e:f32[4,10,5] = transpose[permutation=(1, 0, 2)] d
f:f32[4,10,5] = tanh e
g:f32[4,10,5] = dot_general[
dimension_numbers=(([2], [1]), ([0], [0]))
preferred_element_type=float32
] f b
in (g,) }
inline_automask(scan)#
[14]:
def make_params(key: jax.Array):
k1, k2 = jax.random.split(key.astype(jnp.uint32))
return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name="layer")
keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)
params = automask(jax.vmap)(make_params)(keys)
def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
return jnp.tanh(x @ params["w1"]) @ params["w2"]
def scan_func(params, input):
# layer contains non-jax types
output = forward(params, input)
return params, output
try:
jax.lax.scan(scan_func, params, jnp.ones((1, 1)))
except TypeError as error:
print(f"`jax.lax.scan` Failed: {error=}")
print("\nUsing `inline_automask`:")
print(inline_automask(jax.lax.scan)(scan_func, params, jnp.ones((1, 1))))
`jax.lax.scan` Failed: error=TypeError("Value 'layer' with type <class 'str'> is not a valid JAX type")
Using `inline_automask`:
({'name': 'layer', 'w1': Array([[[0.6022109 , 0.06545091, 0.7613505 ]],
[[0.33657324, 0.3744743 , 0.12130237]],
[[0.51550114, 0.17686307, 0.6407058 ]],
[[0.9101157 , 0.9690273 , 0.36771262]]], dtype=float32), 'w2': Array([[[0.2678218 ],
[0.3963921 ],
[0.7078583 ]],
[[0.18808937],
[0.8475715 ],
[0.04241407]],
[[0.74411213],
[0.6318574 ],
[0.58551705]],
[[0.34456158],
[0.5347049 ],
[0.3992592 ]]], dtype=float32)}, Array([[[[0.62451595],
[0.3141999 ],
[0.59660065],
[0.7389193 ]],
[[0.1839285 ],
[0.36948383],
[0.26153624],
[0.7847949 ]],
[[0.81791794],
[0.53822035],
[0.7945141 ],
[1.2155443 ]],
[[0.4768083 ],
[0.35134616],
[0.48272693],
[0.78913575]]]], dtype=float32))
inline_automask(eval_shape)#
[18]:
def make_params(key: jax.Array):
k1, k2 = jax.random.split(key.astype(jnp.uint32))
return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name="layer")
params = make_params(jr.PRNGKey(0))
def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:
return jnp.tanh(x @ params["w1"]) @ params["w2"]
try:
jax.eval_shape(forward, params, jnp.ones((10, 1)))
except TypeError as error:
print(f"`jax.eval_shape` Failed: {error=}")
print("\nUsing `inline_automask`:")
print(inline_automask(jax.eval_shape)(forward, params, jnp.ones((10, 1))))
`jax.eval_shape` Failed: error=TypeError("Argument 'layer' of type <class 'str'> is not a valid JAX type")
Using `inline_automask`:
ShapeDtypeStruct(shape=(10, 1), dtype=float32)