✂️ Surgery#

This notebook provides tree editing (surgery) recipes using at. at encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, at uses jax.tree_util or optree to traverse the tree and apply the provided function to the selected nodes.

import sepes as sp
import re
tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])
where = re.compile("key_.*")  # select all keys starting with "key_"
func = lambda node: sum(map(abs, node))  # sum of absolute values
sp.at(tree)[where].apply(func)
# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}
[1]:
!pip install sepes

Out-of-place editing#

Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of at. The following example demonstrates this behavior:

[2]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[...].get()  # get the whole pytree using ...
print(f"{pytree1=}, {pytree2=}")
# even though pytree1 and pytree2 are the same, they are not the same object
# because pytree2 is a copy of pytree1
print(f"pytree1 is pytree2 = {pytree1 is pytree2}")
pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]
pytree1 is pytree2 = False

Matching keys#

Match all#

Use ... to match all keys.

The following example applies plus_one function to all tree nodes. This is equivalent to tree = tree_map(plus_one, tree).

[3]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
plus_one = lambda x: x + 1
pytree2 = sp.at(pytree1)[...].apply(plus_one)
pytree2
[3]:
[2, [3, 4], 5]

Integer indexing#

at can edit pytrees by integer paths.

[4]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[1][0].set(100)  # equivalent to pytree1[1][0] = 100
pytree2
[4]:
[1, [100, 3], 4]

Named path indexing#

at can edit pytrees by named paths.

[5]:
import sepes as sp

pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4, "f": {"g": 7, "h": 8}}
pytree2 = sp.at(pytree1)["b"].set(100)  # equivalent to pytree1["b"] = 100
pytree2
[5]:
{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}

Regular expressions indexing#

at can edit pytrees by regular expressions applied to keys.

[6]:
import sepes as sp
import re

pytree1 = {
    "key_1": 1,
    "key_2": {"key_3": 3, "key_4": 4},
    "key_5": 5,
    "key_6": {"key_7": 7, "key_8": 8},
}
# from 1 - 5, set the value to 100
pattern = re.compile(r"key_[1-5]")
pytree2 = sp.at(pytree1)[pattern].set(100)
pytree2
[6]:
{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}

Mask indexing#

at can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked True will be edited, otherwise will not be touched.

The following example set all negative tree nodes to zero.

[7]:
import sepes as sp
import jax

pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4}
# mask defines all desired entries to apply the function
mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)
pytree2 = sp.at(pytree1)[mask].set(0)
pytree2
[7]:
{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}

Composition#

at can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.

The following example demonstrates how to apply a function to all nodes with:

  • Name starting with “key_

  • Positive values

[8]:
import sepes as sp
import re
import jax

pytree1 = {"key_1": 1, "key_2": -2, "value_1": 1, "value_2": 2}
pattern = re.compile(r"key_.*")
mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)
pytree2 = sp.at(pytree1)[pattern][mask].set(100)
pytree2
[8]:
{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}