Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom PyTree metadata. Standardize naming of the "custom metadata" field (user-supplied metadata) as custom_metadata. #1461

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
make the change unnoticeable to most users, but also has additional accessible
properties not included in any tree mapping operations.
- `Checkpointer.save()`, `AsyncCheckpointer.save()` also saves `StepMetadata`.
- Added github actions CI testing using Python versions 3.10-3.13
- Added github actions CI testing using Python versions 3.10-3.13.
- Standardize naming of the "custom metadata" field (user-supplied metadata) as
`custom_metadata`.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between Checkpointing layers to enable async
directory creation.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- User-provided custom PyTree metadata.

### Fixed
- Fix a bug where snapshots are not released by `wait_for_new_checkpoint`
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ py_library(
name = "tree",
srcs = ["tree.py"],
deps = [
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint/_src/tree:types",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _save_step_metadata(
):
"""Saves StepMetadata to the checkpoint directory."""
update_dict = {
'custom': custom_metadata,
'custom_metadata': custom_metadata,
}
if isinstance(
self._handler, composite_checkpoint_handler.CompositeCheckpointHandler
Expand Down
3 changes: 3 additions & 0 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
],
Expand All @@ -91,6 +92,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
Expand Down Expand Up @@ -170,6 +172,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -444,6 +445,7 @@ async def async_save(
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
custom_metadata = args.custom_metadata

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
Expand Down Expand Up @@ -491,6 +493,7 @@ async def async_save(
checkpoint_dir=directory,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=self._use_zarr3,
)
)
Expand Down Expand Up @@ -799,8 +802,10 @@ def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo:
def _write_metadata_file(
self,
directory: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: tree_types.JsonType | None = None,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn(param_infos):
Expand All @@ -811,6 +816,7 @@ def _save_fn(param_infos):
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
custom_metadata=custom_metadata,
pytree_metadata_options=self._pytree_metadata_options,
)
logging.vlog(
Expand All @@ -832,8 +838,10 @@ def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: tree_types.JsonType | None = None,
use_zarr3: bool,
) -> None:
if not utils.is_primary_host(self._primary_host):
Expand All @@ -853,7 +861,11 @@ def _write_metadata_after_commits(
param_infos, checkpoint_dir, self._array_metadata_store
)
self._write_metadata_file(
checkpoint_dir, param_infos, save_args, use_zarr3
checkpoint_dir,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=use_zarr3,
).result()

def _read_metadata_file(
Expand Down Expand Up @@ -915,12 +927,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
internal_tree_metadata = self._read_metadata_file(directory)
return tree_metadata.build_default_tree_metadata(
self._read_metadata_file(directory).as_user_metadata(
internal_tree_metadata.as_custom_metadata(
directory,
self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
),
custom_metadata=internal_tree_metadata.custom_metadata,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down Expand Up @@ -972,12 +986,16 @@ class BasePyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None


@register_with_handler(BasePyTreeCheckpointHandler, for_restore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def test_metadata_no_save(self, use_handler_registry):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

def test_metadata_handler_registry(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
Expand Down Expand Up @@ -779,7 +779,7 @@ def test_metadata_handler_registry(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

def test_metadata_after_step_metadata_write(self):
handler = CompositeCheckpointHandler(
Expand All @@ -795,7 +795,7 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.custom_metadata)

metadata_to_write = checkpoint.StepMetadata(
item_handlers={
Expand All @@ -813,7 +813,7 @@ def test_metadata_after_step_metadata_write(self):
),
init_timestamp_nsecs=1000,
commit_timestamp_nsecs=2000,
custom={
custom_metadata={
'custom_key': 'custom_value',
},
)
Expand All @@ -837,7 +837,9 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertEqual(step_metadata.init_timestamp_nsecs, 1000)
self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000)
self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'})
self.assertEqual(
step_metadata.custom_metadata, {'custom_key': 'custom_value'}
)

def test_metadata_existing_items_updates_step_metadata(self):
handler = CompositeCheckpointHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -431,6 +432,7 @@ def _get_impl_save_args(
save_args=args.save_args,
ocdbt_target_data_file_size=args.ocdbt_target_data_file_size,
enable_pinned_host_transfer=args.enable_pinned_host_transfer,
custom_metadata=args.custom_metadata,
)


Expand Down Expand Up @@ -1052,12 +1054,16 @@ class PyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -144,15 +145,19 @@ async def async_save(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
custom_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
custom_metadata = args.custom_metadata

self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item, save_args=save_args
item=item,
save_args=save_args,
custom_metadata=custom_metadata,
),
)

Expand Down Expand Up @@ -266,10 +271,14 @@ class StandardSaveArgs(CheckpointArgs):
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""

item: PyTree
save_args: Optional[PyTree] = None
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test for standard_checkpoint_handler.py."""

# pylint: disable=protected-access, missing-function-docstring

import functools
from typing import Any

Expand Down Expand Up @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self):
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_shape_dtype_struct(self):
"""Test case."""
self.handler.save(
self.directory, args=self.save_args_cls(self.mixed_pytree)
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_custom_layout(self):
custom_layout = Layout(
device_local_layout=DLL(
major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error
),
sharding=arr.sharding,
)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_custom_layout(self):

@parameterized.parameters((True,), (False,))
def test_change_shape(self, strict: bool):
"""Test case."""
if not hasattr(self.restore_args_cls, 'strict'):
self.skipTest('strict option not supported for this handler')
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',))
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self):
self.handler.restore(self.directory, args=self.restore_args_cls(pytree))

def test_cast(self):
"""Test case."""
# TODO(dicentra): casting from int dtypes currently doesn't work
# in the model surgery context.
save_args = jax.tree.map(
Expand Down Expand Up @@ -289,7 +288,6 @@ def check_dtype(x, dtype):
jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored)

def test_flax_model(self):
"""Test case."""

@flax.struct.dataclass
class Params(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -318,12 +316,10 @@ def make_params():
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -332,7 +328,6 @@ def test_empty_dict_node(self):
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -341,7 +336,6 @@ def test_empty_none_node(self):
self.assertDictEqual(restored, item)

def test_none_node_in_restore_args(self):
"""Test case."""
devices = np.asarray(jax.devices())
mesh = jax.sharding.Mesh(devices, ('x',))
mesh_axes = jax.sharding.PartitionSpec(
Expand All @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self):
test_utils.assert_tree_equal(self, restored, {'b': None})

def test_masked_shape_dtype_struct(self):
"""Test case."""

def _should_mask(keypath):
return keypath[0].key == 'a' or (
Expand Down Expand Up @@ -398,3 +391,12 @@ def _none(keypath, x):
# Restore it without any item.
restored = self.handler.restore(self.directory)
test_utils.assert_tree_equal(self, expected, restored)

def test_custom_metadata(self):
custom_metadata = {'foo': 1}
self.handler.save(
self.directory,
args=self.save_args_cls(self.pytree, custom_metadata=custom_metadata),
)
metadata = self.handler.metadata(self.directory)
self.assertEqual(metadata.custom_metadata, custom_metadata)
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
],
)
Expand Down
Loading
Loading