# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Access support for pytrees.""" from __future__ import annotations import dataclasses import sys from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload from typing_extensions import Self # Python 3.11+ import optree._C as _C from optree._C import PyTreeKind if TYPE_CHECKING: import builtins from optree.typing import NamedTuple, StructSequence __all__ = [ 'PyTreeEntry', 'GetItemEntry', 'GetAttrEntry', 'FlattenedEntry', 'AutoEntry', 'SequenceEntry', 'MappingEntry', 'NamedTupleEntry', 'StructSequenceEntry', 'DataclassEntry', 'PyTreeAccessor', ] SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+ @dataclasses.dataclass(init=True, repr=False, eq=False, frozen=True, **SLOTS) class PyTreeEntry: """Base class for path entries.""" entry: Any type: builtins.type kind: PyTreeKind def __post_init__(self, /) -> None: """Post-initialize the path entry.""" if self.kind == PyTreeKind.LEAF: raise ValueError('Cannot create a leaf path entry.') if self.kind == PyTreeKind.NONE: raise ValueError('Cannot create a path entry for None.') def __call__(self, obj: Any, /) -> Any: """Get the child object.""" try: return obj[self.entry] # should be overridden except TypeError as ex: raise TypeError( f'{self.__class__!r} cannot access through {obj!r} via entry {self.entry!r}', ) from ex def __add__(self, other: object, /) -> PyTreeAccessor: """Join the path entry with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return PyTreeAccessor((self, other)) if isinstance(other, PyTreeAccessor): return PyTreeAccessor((self, *other)) return NotImplemented def __eq__(self, other: object, /) -> bool: """Check if the path entries are equal.""" return isinstance(other, PyTreeEntry) and ( ( self.entry, self.type, self.kind, self.__class__.__call__.__code__.co_code, self.__class__.codify.__code__.co_code, ) == ( other.entry, other.type, other.kind, other.__class__.__call__.__code__.co_code, other.__class__.codify.__code__.co_code, ) ) def __hash__(self, /) -> int: """Get the hash of the path entry.""" return hash( ( self.entry, self.type, self.kind, self.__class__.__call__.__code__.co_code, self.__class__.codify.__code__.co_code, ), ) def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(entry={self.entry!r}, type={self.type!r})' def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}[]' # should be overridden del SLOTS _T = TypeVar('_T') _T_co = TypeVar('_T_co', covariant=True) _KT_co = TypeVar('_KT_co', covariant=True) _VT_co = TypeVar('_VT_co', covariant=True) class AutoEntry(PyTreeEntry): """A generic path entry class that determines the entry type on creation automatically.""" __slots__: ClassVar[tuple[()]] = () def __new__( # type: ignore[misc] cls, /, entry: Any, type: builtins.type, # pylint: disable=redefined-builtin kind: PyTreeKind, ) -> PyTreeEntry: """Create a new path entry.""" # pylint: disable-next=import-outside-toplevel from optree.typing import is_namedtuple_class, is_structseq_class if cls is not AutoEntry: # Use the subclass type if the type is explicitly specified return super().__new__(cls) if kind != PyTreeKind.CUSTOM: raise ValueError(f'Cannot create an automatic path entry for PyTreeKind {kind!r}.') # Dispatch the path entry type based on the node type path_entry_type: builtins.type[PyTreeEntry] if is_structseq_class(type): path_entry_type = StructSequenceEntry elif is_namedtuple_class(type): path_entry_type = NamedTupleEntry elif dataclasses.is_dataclass(type): path_entry_type = DataclassEntry elif issubclass(type, Mapping): path_entry_type = MappingEntry elif issubclass(type, Sequence): path_entry_type = SequenceEntry else: path_entry_type = FlattenedEntry if not issubclass(path_entry_type, AutoEntry): # The __init__() method will not be called if the returned instance is not a subtype of # AutoEntry. We should return an initialized instance. Return a fully-initialized # instance of the dispatched type. return path_entry_type(entry, type, kind) # The __init__() method will be called if the returned instance is a subtype of AutoEntry. # We should return an uninitialized instance. The __init__() method will initialize it. # But we will never reach here because the dispatched type is never a subtype of AutoEntry. raise NotImplementedError('Unreachable code.') class GetItemEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getitem__`.""" __slots__: ClassVar[tuple[()]] = () def __call__(self, obj: Any, /) -> Any: """Get the child object.""" return obj[self.entry] def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}[{self.entry!r}]' class GetAttrEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getattr__`.""" __slots__: ClassVar[tuple[()]] = () entry: str @property def name(self, /) -> str: """Get the attribute name.""" return self.entry def __call__(self, obj: Any, /) -> Any: """Get the child object.""" return getattr(obj, self.name) def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.name}' class FlattenedEntry(PyTreeEntry): # pylint: disable=too-few-public-methods """A fallback path entry class for flattened objects.""" __slots__: ClassVar[tuple[()]] = () class SequenceEntry(GetItemEntry, Generic[_T_co]): """A path entry class for sequences.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[Sequence[_T_co]] @property def index(self, /) -> int: """Get the index.""" return self.entry def __call__(self, obj: Sequence[_T_co], /) -> _T_co: """Get the child object.""" return obj[self.index] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(index={self.index!r}, type={self.type!r})' class MappingEntry(GetItemEntry, Generic[_KT_co, _VT_co]): """A path entry class for mappings.""" __slots__: ClassVar[tuple[()]] = () entry: _KT_co type: builtins.type[Mapping[_KT_co, _VT_co]] @property def key(self, /) -> _KT_co: """Get the key.""" return self.entry def __call__(self, obj: Mapping[_KT_co, _VT_co], /) -> _VT_co: """Get the child object.""" return obj[self.key] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(key={self.key!r}, type={self.type!r})' class NamedTupleEntry(SequenceEntry[_T]): """A path entry class for namedtuple objects.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[NamedTuple[_T]] # type: ignore[type-arg] kind: Literal[PyTreeKind.NAMEDTUPLE] @property def fields(self, /) -> tuple[str, ...]: """Get the field names.""" from optree.typing import namedtuple_fields # pylint: disable=import-outside-toplevel return namedtuple_fields(self.type) @property def field(self, /) -> str: """Get the field name.""" return self.fields[self.entry] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})' def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.field}' class StructSequenceEntry(SequenceEntry[_T]): """A path entry class for PyStructSequence objects.""" __slots__: ClassVar[tuple[()]] = () entry: int type: builtins.type[StructSequence[_T]] kind: Literal[PyTreeKind.STRUCTSEQUENCE] @property def fields(self, /) -> tuple[str, ...]: """Get the field names.""" from optree.typing import structseq_fields # pylint: disable=import-outside-toplevel return structseq_fields(self.type) @property def field(self, /) -> str: """Get the field name.""" return self.fields[self.entry] def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})' def codify(self, /, node: str = '') -> str: """Generate code for accessing the path entry.""" return f'{node}.{self.field}' class DataclassEntry(GetAttrEntry): """A path entry class for dataclasses.""" __slots__: ClassVar[tuple[()]] = () entry: str | int # type: ignore[assignment] @property def fields(self, /) -> tuple[str, ...]: # pragma: no cover """Get all field names.""" return tuple(f.name for f in dataclasses.fields(self.type)) @property def init_fields(self, /) -> tuple[str, ...]: """Get the init field names.""" return tuple(f.name for f in dataclasses.fields(self.type) if f.init) @property def field(self, /) -> str: """Get the field name.""" if isinstance(self.entry, int): return self.init_fields[self.entry] return self.entry @property def name(self, /) -> str: """Get the attribute name.""" return self.field def __repr__(self, /) -> str: """Get the representation of the path entry.""" return f'{self.__class__.__name__}(field={self.field!r}, type={self.type!r})' class PyTreeAccessor(tuple[PyTreeEntry, ...]): """A path class for PyTrees.""" __slots__: ClassVar[tuple[()]] = () @property def path(self, /) -> tuple[Any, ...]: """Get the path of the accessor.""" return tuple(e.entry for e in self) def __new__(cls, /, path: Iterable[PyTreeEntry] = ()) -> Self: """Create a new accessor instance.""" if not isinstance(path, (list, tuple)): path = tuple(path) if not all(isinstance(p, PyTreeEntry) for p in path): raise TypeError(f'Expected a path of PyTreeEntry, got {path!r}.') return super().__new__(cls, path) def __call__(self, obj: Any, /) -> Any: """Get the child object.""" for entry in self: obj = entry(obj) return obj @overload # type: ignore[override] def __getitem__(self, index: int, /) -> PyTreeEntry: ... @overload def __getitem__(self, index: slice, /) -> Self: ... def __getitem__(self, index: int | slice, /) -> PyTreeEntry | Self: """Get the child path entry or an accessor for a subpath.""" if isinstance(index, slice): return self.__class__(super().__getitem__(index)) return super().__getitem__(index) def __add__(self, other: object, /) -> Self: """Join the accessor with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return self.__class__((*self, other)) if isinstance(other, PyTreeAccessor): return self.__class__((*self, *other)) return NotImplemented def __mul__(self, value: int, /) -> Self: # type: ignore[override] """Repeat the accessor.""" return self.__class__(super().__mul__(value)) def __rmul__(self, value: int, /) -> Self: # type: ignore[override] """Repeat the accessor.""" return self.__class__(super().__rmul__(value)) def __eq__(self, other: object, /) -> bool: """Check if the accessors are equal.""" return isinstance(other, PyTreeAccessor) and super().__eq__(other) def __hash__(self, /) -> int: """Get the hash of the accessor.""" return super().__hash__() def __repr__(self, /) -> str: """Get the representation of the accessor.""" return f'{self.__class__.__name__}({self.codify()}, {super().__repr__()})' def codify(self, /, root: str = '*') -> str: """Generate code for accessing the path.""" string = root for entry in self: string = entry.codify(string) return string # These classes are used internally in the C++ side for accessor APIs _name, _cls = '', object for _name in __all__: _cls = globals()[_name] if not isinstance(_cls, type): # pragma: no cover raise TypeError(f'Expected a class, got {_cls!r}.') _cls.__module__ = 'optree' setattr(_C, _name, _cls) del _name, _cls