🗂️ Misc recipes#
[1]:
!pip install sepes
This section introduces some miscellaneous recipes that are not covered in the previous sections.
[1] Lazy layers.#
In this example, a Linear layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., weight) after instance initialization, we will use value_and_tree to create a new instance with the added parameter.
[2]:
import sepes as sp
from typing import Any
import jax
import jax.numpy as jnp
class LazyLinear(sp.TreeClass):
def __init__(self, out_features: int):
self.out_features = out_features
def param(self, name: str, value: Any):
# return the value if it exists, otherwise set it and return it
if name not in vars(self):
setattr(self, name, value)
return vars(self)[name]
def __call__(self, input: jax.Array) -> jax.Array:
weight = self.param("weight", jnp.ones((self.out_features, input.shape[-1])))
bias = self.param("bias", jnp.zeros((self.out_features,)))
return input @ weight.T + bias
input = jnp.ones([10, 1])
lazy = LazyLinear(out_features=1)
print(f"Layer before param is set:\t{lazy}")
# `value_and_tree` executes any mutating method in a functional way
_, material = sp.value_and_tree(lambda layer: layer(input))(lazy)
print(f"Layer after param is set:\t{material}")
# subsequent calls will not set the parameters again
material(input)
Layer before param is set: LazyLinear(out_features=1)
Layer after param is set: LazyLinear(out_features=1, weight=[[1.]], bias=[0.])
[2]:
Array([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]], dtype=float32)
[3] Regularization#
The following code showcase how to use at functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using .at[...] functionality the following can be achieved concisely:
Boolean-based mask#
The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (True/False) leave. The True leaf points to place where the operation can be done, while False leaf is indicating that this leaf should not be touched.
[6]:
import sepes as sp
import jax.numpy as jnp
import jax
class Net(sp.TreeClass):
def __init__(self):
self.weight = jnp.array([-1, -2, -3, 1, 2, 3])
self.bias = jnp.array([-1, 1])
def negative_entries_l2_loss(net: Net):
return (
# select all positive array entries
net.at[jax.tree_map(lambda x: x > 0, net)]
# set them to zero to exclude their loss
.set(0)
# select all leaves
.at[...]
# finally reduce with l2 loss
.reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)
)
net = Net()
print(negative_entries_l2_loss(net))
2.8333335
Name-based mask#
In this step, the mask is based on the path of the leaf.
[7]:
import sepes as sp
import jax
import jax.numpy as jnp
import jax.random as jr
class Linear(sp.TreeClass):
def __init__(self, in_features: int, out_features: int, key: jax.Array):
self.weight = jr.normal(key=key, shape=(out_features, in_features))
self.bias = jnp.zeros((out_features,))
class Net(sp.TreeClass):
def __init__(self, key: jax.Array) -> None:
k1, k2, k3, k4 = jax.random.split(key, 4)
self.linear1 = Linear(1, 20, key=k1)
self.linear2 = Linear(20, 20, key=k2)
self.linear3 = Linear(20, 20, key=k3)
self.linear4 = Linear(20, 1, key=k4)
def linear_12_weight_l1_loss(net: Net):
return (
# select desired branches (linear1, linear2 in this example)
# and the desired leaves (weight)
net.at["linear1", "linear2"]["weight"]
# alternatively, regex can be used to do the same functiontality
# >>> import re
# >>> net.at[re.compile("linear[12]")]["weight"]
# finally apply l1 loss
.reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)
)
net = Net(key=jr.PRNGKey(0))
print(linear_12_weight_l1_loss(net))
331.84155
This recipe can then be included inside the loss function, for example
def loss_fnc(net, x, y):
l1_loss = linear_12_weight_l1_loss(net)
loss += l1_loss
...