🪡 Sharing/Tie Weights#
Because sharing weights convert a pytree to a graph by pointing one leaf to another, a careful handling is needed to avoid breaking the tree assumptions.
In sepes, sharing/tie weights is done inside methods, this means, instead sharing the reference within __init__ method, the reference is shared within the method of which the call is made.
From
class TiedAutoEncoder:
def __init__(self, input_dim, hidden_dim):
self.encoder = Linear(input_dim, hidden_dim)
self.decoder = Linear(hidden_dim, input_dim)
self.decoder.weight = self.encoder.weight
def __call__(self, x):
return self.decoder(self.encoder(x))
To
class TiedAutoEncoder:
def __init__(self, input_dim, hidden_dim):
self.encoder = Linear(input_dim, hidden_dim)
self.decoder = Linear(hidden_dim, input_dim)
def __call__(self, x):
self.decoder.weight = self.encoder.weight.T
return self.decoder(self.encoder(x))
[1]:
!pip install sepes
In this example a simple AutoEncoder with shared weight between the encode/decoder is demonstrated.
[2]:
import sepes as sp
import jax
import jax.numpy as jnp
import jax.random as jr
import functools as ft
def sharing(method):
# sharing simply copies the instance, executes the method, and returns the output
# **without modifying the original instance.**
@ft.wraps(method)
def wrapper(self, *args, **kwargs):
# `value_and_tree` executes any mutating method in a functional way
# by copying `self`, executing the method, and returning the new state
# along with the output.
output, _ = sp.value_and_tree(method)(self, *args, **kwargs)
return output
return wrapper
class Linear(sp.TreeClass):
def __init__(self, in_features: int, out_features: int, key: jax.Array):
self.weight = jr.normal(key=key, shape=(out_features, in_features))
self.bias = jnp.zeros((out_features,))
def __call__(self, input):
return input @ self.weight.T + self.bias
class AutoEncoder(sp.TreeClass):
def __init__(self, *, key: jax.Array):
k1, k2, k3, k4 = jr.split(key, 4)
self.enc1 = Linear(1, 10, key=k1)
self.enc2 = Linear(10, 20, key=k2)
self.dec2 = Linear(20, 10, key=k3)
self.dec1 = Linear(10, 1, key=k4)
@sharing
def tied_call(self, input: jax.Array) -> jax.Array:
self.dec1.weight = self.enc1.weight.T
self.dec2.weight = self.enc2.weight.T
output = self.enc1(input)
output = self.enc2(output)
output = self.dec2(output)
output = self.dec1(output)
return output
def non_tied_call(self, x):
output = self.enc1(x)
output = self.enc2(output)
output = self.dec2(output)
output = self.dec1(output)
return output
[3]:
@jax.jit
@jax.grad
def tied_loss_func(net, x, y):
net = sp.tree_unmask(net)
return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)
tree = sp.tree_mask(AutoEncoder(key=jr.PRNGKey(0)))
x = jnp.ones([10, 1]) + 0.0
y = jnp.ones([10, 1]) * 2.0
grads: AutoEncoder = tied_loss_func(tree, x, y)
# note that the shared weights have 0 gradient
print(repr(grads.dec1), repr(grads.dec2))
Linear(
weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]),
bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])
) Linear(
weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]),
bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])
)
[4]:
# check for non-tied call
@jax.jit
@jax.grad
def non_tied_loss_func(net, x, y):
net = sp.tree_unmask(net)
return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)
tree = sp.tree_mask(tree)
x = jnp.ones([10, 1]) + 0.0
y = jnp.ones([10, 1]) * 2.0
grads: AutoEncoder = tied_loss_func(tree, x, y)
# note that the shared weights have non-zero gradients
print(repr(grads.dec1), repr(grads.dec2))
Linear(
weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]),
bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])
) Linear(
weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]),
bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])
)