Source code for sepes._src.tree_index

# Copyright 2023 sepes authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define lens-like indexing for pytrees

This module provides a way to index and mask pytrees (e.g. TreeClass) in an 
out-of-place manner.Out-of-place means that the original pytree is not modified, 
instead a new pytree with the selected leaves are modified.

The indexing is done through two concepts:

1) Selection (Where): Determines parts of the pytree for manipulation via a path or a boolean mask.
2) Operation (What): Defines actions on selected parts, such as setting values or applying functions.

For example, the following code defines a dict pytree with where of same structure 
as the tree. The where (Selection) defines which parts of the tree to select and 
the set (Operation) operation sets the selected parts to 100.

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> where = {"a": True, "b": [False, True, False]}
>>> sp.at(tree)[where].set(100)
{'a': 100, 'b': [1, 100, 3]}
"""

from __future__ import annotations

import abc
import functools as ft
import re
from typing import Any, Callable, Generic, Hashable, Sequence, TypeVar

from typing_extensions import Self

import sepes
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable
from sepes._src.backend.treelib import ParallelConfig
from sepes._src.tree_pprint import tree_repr

T = TypeVar("T")
S = TypeVar("S")
PyTree = Any
Leaf = Any
EllipsisType = TypeVar("EllipsisType")
PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable)
_no_initializer = object()
_no_fill_value = object()


class BaseKey(abc.ABC):
    """Parent class for all match classes."""

    @abc.abstractmethod
    def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool:
        pass

    @property
    @abc.abstractmethod
    def broadcast(self):
        ...


_INVALID_INDEXER = """\
Indexing with {indexer} is not implemented, supported indexing types are:
  - `str` for mapping keys or class attributes.
  - `int` for positional indexing for sequences.
  - `...` to select all leaves.
  - ``re.Pattern`` to match a leaf level path with a regex pattern.
  - Boolean mask of a compatible structure as the pytree.
  - `tuple` of the above types to match multiple leaves at the same level.
  - Custom matchers defined with `at.def_rule`.
"""

_NO_LEAF_MATCH = """\
No leaf match is found for where={where}, Available keys are {names}
Check the following: 
  - If where is `str` then check if the key exists as a key or attribute.
  - If where is `int` then check if the index is in range.
  - If where is `re.Pattern` then check if the pattern matches any key.
  - If where is a `tuple` of the above types then check if any of the tuple elements match.
"""


def generate_path_mask(tree, where: tuple[BaseKey, ...], *, is_leaf=None):
    # given a pytree `tree` and a `where` path, that is composed of keys
    # generate a boolean mask that will be eventually used to with `tree_map`
    # to mark the leaves at the specified location.
    # for example for a tree = [[1, 2], 3, 4] and where = [0][1] then
    # generate [[False, True], False, False] mask
    match: bool = False
    treelib = sepes._src.backend.treelib

    def one_level_tree_path_map(func, tree):
        # apply func to the immediate children of tree
        def is_leaf_func(node) -> bool:
            # enable immediate children only
            if is_leaf and is_leaf(node) is True:
                return True
            if id(node) == id(tree):
                return False
            return True

        return treelib.path_map(func, tree, is_leaf=is_leaf_func)

    if any(where_i.broadcast for where_i in where):
        # should the selected subtree be broadcasted to the full tree
        # e.g. tree = [[1, 2], 3, 4] and where = [0], then
        # broadcast with True will be [[True, True], False, False]
        # and without broadcast will be [True, False, False]
        # the difference is that with broadcast the user defined value will
        # be broadcasted to the full subtree, for example if the user defined
        # value is 100 then the result will be [[100, 100], 3, 4]
        # and without broadcast the result will be [100, 3, 4]

        def bool_tree(value: bool, tree: Any):
            leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf)
            return treelib.unflatten(treedef, [value] * len(leaves))

        true_tree = ft.partial(bool_tree, True)
        false_tree = ft.partial(bool_tree, False)

    else:
        # no broadcast, the user defined value will be applied to the selected
        # subtree only, for example if the user defined value is 100 then the
        true_tree = lambda _: True
        false_tree = lambda _: False

    def path_map_func(path, leaf):
        nonlocal match, where

        # ensure that the path is not empty
        if len(path) == len(where):
            for wi, pi in zip(where, path):
                if not wi.compare(pi, leaf):
                    return false_tree(leaf)
            match = True
            return true_tree(leaf)

        if len(path) and len(path) < len(where):
            # before traversing deeper into the tree, check if the current
            # path entry matches the current where entry, if not then return
            # a false tree to stop traversing deeper into the tree.
            (cur_where, *rest_where), (cur_path, *_) = where, path
            if cur_where.compare(cur_path, leaf):
                # where is nonlocal to the function
                # so reduce the where path by one level and traverse deeper
                # then restore the where path to the original value before
                # returning the result
                where = rest_where
                # traverse deeper into the tree
                out_tree = one_level_tree_path_map(path_map_func, leaf)
                # return from the traversal
                where = (cur_where, *rest_where)
                return out_tree
            return false_tree(leaf)

        return false_tree(leaf)

    mask = one_level_tree_path_map(path_map_func, tree)

    if not match:
        path_leaf, _ = treelib.path_flatten(tree, is_leaf=is_leaf)
        path = "/".join(str(where_i.input) for where_i in where)
        names = "".join("\n  - " + tree_repr(path) for path, _ in path_leaf)
        raise LookupError(_NO_LEAF_MATCH.format(where=path, names=names))

    return mask


def resolve_where(
    where: list[Any],
    tree: T,
    is_leaf: Callable[[Any], bool] | None = None,
):
    treelib = sepes._src.backend.treelib
    ndarrays = tuple(arraylib.ndarrays)

    def combine_bool_leaves(*leaves):
        # given a list of boolean leaves, combine them using `and`
        # this is used to combine multiple boolean masks resulting from
        # either path mask or boolean mask
        verdict = True
        for leaf in leaves:
            verdict &= leaf
        return verdict

    def is_bool_leaf(leaf: Any) -> bool:
        if isinstance(leaf, ndarrays):
            return arraylib.is_bool(leaf)
        return isinstance(leaf, bool)

    # given a pytree `tree` and a `where` path, that is composed of keys or
    # boolean masks, generate a boolean mask that will be eventually used to
    # with `tree_map` to select the leaves at the specified location.
    mask = None
    bool_masks: list[T] = []
    path_masks: list[BaseKey] = []
    seen_tuple = False  # handle multiple keys at the same level
    level_paths = []

    def verify_and_aggregate_is_leaf(node: Any) -> bool:
        # use is_leaf with non-local to traverse the tree depth-first manner
        # required for verifying if a pytree is a valid indexing pytree
        nonlocal seen_tuple, level_paths, bool_masks
        # used to check if a pytree is a valid indexing pytree
        # used with `is_leaf` argument of any `tree_*` function
        leaves, _ = treelib.flatten(node)

        if all(map(is_bool_leaf, leaves)):
            # if all leaves are boolean then this is maybe a boolean mask.
            # Maybe because the boolean mask can be a valid pytree of same structure
            # as the pytree to be indexed or _compatible_ structure.
            # that can be flattend up to inside tree_map.
            # the following is an example showcase this:
            # >>> tree = [1, 2, [3, 4]]
            # >>> mask = [True, True, False]
            # >>> at(tree)[mask].get()
            # in essence the user can mark full subtrees by `False` without
            # needing to populate the subtree with `False` values. if treedef
            # check is mandated then the user will need to populate the subtree
            # with `False` values. i.e. mask = [True, True, [False, False]]
            # Finally, invalid boolean mask will be caught by `jax.tree_util`
            bool_masks += [node]
            return True

        if isinstance(resolved_key := at.dispatcher(node), BaseKey):
            # valid resolution of `BaseKey` is a valid indexing leaf
            # makes it possible to dispatch on multi-leaf pytree
            level_paths += [resolved_key]
            return True

        if type(node) is tuple and seen_tuple is False:
            # e.g. `at[1,2,3]` but not `at[1,(2,3)]``
            # i.e. inside `__getitem__` mutliple entries are transformed to a tuple
            seen_tuple = True
            return False

        # not a container of other keys or a pytree of same structure
        # emit a descriptive error message to the user by pointing to the
        # the available keys in the pytree.
        raise NotImplementedError(_INVALID_INDEXER.format(indexer=node))

    for level_keys in where:
        # each for loop iteration is a level in the where path
        # this means that if where = ("a", "b", "c") then this means
        # we are travering the tree at level "a" then level "b" then level "c"
        treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf)
        # if len(level_paths) > 1 then this means that we have multiple keys
        # at the same level, for example where = ("a", ("b", "c")) then this
        # means that for a parent "a", select "b" and "c".
        path_masks += [MultiKey(*level_paths)] if len(level_paths) > 1 else level_paths
        level_paths = []
        seen_tuple = False

    if path_masks:
        mask = generate_path_mask(tree, path_masks, is_leaf=is_leaf)

    if bool_masks:
        all_masks = [mask, *bool_masks] if mask else bool_masks
        mask = treelib.map(combine_bool_leaves, *all_masks)

    return mask


class at(Generic[T]):
    """Operate on a pytree at a given path using a path or mask in out-of-place manner.

    Args:
        tree: pytree to operate on.
        where: one of the following:

            - ``str`` for mapping keys or class attributes.
            - ``int`` for positional indexing for sequences.
            - ``...`` to select all leaves.
            - a boolean mask of the same structure as the tree
            - ``re.Pattern`` to match a leaf level path with a regex pattern.
            - Custom matchers defined with ``at.def_rule``.
            - a tuple of the above to match multiple keys at the same level.

    Example:
        >>> import jax
        >>> import sepes as sp
        >>> tree = {"a": 1, "b": [1, 2, 3]}
        >>> sp.at(tree)["a"].set(100)
        {'a': 100, 'b': [1, 2, 3]}
        >>> sp.at(tree)["b"][0].set(100)
        {'a': 1, 'b': [100, 2, 3]}
        >>> mask = jax.tree_map(lambda x: x > 1, tree)
        >>> sp.at(tree)[mask].set(100)
        {'a': 1, 'b': [1, 100, 100]}
    """

    def __init__(self, tree: T, where: list[Any] | None = None) -> None:
        self.tree = tree
        self.where = [] if where is None else where

    def __getitem__(self, where: Any) -> Self:
        """Index a pytree at a given path using a path or mask."""
        return type(self)(self.tree, [*self.where, where])

    def __repr__(self) -> str:
        return f"{type(self).__name__}({tree_repr(self.tree)}, where={self.where})"

[docs] def get( self, *, is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, fill_value: Any = _no_fill_value, ): """Get the leaf values at the specified location. Args: is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. fill_value: the value to fill the non-selected leaves with. Useful to use with ``jax.jit`` to avoid variable size arrays leaves related errors. Returns: A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].get() {'a': None, 'b': [1, None, None]} """ treelib = sepes._src.backend.treelib ndarrays = tuple(arraylib.ndarrays) def leaf_get(where: Any, leaf: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1]) # because of the variable resultant size of the output if isinstance(where, ndarrays) and len(arraylib.shape(where)): if fill_value is not _no_fill_value: return arraylib.where(where, leaf, fill_value) return leaf[where] # non-array boolean mask we select the leaf if the mask is True # and `None` otherwise if fill_value is not _no_fill_value: return leaf if where else fill_value return leaf if where else None return treelib.map( leaf_get, resolve_where(self.where, self.tree, is_leaf), self.tree, is_leaf=is_leaf, is_parallel=is_parallel, )
[docs] def set( self, set_value: Any, *, is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Set the leaf values at the specified location. Args: set_value: the value to set at the specified location. is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A pytree with the leaf values at the specified location set to ``set_value``. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].set(100) {'a': 1, 'b': [100, 2, 3]} """ treelib = sepes._src.backend.treelib ndarrays = tuple(arraylib.ndarrays) def leaf_set(where: Any, leaf: Any, set_value: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100]) # with set_value = 100 if isinstance(where, ndarrays): return arraylib.where(where, set_value, leaf) return set_value if where else leaf _, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf) _, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf) if lhsdef == rhsdef: # do not broadcast set_value if it is a pytree of same structure # for example tree.at[where].set(tree2) will set all tree leaves # to tree2 leaves if tree2 is a pytree of same structure as tree # instead of making each leaf of tree a copy of tree2 # is design is similar to ``numpy`` design `np.at[...].set(Array)` return treelib.map( leaf_set, resolve_where(self.where, self.tree, is_leaf), self.tree, set_value, is_leaf=is_leaf, is_parallel=is_parallel, ) return treelib.map( ft.partial(leaf_set, set_value=set_value), resolve_where(self.where, self.tree, is_leaf), self.tree, is_leaf=is_leaf, is_parallel=is_parallel, )
[docs] def apply( self, func: Callable[[Any], Any], *, is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Apply a function to the leaf values at the specified location. Args: func: the function to apply to the leaf values. is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A pytree with the leaf values at the specified location set to the result of applying ``func`` to the leaf values. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].apply(lambda x: x + 100) {'a': 1, 'b': [101, 2, 3]} Example: Read images in parallel >>> import sepes as sp >>> from matplotlib.pyplot import imread >>> path = {"img1": "path1.png", "img2": "path2.png"} >>> is_parallel = dict(max_workers=2) >>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP """ treelib = sepes._src.backend.treelib ndarrays = tuple(arraylib.ndarrays) def leaf_apply(where: Any, leaf: Any): # same as `leaf_set` but with `func` applied to the leaf # one thing to note is that, the where mask select an array # then the function needs work properly when applied to the selected # array elements if isinstance(where, ndarrays): return arraylib.where(where, func(leaf), leaf) return func(leaf) if where else leaf return treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, is_leaf=is_leaf, is_parallel=is_parallel, )
[docs] def scan( self, func: Callable[[Any, S], tuple[Any, S]], state: S, *, is_leaf: Callable[[Any], bool] | None = None, ) -> tuple[Any, S]: """Apply a function while carrying a state. Args: func: the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state. state: the initial state to carry. is_leaf: a predicate function to determine if a value is a leaf. for example, ``lambda x: isinstance(x, list)`` will treat all lists as leaves and will not recurse into list items. Returns: A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying ``func`` to the leaf values. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> def scan_func(leaf, running_max): ... cur_max = max(leaf, running_max) ... return leaf, cur_max >>> running_max = float("-inf") >>> _, running_max = sp.at(tree)["b"][0, 1].scan(scan_func, state=running_max) >>> running_max # max of b[0] and b[1] 2 Note: ``scan`` applies a binary ``func`` to the leaf values while carrying a state and returning a tree leaves with the the ``func`` applied to them with final state. While ``reduce`` applies a binary ``func`` to the leaf values while carrying a state and returning a single value. """ treelib = sepes._src.backend.treelib ndarrays = tuple(arraylib.ndarrays) running_state = state def stateless_func(leaf): nonlocal running_state leaf, running_state = func(leaf, running_state) return leaf def leaf_apply(where: Any, leaf: Any): if isinstance(where, ndarrays): return arraylib.where(where, stateless_func(leaf), leaf) return stateless_func(leaf) if where else leaf out_tree = treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, is_leaf=is_leaf, ) return out_tree, running_state
[docs] def reduce( self, func: Callable[[Any, Any], Any], *, initializer: Any = _no_initializer, is_leaf: Callable[[Any], bool] | None = None, ) -> Any: """Reduce the leaf values at the specified location. Args: func: the function to reduce the leaf values. initializer: the initializer value for the reduction. is_leaf: a predicate function to determine if a value is a leaf. Returns: The result of reducing the leaf values at the specified location. Note: - If ``initializer`` is not specified, the first leaf value is used as the initializer. - ``reduce`` applies a binary ``func`` to each leaf values while accumulating a state a returns the final result. while ``scan`` applies ``func`` to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applying ``func`` to each leaf. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"].reduce(lambda x, y: x + y) 6 """ treelib = sepes._src.backend.treelib tree = self.get(is_leaf=is_leaf) # type: ignore leaves, _ = treelib.flatten(tree, is_leaf=is_leaf) if initializer is _no_initializer: return ft.reduce(func, leaves) return ft.reduce(func, leaves, initializer)
[docs] def pluck( self, count: int | None = None, *, is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ) -> list[Any]: """Extract subtrees at the specified location. Note: ``pluck`` first applies ``get`` to the specified location and then extracts the immediate subtrees of the selected leaves. ``is_leaf`` and ``is_parallel`` are passed to ``get``. Args: count: number of subtrees to extract, Default to ``None`` to extract all subtrees. is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A list of subtrees at the specified location. Note: Compared to ``get``, ``pluck`` extracts subtrees at the specified location and returns a list of subtrees. While ``get`` returns a pytree with the leaf values at the specified location and set the non-selected leaf values to ``None``. Example: >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} <BLANKLINE> >>> # `pluck` returns a list of selected subtrees >>> sp.at(tree)["b"].pluck() [[1, 2, 3]] <BLANKLINE> >>> # `get` returns same pytree >>> sp.at(tree)["b"].get() {'a': None, 'b': [1, 2, 3]} Example: ``pluck`` with mask >>> import sepes as sp >>> tree = {"a": 1, "b": [2, 3, 4]} >>> mask = {"a": True, "b": [False, True, False]} >>> sp.at(tree)[mask].pluck() [1, 3] This is equivalent to the following: >>> [tree["a"], tree["b"][1]] # doctest: +SKIP """ treelib = sepes._src.backend.treelib tree = self.get(is_leaf=is_leaf, is_parallel=is_parallel) subtrees: list[Any] = [] count = float("inf") if count is None else count def aggregate_subtrees(node: Any) -> bool: nonlocal subtrees, count if count < 1: # stop traversing the tree # if total number of subtrees is reached return True if id(node) == id(tree): # skip the root node # for example if tree = dict(a=1) and mask is dict(a=True) # then returns [1] and not [dict(a=1)] return False leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None) # in essence if the subtree does not contain any None leaves # then it is a valid subtree to be plucked # this because `get` sets the non-selected leaves to None if any(leaf is None for leaf in leaves): return False subtrees += [node] count -= 1 return True treelib.flatten(tree, is_leaf=aggregate_subtrees) return subtrees
[docs] @staticmethod def def_rule( matcher_type: type[T], compare: Callable[[T, PathKeyEntry, Leaf], bool], *, broadcastable: bool = False, ) -> None: """Define a rule to match user input to with the corresponding path and leaf entry. Args: matcher_type: the user match object type to match with the path and leaf entry. compare: a function to compare the user matcher object with the path and leaf entry the function accepts the user input, the path entry, and the leaf value and returns a boolean value to mark if the user input matches the path and leaf entry. broadcastable: if the user type match result should be broadcasted to the full subtree. Default to ``False``. Example: Define a type matcher that matches based on the name, dtype, and shape of the leaf and then apply a function to the matched leaf. >>> import sepes as sp >>> import jax >>> import jax.numpy as jnp >>> import dataclasses as dc >>> @dc.dataclass ... class NameDtypeShapeMatcher: ... name: str ... dtype: jnp.dtype ... shape: tuple[int, ...] >>> def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool: ... if not isinstance(leaf, jax.Array): ... return False ... if isinstance(key, str): ... key = key ... elif isinstance(key, jax.tree_util.GetAttrKey): ... key = key.name ... elif isinstance(key, jax.tree_util.DictKey): ... key = key.key ... return matcher.name == key and matcher.dtype == leaf.dtype and matcher.shape == leaf.shape >>> tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3)) >>> sp.at.def_rule(NameDtypeShapeMatcher, compare) >>> matcher = NameDtypeShapeMatcher('weight', jnp.int32, (3, 3)) >>> to_symmetric = lambda x: (x + x.T) / 2 >>> sp.at(tree)[matcher].apply(to_symmetric)["weight"] Array([[0., 2., 4.], [2., 4., 6.], [4., 6., 8.]], dtype=float32) """ # remove the BaseKey abstraction from the user-facing function class UserKey(BaseKey): broadcast: bool = broadcastable def __init__(self, input: T): self.input = input def compare(self, key: PathKeyEntry, leaf: Leaf) -> bool: return compare(self.input, key, leaf) at.dispatcher.register(matcher_type, UserKey)
# pass through for boolean pytrees masks and tuple of keys at.dispatcher = ft.singledispatch(lambda x: x) # key rules to match user input to with the path and leaf entry def str_compare(name: str, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a leaf with a given name.""" del leaf if isinstance(key, str): return name == key treelib = sepes._src.backend.treelib if isinstance(key, type(treelib.attribute_key(""))): return name == key.name if isinstance(key, type(treelib.dict_key(""))): return name == key.key return False def int_compare(idx: int, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a leaf with a given index.""" del leaf if isinstance(key, int): return idx == key treelib = sepes._src.backend.treelib if isinstance(key, type(treelib.sequence_key(0))): return idx == key.idx return False def regex_compare(pattern: re.Pattern, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a path with a regex pattern inside 'at' property.""" del leaf if isinstance(key, str): return re.fullmatch(pattern, key) is not None treelib = sepes._src.backend.treelib if isinstance(key, type(treelib.attribute_key(""))): return re.fullmatch(pattern, key.name) is not None if isinstance(key, type(treelib.dict_key(""))): return re.fullmatch(pattern, key.key) is not None return False def ellipsis_compare(_, key: PathKeyEntry, leaf: Leaf) -> bool: del key, leaf return True at.def_rule(str, str_compare, broadcastable=False) at.def_rule(int, int_compare, broadcastable=False) at.def_rule(re.Pattern, regex_compare, broadcastable=False) at.def_rule(type(...), ellipsis_compare, broadcastable=True) class MultiKey(BaseKey): """Match a leaf with multiple keys at the same level.""" def __init__(self, *keys: BaseKey): self.keys = keys def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool: return any(key.compare(entry, leaf) for key in self.keys) broadcast: bool = False if is_package_avaiable("jax"): import jax.tree_util as jtu def jax_key_compare(input, key: PathKeyEntry, leaf: Leaf) -> bool: """Enable indexing with jax keys directly in `at`.""" del leaf return input == key at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False) at.def_rule(jtu.GetAttrKey, jax_key_compare, broadcastable=False) at.def_rule(jtu.DictKey, jax_key_compare, broadcastable=False)