🗂️ Misc recipes#

[1]:
!pip install sepes

This section introduces some miscellaneous recipes that are not covered in the previous sections.

[1] Lazy layers.#

In this example, a Linear layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., weight) after instance initialization, we will use value_and_tree to create a new instance with the added parameter.

[2]:
import sepes as sp
from typing import Any
import jax
import jax.numpy as jnp


class LazyLinear(sp.TreeClass):
    def __init__(self, out_features: int):
        self.out_features = out_features

    def param(self, name: str, value: Any):
        # return the value if it exists, otherwise set it and return it
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, input: jax.Array) -> jax.Array:
        weight = self.param("weight", jnp.ones((self.out_features, input.shape[-1])))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return input @ weight.T + bias


input = jnp.ones([10, 1])

lazy = LazyLinear(out_features=1)

print(f"Layer before param is set:\t{lazy}")

# `value_and_tree` executes any mutating method in a functional way
_, material = sp.value_and_tree(lambda layer: layer(input))(lazy)

print(f"Layer after param is set:\t{material}")
# subsequent calls will not set the parameters again
material(input)
Layer before param is set:      LazyLinear(out_features=1)
Layer after param is set:       LazyLinear(out_features=1, weight=[[1.]], bias=[0.])
[2]:
Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)

[3] Regularization#

The following code showcase how to use at functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using .at[...] functionality the following can be achieved concisely:

Boolean-based mask#

The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (True/False) leave. The True leaf points to place where the operation can be done, while False leaf is indicating that this leaf should not be touched.

[6]:
import sepes as sp
import jax.numpy as jnp
import jax


class Net(sp.TreeClass):
    def __init__(self):
        self.weight = jnp.array([-1, -2, -3, 1, 2, 3])
        self.bias = jnp.array([-1, 1])


def negative_entries_l2_loss(net: Net):
    return (
        # select all positive array entries
        net.at[jax.tree_map(lambda x: x > 0, net)]
        # set them to zero to exclude their loss
        .set(0)
        # select all leaves
        .at[...]
        # finally reduce with l2 loss
        .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)
    )


net = Net()
print(negative_entries_l2_loss(net))
2.8333335

Name-based mask#

In this step, the mask is based on the path of the leaf.

[7]:
import sepes as sp
import jax
import jax.numpy as jnp
import jax.random as jr


class Linear(sp.TreeClass):
    def __init__(self, in_features: int, out_features: int, key: jax.Array):
        self.weight = jr.normal(key=key, shape=(out_features, in_features))
        self.bias = jnp.zeros((out_features,))


class Net(sp.TreeClass):
    def __init__(self, key: jax.Array) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.linear1 = Linear(1, 20, key=k1)
        self.linear2 = Linear(20, 20, key=k2)
        self.linear3 = Linear(20, 20, key=k3)
        self.linear4 = Linear(20, 1, key=k4)


def linear_12_weight_l1_loss(net: Net):
    return (
        # select desired branches (linear1, linear2 in this example)
        # and the desired leaves (weight)
        net.at["linear1", "linear2"]["weight"]
        # alternatively, regex can be used to do the same functiontality
        # >>> import re
        # >>> net.at[re.compile("linear[12]")]["weight"]
        # finally apply l1 loss
        .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)
    )


net = Net(key=jr.PRNGKey(0))
print(linear_12_weight_l1_loss(net))
331.84155

This recipe can then be included inside the loss function, for example

def loss_fnc(net, x, y):
    l1_loss = linear_12_weight_l1_loss(net)
    loss += l1_loss
    ...