# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper that automatically registers dataclass as a Jax PyTree."""
import copy
import dataclasses
import hashlib
import typing
from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union
import jax
import numpy as np
_T = TypeVar('_T')
class _NumPyArrayHashWrapper:
"""A wrapper for NumPy arrays to make them hashable based on content.
This class is used to allow NumPy arrays to be part of the metadata in a Jax
PyTree registration, as metadata must be hashable. The hash is based on the
array's content, dtype, and shape.
"""
__slots__ = ('_hash_key', 'array')
def __init__(self, arr: np.ndarray):
if arr.size == 0:
h = hashlib.sha256(b'').hexdigest()
else:
contiguous = np.ascontiguousarray(arr)
h = hashlib.sha256(contiguous.data.cast('B')).hexdigest()
self._hash_key = (h, arr.dtype, arr.shape)
self.array = arr
def __hash__(self):
return hash(self._hash_key)
def __eq__(self, other):
if not isinstance(other, _NumPyArrayHashWrapper):
return NotImplemented
return self._hash_key == other._hash_key
def _jax_in_args(typ) -> bool:
if typ is jax.Array:
return True
if dataclasses.is_dataclass(typ):
return any(_jax_in_args(f.type) for f in dataclasses.fields(typ))
if typing.get_origin(typ) in (tuple, list, dict, Union, set):
return any(_jax_in_args(t) for t in typing.get_args(typ))
return False
def dataclass(clz: _T, register_as_pytree: bool) -> _T:
"""Wraps a dataclass with metadata for which fields are pytrees.
This is based off flax.struct.dataclass, but instead of using field
descriptors to specify which fields are pytrees, we follow a simple rule:
a leaf field is a pytree node if and only if it's a jax.Array
Args:
clz: the class to register as a dataclass
Returns:
the resulting dataclass, registered with Jax
"""
data_clz = dataclasses.dataclass(frozen=True)(clz)
data_clz.replace = dataclasses.replace
if register_as_pytree:
meta_fields, data_fields = [], []
for field in dataclasses.fields(data_clz):
if _jax_in_args(field.type):
data_fields.append(field)
else:
meta_fields.append(field)
def iterate_clz_with_keys(x):
def to_meta(field, obj):
val = getattr(obj, field.name)
if isinstance(val, np.ndarray):
return _NumPyArrayHashWrapper(val)
if typing.get_origin(field.type) == tuple:
type_args = typing.get_args(field.type)
if (
len(type_args) == 2
and type_args[0] == np.ndarray
and type_args[1] == ...
):
return tuple(_NumPyArrayHashWrapper(v) for v in val)
return val
def to_data(field, obj):
return (jax.tree_util.GetAttrKey(field.name), getattr(obj, field.name))
data = tuple(to_data(f, x) for f in data_fields)
meta = tuple(to_meta(f, x) for f in meta_fields)
return data, meta
def clz_from_iterable(meta, data):
def from_meta(field, meta):
if field.type is np.ndarray:
return (field.name, meta.array)
if typing.get_origin(field.type) == tuple:
type_args = typing.get_args(field.type)
if (
len(type_args) == 2
and type_args[0] == np.ndarray
and type_args[1] == ...
):
return (
field.name,
tuple(m.array for m in meta),
)
return (field.name, meta)
from_data = lambda field, meta: (field.name, meta)
meta_args = tuple(from_meta(f, m) for f, m in zip(meta_fields, meta))
data_args = tuple(from_data(f, m) for f, m in zip(data_fields, data))
return data_clz(**dict(meta_args + data_args))
jax.tree_util.register_pytree_with_keys(
data_clz, iterate_clz_with_keys, clz_from_iterable
)
return data_clz
TNode = TypeVar('TNode', bound='PyTreeNode')
[docs]
class PyTreeNode:
"""Base class for dataclasses that should act like a JAX pytree node.
This base class additionally avoids type checking errors when using PyType.
"""
def __init_subclass__(cls, register_as_pytree: bool = True, **kwargs):
super().__init_subclass__(**kwargs)
dataclass(cls, register_as_pytree=register_as_pytree)
def __init__(self, *args, **kwargs):
# stub for pytype
raise NotImplementedError
def replace(self: TNode, **overrides) -> TNode:
# stub for pytype
raise NotImplementedError
@classmethod
def fields(cls) -> Tuple[dataclasses.Field, ...]: # pylint: disable=g-bare-generic
return dataclasses.fields(cls)
def tree_replace(
self, params: Dict[str, Optional[jax.typing.ArrayLike]]
) -> 'PyTreeNode':
new = self
for k, v in params.items():
new = _tree_replace(new, k.split('.'), v)
return new
def _tree_replace(
base: PyTreeNode,
attr: Sequence[str],
val: Optional[jax.typing.ArrayLike],
) -> PyTreeNode:
"""Sets attributes in a struct.dataclass with values."""
if not attr:
return base
# special case for List attribute
if len(attr) > 1 and isinstance(getattr(base, attr[0]), list):
lst = copy.deepcopy(getattr(base, attr[0]))
for i, g in enumerate(lst):
if not hasattr(g, attr[1]):
continue
v = val if not hasattr(val, '__iter__') else val[i]
lst[i] = _tree_replace(g, attr[1:], v)
return base.replace(**{attr[0]: lst})
if len(attr) == 1:
return base.replace(**{attr[0]: val})
return base.replace(
**{attr[0]: _tree_replace(getattr(base, attr[0]), attr[1:], val)}
)