πŸ”§ Intermediates handling#

This notebook demonstrates how to capture the intermediate outputs of a model during inference. This is useful for debugging, understanding the model, and visualizing the model’s internal representations.

[ ]:
!pip install sepes

Capture intermediate values.#

In this example, we will capture the intermediate values in a method by simply returning them as part of the output.

[ ]:
import sepes as sp


class Foo(sp.TreeClass):
    def __init__(self):
        self.a = 1.0

    def __call__(self, x):
        capture = {}
        b = self.a + x
        capture["b"] = b
        c = 2 * b
        capture["c"] = c
        e = 4 * c
        return e, capture


foo = Foo()

_, inter_values = foo(1.0)
inter_values
{'b': 2.0, 'c': 4.0}

Capture intermediate gradients#

In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using argnum in jax.grad to compute the intermediate gradients.

[ ]:
import sepes as sp
import jax


class Foo(sp.TreeClass):
    def __init__(self):
        self.a = 1.0

    def __call__(self, x, perturb):
        # pass in the perturbations as a pytree
        b = self.a + x + perturb["b"]
        c = 2 * b + perturb["c"]
        e = 4 * c
        return e


foo = Foo()

# de/dc = 4
# de/db = de/dc * dc/db = 4 * 2 = 8

# take gradient with respect to the perturbations pytree
# by setting `argnums=1` in `jax.grad`
inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))
inter_grads
{'b': Array(8., dtype=float32, weak_type=True),
 'c': Array(4., dtype=float32, weak_type=True)}