π§ Intermediates handling#
This notebook demonstrates how to capture the intermediate outputs of a model during inference. This is useful for debugging, understanding the model, and visualizing the modelβs internal representations.
[ ]:
!pip install sepes
Capture intermediate values.#
In this example, we will capture the intermediate values in a method by simply returning them as part of the output.
[ ]:
import sepes as sp
class Foo(sp.TreeClass):
def __init__(self):
self.a = 1.0
def __call__(self, x):
capture = {}
b = self.a + x
capture["b"] = b
c = 2 * b
capture["c"] = c
e = 4 * c
return e, capture
foo = Foo()
_, inter_values = foo(1.0)
inter_values
{'b': 2.0, 'c': 4.0}
Capture intermediate gradients#
In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using argnum in jax.grad to compute the intermediate gradients.
[ ]:
import sepes as sp
import jax
class Foo(sp.TreeClass):
def __init__(self):
self.a = 1.0
def __call__(self, x, perturb):
# pass in the perturbations as a pytree
b = self.a + x + perturb["b"]
c = 2 * b + perturb["c"]
e = 4 * c
return e
foo = Foo()
# de/dc = 4
# de/db = de/dc * dc/db = 4 * 2 = 8
# take gradient with respect to the perturbations pytree
# by setting `argnums=1` in `jax.grad`
inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))
inter_grads
{'b': Array(8., dtype=float32, weak_type=True),
'c': Array(4., dtype=float32, weak_type=True)}