✂️ Surgery#
This notebook provides tree editing (surgery) recipes using at. at encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, at uses jax.tree_util or optree to traverse the tree and apply the provided function to the selected nodes.
import sepes as sp
import re
tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])
where = re.compile("key_.*") # select all keys starting with "key_"
func = lambda node: sum(map(abs, node)) # sum of absolute values
sp.at(tree)[where].apply(func)
# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}
[1]:
!pip install sepes
Out-of-place editing#
Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of at. The following example demonstrates this behavior:
[2]:
import sepes as sp
pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...
print(f"{pytree1=}, {pytree2=}")
# even though pytree1 and pytree2 are the same, they are not the same object
# because pytree2 is a copy of pytree1
print(f"pytree1 is pytree2 = {pytree1 is pytree2}")
pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]
pytree1 is pytree2 = False
Matching keys#
Match all#
Use ... to match all keys.
The following example applies plus_one function to all tree nodes. This is equivalent to tree = tree_map(plus_one, tree).
[3]:
import sepes as sp
pytree1 = [1, [2, 3], 4]
plus_one = lambda x: x + 1
pytree2 = sp.at(pytree1)[...].apply(plus_one)
pytree2
[3]:
[2, [3, 4], 5]
Integer indexing#
at can edit pytrees by integer paths.
[4]:
import sepes as sp
pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100
pytree2
[4]:
[1, [100, 3], 4]
Named path indexing#
at can edit pytrees by named paths.
[5]:
import sepes as sp
pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4, "f": {"g": 7, "h": 8}}
pytree2 = sp.at(pytree1)["b"].set(100) # equivalent to pytree1["b"] = 100
pytree2
[5]:
{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}
Regular expressions indexing#
at can edit pytrees by regular expressions applied to keys.
[6]:
import sepes as sp
import re
pytree1 = {
"key_1": 1,
"key_2": {"key_3": 3, "key_4": 4},
"key_5": 5,
"key_6": {"key_7": 7, "key_8": 8},
}
# from 1 - 5, set the value to 100
pattern = re.compile(r"key_[1-5]")
pytree2 = sp.at(pytree1)[pattern].set(100)
pytree2
[6]:
{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}
Mask indexing#
at can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked True will be edited, otherwise will not be touched.
The following example set all negative tree nodes to zero.
[7]:
import sepes as sp
import jax
pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4}
# mask defines all desired entries to apply the function
mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)
pytree2 = sp.at(pytree1)[mask].set(0)
pytree2
[7]:
{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}
Composition#
at can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.
The following example demonstrates how to apply a function to all nodes with:
Name starting with “key_”
Positive values
[8]:
import sepes as sp
import re
import jax
pytree1 = {"key_1": 1, "key_2": -2, "value_1": 1, "value_2": 2}
pattern = re.compile(r"key_.*")
mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)
pytree2 = sp.at(pytree1)[pattern][mask].set(100)
pytree2
[8]:
{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}