Source code for sepes._src.backend

# Copyright 2023 sepes authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools as ft
import os
from importlib.util import find_spec
from typing import Literal, Callable
import logging
from contextlib import contextmanager


@ft.lru_cache(maxsize=None)
def is_package_avaiable(backend: str) -> bool:
    return find_spec(backend) is not None


# by importing the backend modules here, we register the backend implementations
# with the arraylib
if is_package_avaiable("torch"):
    import sepes._src.backend.arraylib.torch
if is_package_avaiable("jax"):
    import sepes._src.backend.arraylib.jax
if is_package_avaiable("numpy"):
    import sepes._src.backend.arraylib.numpy


def optree_backend():
    # no backend is available
    if not is_package_avaiable("optree"):
        raise ImportError("No backend is available. Please install `optree`.")
    from sepes._src.backend.treelib.optree import OpTreeTreeLib

    return OpTreeTreeLib()


def jax_backend():
    if not is_package_avaiable("jax"):
        raise ImportError("`jax` backend requires `jax` to be installed.")
    from sepes._src.backend.treelib.jax import JaxTreeLib

    return JaxTreeLib()


BackendLiteral = Literal["optree", "jax"]  # tree backend
backend: BackendLiteral = os.environ.get("SEPES_BACKEND", "default").lower()
backends_map: dict[BackendLiteral, Callable] = {}
backends_map["jax"] = jax_backend
backends_map["optree"] = optree_backend

if backend == "default":
    # backend promotion in essence is a search for the first available backend
    # in the following order: jax, optree
    # if no backend is available, then the default backend is used
    for backend_name in backends_map:
        if is_package_avaiable(backend_name):
            treelib = backends_map[backend_name]()
            backend = backend_name
            logging.info(f"Successfully set backend to `{backend_name}`")
            break
elif backend == "jax":
    treelib = jax_backend()
    logging.info("Successfully set backend to `jax`")
elif backend == "optree":
    treelib = optree_backend()
    logging.info(f"Successfully set backend to `{backend}`")
else:
    raise ValueError(f"Unknown backend: {backend!r}. available {backends_map.keys()=}")


[docs]@contextmanager def backend_context(backend_name: BackendLiteral): """Context manager for switching the tree backend within a context. Args: backend_name: The name of the backend to switch to. available backends are ``optree`` and ``jax``. Example: Registering a custom tree class with optree backend: >>> import sepes as sp >>> import optree >>> with sp.backend_context("optree"): ... class Tree(sp.TreeClass): ... def __init__(self, a, b): ... self.a = a ... self.b = b ... tree = Tree(1, 2) >>> optree.tree_flatten(tree, namespace="sepes") ([1, 2], PyTreeSpec(CustomTreeNode(Tree[('a', 'b')], [*, *]), namespace='sepes')) """ global treelib, backend old_treelib = treelib old_backend = backend try: treelib = backends_map[backend_name]() yield finally: treelib = old_treelib backend = old_backend