🪡 Sharing/Tie Weights#

Because sharing weights convert a pytree to a graph by pointing one leaf to another, a careful handling is needed to avoid breaking the tree assumptions.

In sepes, sharing/tie weights is done inside methods, this means, instead sharing the reference within __init__ method, the reference is shared within the method of which the call is made.

From

class TiedAutoEncoder:
    def __init__(self, input_dim, hidden_dim):
        self.encoder = Linear(input_dim, hidden_dim)
        self.decoder = Linear(hidden_dim, input_dim)
        self.decoder.weight = self.encoder.weight

    def __call__(self, x):
        return self.decoder(self.encoder(x))

To

class TiedAutoEncoder:
    def __init__(self, input_dim, hidden_dim):
        self.encoder = Linear(input_dim, hidden_dim)
        self.decoder = Linear(hidden_dim, input_dim)

    def __call__(self, x):
        self.decoder.weight = self.encoder.weight.T
        return self.decoder(self.encoder(x))
[1]:
!pip install sepes

In this example a simple AutoEncoder with shared weight between the encode/decoder is demonstrated.

[2]:
import sepes as sp
import jax
import jax.numpy as jnp
import jax.random as jr
import functools as ft


def sharing(method):
    # sharing simply copies the instance, executes the method, and returns the output
    # **without modifying the original instance.**
    @ft.wraps(method)
    def wrapper(self, *args, **kwargs):
        # `value_and_tree` executes any mutating method in a functional way
        # by copying `self`, executing the method, and returning the new state
        # along with the output.
        output, _ = sp.value_and_tree(method)(self, *args, **kwargs)
        return output

    return wrapper


class Linear(sp.TreeClass):
    def __init__(self, in_features: int, out_features: int, key: jax.Array):
        self.weight = jr.normal(key=key, shape=(out_features, in_features))
        self.bias = jnp.zeros((out_features,))

    def __call__(self, input):
        return input @ self.weight.T + self.bias


class AutoEncoder(sp.TreeClass):
    def __init__(self, *, key: jax.Array):
        k1, k2, k3, k4 = jr.split(key, 4)
        self.enc1 = Linear(1, 10, key=k1)
        self.enc2 = Linear(10, 20, key=k2)
        self.dec2 = Linear(20, 10, key=k3)
        self.dec1 = Linear(10, 1, key=k4)

    @sharing
    def tied_call(self, input: jax.Array) -> jax.Array:
        self.dec1.weight = self.enc1.weight.T
        self.dec2.weight = self.enc2.weight.T
        output = self.enc1(input)
        output = self.enc2(output)
        output = self.dec2(output)
        output = self.dec1(output)
        return output

    def non_tied_call(self, x):
        output = self.enc1(x)
        output = self.enc2(output)
        output = self.dec2(output)
        output = self.dec1(output)
        return output
[3]:
@jax.jit
@jax.grad
def tied_loss_func(net, x, y):
    net = sp.tree_unmask(net)
    return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)


tree = sp.tree_mask(AutoEncoder(key=jr.PRNGKey(0)))
x = jnp.ones([10, 1]) + 0.0
y = jnp.ones([10, 1]) * 2.0
grads: AutoEncoder = tied_loss_func(tree, x, y)
# note that the shared weights have 0 gradient
print(repr(grads.dec1), repr(grads.dec2))
Linear(
  weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]),
  bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])
) Linear(
  weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]),
  bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])
)
[4]:
# check for non-tied call
@jax.jit
@jax.grad
def non_tied_loss_func(net, x, y):
    net = sp.tree_unmask(net)
    return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)


tree = sp.tree_mask(tree)
x = jnp.ones([10, 1]) + 0.0
y = jnp.ones([10, 1]) * 2.0
grads: AutoEncoder = tied_loss_func(tree, x, y)

# note that the shared weights have non-zero gradients
print(repr(grads.dec1), repr(grads.dec2))
Linear(
  weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]),
  bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])
) Linear(
  weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]),
  bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])
)