From a95265cfd546dbca96a23813488e728ecf08733e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20S=C3=A9ri=C3=A9?= Date: Sat, 24 Apr 2021 00:50:19 +0200 Subject: [PATCH 1/4] Introduce wrapper eager_function Introduce as_tensors, as_tensors_, as_raw_tensor, as_raw_tensors that rely in tree_flatten/tree_unflatten for more generic usages JaxTensor is no longer registered as a pytree datastructure Refactor JaxTensor._value_and_grad_fn --- eagerpy/__init__.py | 1 + eagerpy/astensor.py | 66 ++++++++++++++++++++++++++++- eagerpy/tensor/extensions.py | 1 - eagerpy/tensor/jax.py | 75 +++++++++------------------------ tests/test_main.py | 82 +++++++++++++++++++++++++++++++++++- 5 files changed, 165 insertions(+), 60 deletions(-) diff --git a/eagerpy/__init__.py b/eagerpy/__init__.py index bf04d8f..31d887a 100644 --- a/eagerpy/__init__.py +++ b/eagerpy/__init__.py @@ -33,6 +33,7 @@ def __getitem__(self, index: _T) -> _T: from .astensor import astensors # noqa: F401,E402 from .astensor import astensor_ # noqa: F401,E402 from .astensor import astensors_ # noqa: F401,E402 +from .astensor import eager_function # noqa: F401,E402 from .modules import torch # noqa: F401,E402 from .modules import tensorflow # noqa: F401,E402 diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index f547179..cdab2f0 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -1,6 +1,18 @@ -from typing import TYPE_CHECKING, Union, overload, Tuple, TypeVar, Generic, Any +import functools +from typing import ( + TYPE_CHECKING, + Union, + overload, + Tuple, + TypeVar, + Generic, + Any, + Callable, +) import sys +from jax import tree_flatten, tree_unflatten + from .tensor import Tensor from .tensor import TensorType @@ -59,9 +71,28 @@ def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: return tuple(astensor(x) for x in xs) +def as_tensors(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + leaf_values = tuple(astensor(value) for value in leaf_values) + return tree_unflatten(tree_def, leaf_values) + + T = TypeVar("T") +def as_raw_tensor(x: T) -> Any: + if isinstance(x, Tensor): + return x.raw + else: + return x + + +def as_raw_tensors(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + leaf_values = tuple(as_raw_tensor(value) for value in leaf_values) + return tree_unflatten(tree_def, leaf_values) + + class RestoreTypeFunc(Generic[T]): def __init__(self, x: T): self.unwrap = not isinstance(x, Tensor) @@ -84,7 +115,7 @@ def __call__(self, *args: Any) -> Any: ... def __call__(self, *args): # type: ignore # noqa: F811 - result = tuple(x.raw for x in args) if self.unwrap else args + result = tuple(as_raw_tensor(x) for x in args) if self.unwrap else args if len(result) == 1: (result,) = result return result @@ -96,3 +127,34 @@ def astensor_(x: T) -> Tuple[Tensor, RestoreTypeFunc[T]]: def astensors_(x: T, *xs: T) -> Tuple[Tuple[Tensor, ...], RestoreTypeFunc[T]]: return astensors(x, *xs), RestoreTypeFunc[T](x) + + +def as_tensors_(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + leaf_values, restore_type = astensors_(*leaf_values) + return tree_unflatten(tree_def, leaf_values), restore_type + + +def eager_function( + func: Callable[..., T], skip_argnums: Tuple = tuple() +) -> Callable[..., T]: + @functools.wraps(func) + def eager_func(*args: Any, **kwargs: Any) -> Any: + sorted_skip_argnums = sorted(skip_argnums) + skip_args = [arg for i, arg in enumerate(args) if i in sorted_skip_argnums] + kept_args = [arg for i, arg in enumerate(args) if i not in sorted_skip_argnums] + + (kept_args, kwargs), restore_type = as_tensors_((kept_args, kwargs)) + + for i, arg in zip(sorted_skip_argnums, skip_args): + kept_args.insert(i, arg) + + result = func(*kept_args, **kwargs) + + if restore_type.unwrap: + raw_result = as_raw_tensors(result) + return raw_result + else: + return result + + return eager_func diff --git a/eagerpy/tensor/extensions.py b/eagerpy/tensor/extensions.py index 1a0423f..f60ec75 100644 --- a/eagerpy/tensor/extensions.py +++ b/eagerpy/tensor/extensions.py @@ -6,7 +6,6 @@ from .tensor import Tensor - T = TypeVar("T") diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index 1acf104..b3e868b 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -9,8 +9,8 @@ Optional, overload, Callable, - Type, ) + from typing_extensions import Literal from importlib import import_module import numpy as onp @@ -62,23 +62,8 @@ class JAXTensor(BaseTensor): # more specific types for the extensions norms: "NormsMethods[JAXTensor]" - _registered = False key = None - def __new__(cls: Type["JAXTensor"], *args: Any, **kwargs: Any) -> "JAXTensor": - if not cls._registered: - import jax - - def flatten(t: JAXTensor) -> Tuple[Any, None]: - return ((t.raw,), None) - - def unflatten(aux_data: None, children: Tuple) -> JAXTensor: - return cls(*children) - - jax.tree_util.register_pytree_node(cls, flatten, unflatten) - cls._registered = True - return cast(JAXTensor, super().__new__(cls)) - def __init__(self, raw: "np.ndarray"): # type: ignore global jax global np @@ -434,46 +419,24 @@ def _value_and_grad_fn( def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1) self: TensorType, f: Callable, has_aux: bool = False ) -> Callable[..., Tuple]: - # f takes and returns JAXTensor instances - # jax.value_and_grad accepts functions that take JAXTensor instances - # because we registered JAXTensor as JAX type, but it still requires - # the output to be a scalar (that is not not wrapped as a JAXTensor) - - # f_jax is like f but unwraps loss - if has_aux: - - def f_jax(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: - loss, aux = f(*args, **kwargs) - return loss.raw, aux - - else: - - def f_jax(*args: Any, **kwargs: Any) -> Any: # type: ignore - loss = f(*args, **kwargs) - return loss.raw - - value_and_grad_jax = jax.value_and_grad(f_jax, has_aux=has_aux) - - # value_and_grad is like value_and_grad_jax but wraps loss - if has_aux: - - def value_and_grad( - x: JAXTensor, *args: Any, **kwargs: Any - ) -> Tuple[JAXTensor, Any, JAXTensor]: - assert isinstance(x, JAXTensor) - (loss, aux), grad = value_and_grad_jax(x, *args, **kwargs) - assert grad.shape == x.shape - return JAXTensor(loss), aux, grad - - else: - - def value_and_grad( # type: ignore - x: JAXTensor, *args: Any, **kwargs: Any - ) -> Tuple[JAXTensor, JAXTensor]: - assert isinstance(x, JAXTensor) - loss, grad = value_and_grad_jax(x, *args, **kwargs) - assert grad.shape == x.shape - return JAXTensor(loss), grad + from eagerpy.astensor import as_tensors, as_raw_tensors + + def value_and_grad( + x: JAXTensor, *args: Any, **kwargs: Any + ) -> Union[Tuple[JAXTensor, JAXTensor], Tuple[JAXTensor, Any, JAXTensor]]: + assert isinstance(x, JAXTensor) + x, args, kwargs = as_raw_tensors((x, args, kwargs)) + + loss_aux, grad = jax.value_and_grad(f, has_aux=has_aux)(x, *args, **kwargs) + assert grad.shape == x.shape + loss_aux, grad = as_tensors((loss_aux, grad)) + + if has_aux: + loss, aux = loss_aux + return loss, aux, grad + else: + loss = loss_aux + return loss, grad return value_and_grad diff --git a/tests/test_main.py b/tests/test_main.py index 36ca7df..7d0d701 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,7 @@ import itertools import numpy as np import eagerpy as ep -from eagerpy import Tensor +from eagerpy import Tensor, eager_function from eagerpy.types import Shape, AxisAxes # make sure there are no undecorated tests in the "special tests" section below @@ -147,6 +147,7 @@ def test_value_and_grad_fn(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: ep.Tensor) -> ep.Tensor: return x.square().sum() @@ -161,6 +162,7 @@ def test_value_and_grad_fn_with_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x @@ -177,6 +179,7 @@ def test_value_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tensor: return x.square().sum() @@ -190,6 +193,7 @@ def test_value_aux_and_grad(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tensor]: x = x.square() return x.sum(), x @@ -205,6 +209,7 @@ def test_value_aux_and_grad_multiple_aux(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: x = x.square() return x.sum(), (x, x + 1) @@ -221,6 +226,7 @@ def test_value_and_grad_multiple_args(dummy: Tensor) -> None: if isinstance(dummy, ep.NumPyTensor): pytest.skip() + @eager_function def f(x: Tensor, y: Tensor) -> Tensor: return (x * y).sum() @@ -1581,3 +1587,77 @@ def test_norms_lp(t: Tensor) -> Tensor: @compare_all def test_norms_cache(t: Tensor) -> Tensor: return t.norms.l1() + t.norms.l2() + + +@eager_function +def my_universal_function(a: Tensor, b: Tensor, c: Tensor) -> Tensor: + return (a + b * c).square() + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function(t: Tensor, astensor: bool) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = my_universal_function(a, b, c) + assert isinstance(result, type(a)) + return ep.astensor(result) + + +# define a non-registered pytree container. +class NonRegisteredDataStruct: + def __init__(self, res: Any) -> None: + self.res = res + + +@eager_function +def my_universal_function_return_non_registered_datastruct( + a: Tensor, b: Tensor, c: Tensor +) -> Any: + res = (a + b * c).square() + return NonRegisteredDataStruct(res) + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_return_non_registered_datastruct( + t: Tensor, astensor: bool +) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = my_universal_function_return_non_registered_datastruct(a, b, c) + + # result has not been converted because NonRegisteredSpecial + # is not a registered pytree container + assert isinstance(result.res, type(t)) + return ep.astensor(result.res) + + +# define a non-registered pytree container. +class MyClass: + @functools.partial(eager_function, skip_argnums=(0,)) + def my_universal_method(self, a: Tensor, b: Tensor, c: Tensor) -> Any: + res = (a + b * c).square() + return res + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_on_method(t: Tensor, astensor: bool) -> Tensor: + if astensor: + a = t + else: + a = t.raw + b = a + c = a + result = MyClass().my_universal_method(a, b, c) + assert isinstance(result, type(a)) + return ep.astensor(result) From 81bba4d5ccdcbd485dc71d57d393bb93880fa7d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20S=C3=A9ri=C3=A9?= Date: Sat, 24 Apr 2021 16:52:00 +0200 Subject: [PATCH 2/4] astensor manage non tensor arguments and behave like identity if the argument is not a tensor --- eagerpy/astensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index cdab2f0..3df1b68 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -48,7 +48,7 @@ def astensor(x: NativeTensor) -> Tensor: # type: ignore ... -def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore +def astensor(x: Union[NativeTensor, Tensor, Any]) -> Union[Tensor, Any]: # type: ignore if isinstance(x, Tensor): return x # we use the module name instead of isinstance @@ -64,7 +64,8 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x) - raise ValueError(f"Unknown type: {type(x)}") + return x + # raise ValueError(f"Unknown type: {type(x)}") def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore From 3de88207e4f8768b7b18c189f56b7b7f669adcde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20S=C3=A9ri=C3=A9?= Date: Sat, 24 Apr 2021 17:06:53 +0200 Subject: [PATCH 3/4] new function astensors_any Register all Tensor types as pytrees. Handle unwraping mechanism within unflatten method. --- eagerpy/astensor.py | 88 ++++++++++++++++++++++++++++++------------ eagerpy/tensor/base.py | 29 +++++++++++++- eagerpy/tensor/jax.py | 1 - tests/test_main.py | 19 +++++++++ 4 files changed, 110 insertions(+), 27 deletions(-) diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index 3df1b68..1900de5 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -64,36 +64,18 @@ def astensor(x: Union[NativeTensor, Tensor, Any]) -> Union[Tensor, Any]: # type return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x) + + # non Tensor types are returned unmodified return x - # raise ValueError(f"Unknown type: {type(x)}") def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore return tuple(astensor(x) for x in xs) -def as_tensors(data: Any) -> Any: - leaf_values, tree_def = tree_flatten(data) - leaf_values = tuple(astensor(value) for value in leaf_values) - return tree_unflatten(tree_def, leaf_values) - - T = TypeVar("T") -def as_raw_tensor(x: T) -> Any: - if isinstance(x, Tensor): - return x.raw - else: - return x - - -def as_raw_tensors(data: Any) -> Any: - leaf_values, tree_def = tree_flatten(data) - leaf_values = tuple(as_raw_tensor(value) for value in leaf_values) - return tree_unflatten(tree_def, leaf_values) - - class RestoreTypeFunc(Generic[T]): def __init__(self, x: T): self.unwrap = not isinstance(x, Tensor) @@ -130,10 +112,65 @@ def astensors_(x: T, *xs: T) -> Tuple[Tuple[Tensor, ...], RestoreTypeFunc[T]]: return astensors(x, *xs), RestoreTypeFunc[T](x) -def as_tensors_(data: Any) -> Any: +def as_tensors(data: Any) -> Any: leaf_values, tree_def = tree_flatten(data) - leaf_values, restore_type = astensors_(*leaf_values) - return tree_unflatten(tree_def, leaf_values), restore_type + leaf_values = tuple(astensor(value) for value in leaf_values) + return tree_unflatten(tree_def, leaf_values) + + +def has_tensor(tree_def: Any) -> bool: + return " Tuple[Any, bool]: + """Convert data structure leaves in Tensor and detect if any of the input data contains a Tensor. + + Parameters + ---------- + data + data structure. + + Returns + ------- + Any + modified data structure. + bool + True if input data contains a Tensor type. + """ + leaf_values, tree_def = tree_flatten(data) + transformed_leaf_values = tuple(astensor(value) for value in leaf_values) + return tree_unflatten(tree_def, transformed_leaf_values), has_tensor(tree_def) + + +def as_raw_tensor(x: T) -> Any: + if isinstance(x, Tensor): + return x.raw + else: + return x + + +def as_raw_tensors(data: Any) -> Any: + leaf_values, tree_def = tree_flatten(data) + + if not has_tensor(tree_def): + return data + + leaf_values = tuple(as_raw_tensor(value) for value in leaf_values) + unwrap_leaf_values = [] + for x in leaf_values: + name = _get_module_name(x) + m = sys.modules + if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore + unwrap_leaf_values.append((x, True)) + elif name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore + unwrap_leaf_values.append((x, True)) + elif (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore + unwrap_leaf_values.append((x, True)) + elif name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore + unwrap_leaf_values.append((x, True)) + else: + unwrap_leaf_values.append(x) + return tree_unflatten(tree_def, unwrap_leaf_values) def eager_function( @@ -145,14 +182,15 @@ def eager_func(*args: Any, **kwargs: Any) -> Any: skip_args = [arg for i, arg in enumerate(args) if i in sorted_skip_argnums] kept_args = [arg for i, arg in enumerate(args) if i not in sorted_skip_argnums] - (kept_args, kwargs), restore_type = as_tensors_((kept_args, kwargs)) + (kept_args, kwargs), has_tensor = as_tensors_any((kept_args, kwargs)) + unwrap = not has_tensor for i, arg in zip(sorted_skip_argnums, skip_args): kept_args.insert(i, arg) result = func(*kept_args, **kwargs) - if restore_type.unwrap: + if unwrap: raw_result = as_raw_tensors(result) return raw_result else: diff --git a/eagerpy/tensor/base.py b/eagerpy/tensor/base.py index 41a7da3..f1e2953 100644 --- a/eagerpy/tensor/base.py +++ b/eagerpy/tensor/base.py @@ -1,5 +1,5 @@ from typing_extensions import final -from typing import Any, cast +from typing import Any, Type, Union, cast, Tuple from .tensor import Tensor from .tensor import TensorType @@ -17,6 +17,33 @@ def unwrap1(t: Any) -> Any: class BaseTensor(Tensor): __slots__ = "_raw" + _registered = False + + def __new__(cls: Type["BaseTensor"], *args: Any, **kwargs: Any) -> "BaseTensor": + if not cls._registered: + import jax + + def flatten(t: Tensor) -> Tuple[Any, None]: + return ((t.raw,), None) + + def unflatten(aux_data: None, children: Tuple) -> Union[Tensor, Any]: + assert len(children) == 1 + x = children[0] + del children + + if isinstance(x, tuple): + x, unwrap = x + if unwrap: + return x + + if isinstance(x, Tensor): + return x + return cls(x) + + jax.tree_util.register_pytree_node(cls, flatten, unflatten) + cls._registered = True + return cast("BaseTensor", super().__new__(cls)) + def __init__(self: TensorType, raw: Any): assert not isinstance(raw, Tensor) self._raw = raw diff --git a/eagerpy/tensor/jax.py b/eagerpy/tensor/jax.py index b3e868b..74ae75e 100644 --- a/eagerpy/tensor/jax.py +++ b/eagerpy/tensor/jax.py @@ -58,7 +58,6 @@ def getitem_preprocess(x: Any) -> Any: class JAXTensor(BaseTensor): __slots__ = () - # more specific types for the extensions norms: "NormsMethods[JAXTensor]" diff --git a/tests/test_main.py b/tests/test_main.py index 7d0d701..f1c044b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1661,3 +1661,22 @@ def test_eager_function_on_method(t: Tensor, astensor: bool) -> Tensor: result = MyClass().my_universal_method(a, b, c) assert isinstance(result, type(a)) return ep.astensor(result) + + +@eager_function +def my_universal_function_with_non_tensors(a: int, b: Tensor, c: Tensor) -> Tensor: + return (a + b * c).square() + + +@pytest.mark.parametrize("astensor", [False, True]) +@compare_all +def test_eager_function_with_non_tensors(t: Tensor, astensor: bool) -> Tensor: + if astensor: + b = t + else: + b = t.raw + a = 3 + c = b + result = my_universal_function_with_non_tensors(a, b, c) + assert isinstance(result, type(b)) + return ep.astensor(result) From 45832603ea422ef15b23c63b2fd5b24f78fea326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emmanuel=20S=C3=A9ri=C3=A9?= Date: Mon, 26 Apr 2021 08:24:52 +0200 Subject: [PATCH 4/4] remove skip_argnums argument since as_tensors_any manage non pytree objects --- eagerpy/astensor.py | 17 +++-------------- tests/test_main.py | 2 +- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index 1900de5..61eaa21 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -173,23 +173,12 @@ def as_raw_tensors(data: Any) -> Any: return tree_unflatten(tree_def, unwrap_leaf_values) -def eager_function( - func: Callable[..., T], skip_argnums: Tuple = tuple() -) -> Callable[..., T]: +def eager_function(func: Callable[..., T]) -> Callable[..., T]: @functools.wraps(func) def eager_func(*args: Any, **kwargs: Any) -> Any: - sorted_skip_argnums = sorted(skip_argnums) - skip_args = [arg for i, arg in enumerate(args) if i in sorted_skip_argnums] - kept_args = [arg for i, arg in enumerate(args) if i not in sorted_skip_argnums] - - (kept_args, kwargs), has_tensor = as_tensors_any((kept_args, kwargs)) + (args, kwargs), has_tensor = as_tensors_any((args, kwargs)) unwrap = not has_tensor - - for i, arg in zip(sorted_skip_argnums, skip_args): - kept_args.insert(i, arg) - - result = func(*kept_args, **kwargs) - + result = func(*args, **kwargs) if unwrap: raw_result = as_raw_tensors(result) return raw_result diff --git a/tests/test_main.py b/tests/test_main.py index f1c044b..f62ef76 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1643,7 +1643,7 @@ def test_eager_function_return_non_registered_datastruct( # define a non-registered pytree container. class MyClass: - @functools.partial(eager_function, skip_argnums=(0,)) + @eager_function def my_universal_method(self, a: Tensor, b: Tensor, c: Tensor) -> Any: res = (a + b * c).square() return res