# Copyright 2022-2025 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Registry for custom pytree node types.""" # pylint: disable=too-many-lines from __future__ import annotations import contextlib import dataclasses import inspect import sys from collections import OrderedDict, defaultdict, deque, namedtuple from operator import itemgetter, methodcaller from threading import Lock from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, overload import optree._C as _C from optree.accessors import ( AutoEntry, MappingEntry, NamedTupleEntry, PyTreeEntry, SequenceEntry, StructSequenceEntry, ) from optree.typing import PyTreeKind, StructSequence, T, is_namedtuple_class, is_structseq_class from optree.utils import safe_zip, total_order_sorted, unzip2 if TYPE_CHECKING: import builtins from collections.abc import Collection, Generator, Iterable from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc # pylint: disable-next=invalid-name CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode]) __all__ = [ 'register_pytree_node', 'register_pytree_node_class', 'unregister_pytree_node', 'dict_insertion_ordered', ] SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+ @dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS) class PyTreeNodeRegistryEntry(Generic[T]): """A dataclass that stores the information of a pytree node type.""" type: builtins.type[Collection[T]] flatten_func: FlattenFunc[T] unflatten_func: UnflattenFunc[T] if sys.version_info >= (3, 10): # pragma: >=3.10 cover _: dataclasses.KW_ONLY # Python 3.10+ path_entry_type: builtins.type[PyTreeEntry] = AutoEntry kind: PyTreeKind = PyTreeKind.CUSTOM namespace: str = '' del SLOTS # pylint: disable-next=missing-class-docstring,too-few-public-methods class GlobalNamespace: # pragma: no cover __slots__: ClassVar[tuple[()]] = () def __repr__(self, /) -> str: return '' __GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment] __REGISTRY_LOCK: Lock = Lock() del GlobalNamespace if TYPE_CHECKING: from typing_extensions import ParamSpec # Python 3.10+ _P = ParamSpec('_P') _T = TypeVar('_T') _GetP = ParamSpec('_GetP') _GetT = TypeVar('_GetT') class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]): def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: raise NotImplementedError # pylint: disable-next=missing-function-docstring def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT: raise NotImplementedError def _add_get( get: Callable[_GetP, _GetT], /, ) -> Callable[ [Callable[_P, _T]], _CallableWithGet[_P, _T, _GetP, _GetT], ]: def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]: func.get = get # type: ignore[attr-defined] return func # type: ignore[return-value] return decorator @overload def pytree_node_registry_get( cls: type, /, *, namespace: str = '', ) -> PyTreeNodeRegistryEntry | None: ... @overload def pytree_node_registry_get( cls: None = None, /, *, namespace: str = '', ) -> dict[type, PyTreeNodeRegistryEntry]: ... # pylint: disable-next=too-many-return-statements,too-many-branches def pytree_node_registry_get( # noqa: C901 cls: type | None = None, /, *, namespace: str = '', ) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None: """Lookup the pytree node registry. >>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS { : PyTreeNodeRegistryEntry( type=, flatten_func=, unflatten_func=, path_entry_type=, kind=, namespace='' ), : PyTreeNodeRegistryEntry( type=, flatten_func=, unflatten_func=, path_entry_type=, kind=, namespace='' ), : PyTreeNodeRegistryEntry( type=, flatten_func=, unflatten_func=, path_entry_type=, kind=, namespace='' ), ... } >>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS PyTreeNodeRegistryEntry( type=, flatten_func=, unflatten_func=, path_entry_type=, kind=, namespace='' ) >>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node None Args: cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all the registered pytree nodes in the namespace are returned. namespace (str, optional): The namespace of the registry to retrieve. If not provided, the global namespace is used. Returns: If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the namespace is returned. If the ``cls`` is provided, the corresponding registry entry is returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned, i.e., the ``cls`` is represented as a leaf node. """ if namespace is __GLOBAL_NAMESPACE: namespace = '' if ( cls is not None and cls is not namedtuple # noqa: PYI024 and not inspect.isclass(cls) ): raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover if not isinstance(namespace, str): raise TypeError( # pragma: !=3.9 cover f'The namespace must be a string, got {namespace!r}.', ) if cls is None: namespaces = frozenset({namespace, ''}) with __REGISTRY_LOCK: registry = { handler.type: handler for handler in _NODETYPE_REGISTRY.values() if handler.namespace in namespaces } if _C.is_dict_insertion_ordered(namespace): registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY return registry if namespace != '': handler = _NODETYPE_REGISTRY.get((namespace, cls)) if handler is not None: return handler if _C.is_dict_insertion_ordered(namespace): if cls is dict: return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY if cls is defaultdict: return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY handler = _NODETYPE_REGISTRY.get(cls) if handler is not None: return handler if is_structseq_class(cls): return _NODETYPE_REGISTRY.get(StructSequence) if is_namedtuple_class(cls): return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024 return None @_add_get(pytree_node_registry_get) def register_pytree_node( cls: type[Collection[T]], /, flatten_func: FlattenFunc[T], unflatten_func: UnflattenFunc[T], *, path_entry_type: type[PyTreeEntry] = AutoEntry, namespace: str, ) -> type[Collection[T]]: """Extend the set of types that are considered internal nodes in pytrees. See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`. The ``namespace`` argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases. .. warning:: For safety reasons, a ``namespace`` must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type. Args: cls (type): A Python type to treat as an internal pytree node. flatten_func (callable): A function to be used during flattening, taking an instance of ``cls`` and returning a triple or optionally a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable metadata to be stored in the treespec and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path entries to the corresponding children. If the entries are not provided or given by :data:`None`, then `range(len(children))` will be used. unflatten_func (callable): A function taking two arguments: the metadata that was returned by ``flatten_func`` and stored in the treespec, and the unflattened children. The function should return an instance of ``cls``. path_entry_type (type, optional): The type of the path entry to be used in the treespec. (default: :class:`AutoEntry`) namespace (str): A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type. Returns: The same type as the input ``cls``. Raises: TypeError: If the input type is not a class. TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`. TypeError: If the namespace is not a string. ValueError: If the namespace is an empty string. ValueError: If the type is already registered in the registry. Examples: >>> # Registry a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='set', ... ) >>> # Register a Python type into a namespace >>> import torch >>> register_pytree_node( ... torch.Tensor, ... flatten_func=lambda tensor: ( ... (tensor.cpu().detach().numpy(),), ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, ... ), ... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), ... namespace='torch2numpy', ... ) >>> # doctest: +SKIP >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} >>> tree {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} >>> # Flatten without specifying the namespace >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) >>> # Flatten with the namespace >>> tree_flatten(tree, namespace='torch2numpy') ( [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], PyTreeSpec( { 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]), 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*]) }, namespace='torch2numpy' ) ) >>> # Register the same type with a different namespace for different behaviors >>> def tensor2flatparam(tensor): ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None ... ... def flatparam2tensor(metadata, children): ... return children[0].reshape(metadata) ... ... register_pytree_node( ... torch.Tensor, ... flatten_func=tensor2flatparam, ... unflatten_func=flatparam2tensor, ... namespace='tensor2flatparam', ... ) >>> # Flatten with the new namespace >>> tree_flatten(tree, namespace='tensor2flatparam') ( [ Parameter containing: tensor([0., 0.], requires_grad=True), Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) ], PyTreeSpec( { 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) }, namespace='tensor2flatparam' ) ) """ # pylint: disable=line-too-long if not inspect.isclass(cls): raise TypeError(f'Expected a class, got {cls!r}.') if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): raise TypeError(f'The namespace must be a string, got {namespace!r}.') if namespace == '': raise ValueError('The namespace cannot be an empty string.') registration_key: type | tuple[str, type] if namespace is __GLOBAL_NAMESPACE: registration_key = cls namespace = '' else: registration_key = (namespace, cls) with __REGISTRY_LOCK: _C.register_node( cls, flatten_func, unflatten_func, path_entry_type, namespace, ) _NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry( cls, flatten_func, unflatten_func, path_entry_type=path_entry_type, namespace=namespace, ) return cls del pytree_node_registry_get, _add_get @overload def register_pytree_node_class( cls: str | None = None, /, *, path_entry_type: type[PyTreeEntry] | None = None, namespace: str | None = None, ) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ... @overload def register_pytree_node_class( cls: CustomTreeNodeType, /, *, path_entry_type: type[PyTreeEntry] | None, namespace: str, ) -> CustomTreeNodeType: ... def register_pytree_node_class( # noqa: C901 cls: CustomTreeNodeType | str | None = None, /, *, path_entry_type: type[PyTreeEntry] | None = None, namespace: str | None = None, ) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]: """Extend the set of types that are considered internal nodes in pytrees. See also :func:`register_pytree_node` and :func:`unregister_pytree_node`. The ``namespace`` argument is used to avoid collisions that occur when different libraries register the same Python type with different behaviors. It is recommended to add a unique prefix to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify the same class in different namespaces for different use cases. .. warning:: For safety reasons, a ``namespace`` must be specified while registering a custom type. It is used to isolate the behavior of flattening and unflattening a pytree node type. This is to prevent accidental collisions between different libraries that may register the same type. Args: cls (type, optional): A Python type to treat as an internal pytree node. path_entry_type (type, optional): The type of the path entry to be used in the treespec. (default: :class:`AutoEntry`) namespace (str, optional): A non-empty string that uniquely identifies the namespace of the type registry. This is used to isolate the registry from other modules that might register a different custom behavior for the same type. Returns: The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator function that registers the class as a pytree node. Raises: TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`. TypeError: If the namespace is not a string. ValueError: If the namespace is an empty string. ValueError: If the type is already registered in the registry. This function is a thin wrapper around :func:`register_pytree_node`, and provides a class-oriented interface:: @register_pytree_node_class(namespace='foo') class Special: TREE_PATH_ENTRY_TYPE = GetAttrEntry def __init__(self, x, y): self.x = x self.y = y def tree_flatten(self): return ((self.x, self.y), None, ('x', 'y')) @classmethod def tree_unflatten(cls, metadata, children): return cls(*children) @register_pytree_node_class('mylist') class MyList(UserList): TREE_PATH_ENTRY_TYPE = SequenceEntry def tree_flatten(self): return self.data, None, None @classmethod def tree_unflatten(cls, metadata, children): return cls(*children) """ if cls is __GLOBAL_NAMESPACE or isinstance(cls, str): if namespace is not None: raise ValueError('Cannot specify `namespace` when the first argument is a string.') if cls == '': raise ValueError('The namespace cannot be an empty string.') cls, namespace = None, cls if namespace is None: raise ValueError('Must specify `namespace` when the first argument is a class.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): raise TypeError(f'The namespace must be a string, got {namespace!r}') if namespace == '': raise ValueError('The namespace cannot be an empty string.') if cls is None: def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType: return register_pytree_node_class( cls, path_entry_type=path_entry_type, namespace=namespace, ) return decorator if not inspect.isclass(cls): raise TypeError(f'Expected a class, got {cls!r}.') if path_entry_type is None: path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry) if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.') register_pytree_node( cls, methodcaller('tree_flatten'), cls.tree_unflatten, path_entry_type=path_entry_type, namespace=namespace, ) return cls def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry: """Remove a type from the pytree node registry. See also :func:`register_pytree_node` and :func:`register_pytree_node_class`. This function is the inverse operation of function :func:`register_pytree_node`. Args: cls (type): A Python type to remove from the pytree node registry. namespace (str): The namespace of the pytree node registry to remove the type from. Returns: The removed registry entry. Raises: TypeError: If the input type is not a class. TypeError: If the namespace is not a string. ValueError: If the namespace is an empty string. ValueError: If the type is a built-in type that cannot be unregistered. ValueError: If the type is not found in the registry. Examples: >>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), ... lambda _, children: set(children), ... namespace='temp', ... ) >>> # Unregister the Python type >>> unregister_pytree_node(set, namespace='temp') """ if not inspect.isclass(cls): raise TypeError(f'Expected a class, got {cls!r}.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): raise TypeError(f'The namespace must be a string, got {namespace!r}.') if namespace == '': raise ValueError('The namespace cannot be an empty string.') registration_key: type | tuple[str, type] if namespace is __GLOBAL_NAMESPACE: registration_key = cls namespace = '' else: registration_key = (namespace, cls) with __REGISTRY_LOCK: _C.unregister_node(cls, namespace) return _NODETYPE_REGISTRY.pop(registration_key) @contextlib.contextmanager def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]: """Context manager to temporarily set the dictionary sorting mode. This context manager is used to temporarily set the dictionary sorting mode for a specific namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary should be sorted or keeping the insertion order when flattening a pytree. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE ( [1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) ) >>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE ... tree_flatten(tree, namespace='some-namespace') ( [2, 3, 4, 1, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace') ) .. warning:: The dictionary sorting mode is a global setting and is **not thread-safe**. It is recommended to use this context manager in a single-threaded environment. Args: mode (bool): The dictionary sorting mode to set. namespace (str): The namespace to set the dictionary sorting mode for. """ if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): raise TypeError(f'The namespace must be a string, got {namespace!r}.') if namespace == '': raise ValueError('The namespace cannot be an empty string.') if namespace is __GLOBAL_NAMESPACE: namespace = '' with __REGISTRY_LOCK: prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False) _C.set_dict_insertion_ordered(bool(mode), namespace) try: yield finally: with __REGISTRY_LOCK: _C.set_dict_insertion_ordered(prev, namespace) def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]: return total_order_sorted(items, key=itemgetter(0)) def _none_flatten(_: None, /) -> tuple[tuple[()], None]: return (), None def _none_unflatten(_: None, children: Iterable[Any], /) -> None: sentinel = object() if next(iter(children), sentinel) is not sentinel: raise ValueError('Expected no children.') def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]: return tup, None def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]: return tuple(children) def _list_flatten(lst: list[T], /) -> tuple[list[T], None]: return lst, None def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]: return list(children) def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: keys, values = unzip2(_sorted_items(dct.items())) return values, list(keys), keys def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: return dict(safe_zip(keys, values)) def _dict_insertion_ordered_flatten( dct: dict[KT, VT], /, ) -> tuple[ tuple[VT, ...], list[KT], tuple[KT, ...], ]: keys, values = unzip2(dct.items()) return values, list(keys), keys def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: return dict(safe_zip(keys, values)) def _ordereddict_flatten( dct: OrderedDict[KT, VT], /, ) -> tuple[ tuple[VT, ...], list[KT], tuple[KT, ...], ]: keys, values = unzip2(dct.items()) return values, list(keys), keys def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]: return OrderedDict(safe_zip(keys, values)) def _defaultdict_flatten( dct: defaultdict[KT, VT], /, ) -> tuple[ tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...], ]: values, keys, entries = _dict_flatten(dct) return values, (dct.default_factory, keys), entries def _defaultdict_unflatten( metadata: tuple[Callable[[], VT], list[KT]], values: Iterable[VT], /, ) -> defaultdict[KT, VT]: default_factory, keys = metadata return defaultdict(default_factory, _dict_unflatten(keys, values)) def _defaultdict_insertion_ordered_flatten( dct: defaultdict[KT, VT], /, ) -> tuple[ tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...], ]: values, keys, entries = _dict_insertion_ordered_flatten(dct) return values, (dct.default_factory, keys), entries def _defaultdict_insertion_ordered_unflatten( metadata: tuple[Callable[[], VT], list[KT]], values: Iterable[VT], /, ) -> defaultdict[KT, VT]: default_factory, keys = metadata return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values)) def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]: return deq, deq.maxlen def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]: return deque(children, maxlen=maxlen) def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg] return tup, type(tup) # pylint: disable-next=line-too-long def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg] return cls(*children) # type: ignore[call-overload] def _structseq_flatten(seq: StructSequence[T], /) -> tuple[tuple[T, ...], type[StructSequence[T]]]: return seq, type(seq) def _structseq_unflatten( cls: type[StructSequence[T]], children: Iterable[T], /, ) -> StructSequence[T]: return cls(children) _NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = { type(None): PyTreeNodeRegistryEntry( type(None), # type: ignore[arg-type] _none_flatten, _none_unflatten, path_entry_type=PyTreeEntry, kind=PyTreeKind.NONE, ), tuple: PyTreeNodeRegistryEntry( tuple, _tuple_flatten, _tuple_unflatten, path_entry_type=SequenceEntry, kind=PyTreeKind.TUPLE, ), list: PyTreeNodeRegistryEntry( list, _list_flatten, _list_unflatten, path_entry_type=SequenceEntry, kind=PyTreeKind.LIST, ), dict: PyTreeNodeRegistryEntry( dict, _dict_flatten, _dict_unflatten, path_entry_type=MappingEntry, kind=PyTreeKind.DICT, ), namedtuple: PyTreeNodeRegistryEntry( # type: ignore[dict-item] # noqa: PYI024 namedtuple, # type: ignore[arg-type] # noqa: PYI024 _namedtuple_flatten, _namedtuple_unflatten, path_entry_type=NamedTupleEntry, kind=PyTreeKind.NAMEDTUPLE, ), OrderedDict: PyTreeNodeRegistryEntry( OrderedDict, _ordereddict_flatten, _ordereddict_unflatten, path_entry_type=MappingEntry, kind=PyTreeKind.ORDEREDDICT, ), defaultdict: PyTreeNodeRegistryEntry( defaultdict, _defaultdict_flatten, _defaultdict_unflatten, path_entry_type=MappingEntry, kind=PyTreeKind.DEFAULTDICT, ), deque: PyTreeNodeRegistryEntry( deque, _deque_flatten, _deque_unflatten, path_entry_type=SequenceEntry, kind=PyTreeKind.DEQUE, ), StructSequence: PyTreeNodeRegistryEntry( StructSequence, _structseq_flatten, _structseq_unflatten, path_entry_type=StructSequenceEntry, kind=PyTreeKind.STRUCTSEQUENCE, ), } _DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry( dict, _dict_insertion_ordered_flatten, _dict_insertion_ordered_unflatten, path_entry_type=MappingEntry, kind=PyTreeKind.DICT, ) _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry( defaultdict, _defaultdict_insertion_ordered_flatten, _defaultdict_insertion_ordered_unflatten, path_entry_type=MappingEntry, kind=PyTreeKind.DEFAULTDICT, )