Skip to content

Commit

Permalink
Fix UCObjectStore.list_objects (mosaicml#3284)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 13, 2024
1 parent 134ae12 commit 01eec3a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 85 deletions.
39 changes: 17 additions & 22 deletions composer/utils/object_store/uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -96,16 +95,20 @@ def validate_path(path: str) -> str:
# The first 4 dirs form the prefix
return os.path.join(*dirs[:4])

def _get_object_path(self, object_name: str) -> str:
def _get_object_path(self, object_name: Optional[str] = None) -> str:
"""Return the absolute Single Path Namespace for the given object_name.
Args:
object_name (str): Absolute or relative path of the object w.r.t. the
UC Volumes root.
object_name (optional, str): Absolute or relative path of the object w.r.t. the
UC Volumes root. If None, the prefix path is returned.
"""
# convert object name to relative path if prefix is included
if os.path.commonprefix([object_name, self.prefix]) == self.prefix:
if object_name is not None and os.path.commonprefix([object_name, self.prefix]) == self.prefix:
object_name = os.path.relpath(object_name, start=self.prefix)

if object_name is None:
return os.path.join('/', self.prefix)

return os.path.join('/', self.prefix, object_name)

def get_uri(self, object_name: str) -> str:
Expand Down Expand Up @@ -241,34 +244,26 @@ def list_objects(self, prefix: Optional[str]) -> List[str]:

from databricks.sdk.core import DatabricksError
try:
# NOTE: This API is in preview and should not be directly used outside of this instance
logging.warn('UCObjectStore.list_objects is experimental.')

# Iteratively get all UC Volume files with `prefix`.
stack = [prefix]
all_files = []

while len(stack) > 0:
current_path = stack.pop()

# Note: Databricks SDK handles HTTP errors and retries.
# See https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/core.py#L125 and
# https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/retries.py#L33 .
resp = self.client.api_client.do(
method='GET',
path=self._UC_VOLUME_LIST_API_ENDPOINT,
data=json.dumps({'path': self._get_object_path(current_path)}),
headers={'Source': 'mosaicml/composer'},
ls_results = self.client.files.list_directory_contents(
directory_path=self._get_object_path(current_path),
)

assert isinstance(resp, dict), 'Response is not a dictionary'
for dir_entry in ls_results:
path = dir_entry.path
is_directory = dir_entry.is_directory
assert isinstance(path, str)

for f in resp.get('files', []):
fpath = f['path']
if f['is_dir']:
stack.append(fpath)
if is_directory:
stack.append(path)
else:
all_files.append(fpath)
all_files.append(path)

return all_files

Expand Down
116 changes: 53 additions & 63 deletions tests/utils/object_store/test_uc_object_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, MagicMock
Expand Down Expand Up @@ -179,108 +180,97 @@ def generate_dummy_file(_):


def test_list_objects_nested_folders(ws_client, uc_object_store):
from databricks.sdk.service.files import DirectoryEntry

expected_files = [
'/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
'/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
'/Volumes/catalog/volume/schema/path/to/folder/subdir/file1.txt',
'/Volumes/catalog/volume/schema/path/to/folder/subdir/file2.txt',
]
uc_list_api_responses = [
{
'files': [
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
'is_dir': False,
},
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
'is_dir': False,
},
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir',
'is_dir': True,
},
],
},
{
'files': [
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir/file1.txt',
'is_dir': False,
},
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir/file2.txt',
'is_dir': False,
},
],
},
[
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
is_directory=False,
),
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
is_directory=False,
),
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/subdir',
is_directory=True,
),
],
[
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/subdir/file1.txt',
is_directory=False,
),
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/subdir/file2.txt',
is_directory=False,
),
],
]

prefix = 'Volumes/catalog/schema/volume/path/to/folder'

ws_client.api_client.do = MagicMock(side_effect=[uc_list_api_responses[0], uc_list_api_responses[1]])
ws_client.files.list_directory_contents = MagicMock(
side_effect=[uc_list_api_responses[0], uc_list_api_responses[1]],
)
actual_files = uc_object_store.list_objects(prefix=prefix)

assert actual_files == expected_files

ws_client.api_client.do.assert_called_with(
method='GET',
path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT,
data='{"path": "/Volumes/catalog/volume/schema/path/to/folder/subdir"}',
headers={'Source': 'mosaicml/composer'},
ws_client.files.list_directory_contents.assert_called_with(
directory_path='/Volumes/catalog/volume/schema/path/to/folder/subdir',
)

assert ws_client.api_client.do.call_count == 2
assert ws_client.files.list_directory_contents.call_count == 2


@pytest.mark.parametrize('result', ['success', 'prefix_none', 'not_found', 'error'])
def test_list_objects(ws_client, uc_object_store, result):
from databricks.sdk.service.files import DirectoryEntry

expected_files = [
'/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
'/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
]
uc_list_api_response = {
'files': [
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
'is_dir': False,
},
{
'path': '/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
'is_dir': False,
},
],
}
uc_list_api_response = [
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
is_directory=False,
),
DirectoryEntry(
path='/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
is_directory=False,
),
]

prefix = 'Volumes/catalog/schema/volume/path/to/folder'

if result == 'success':
ws_client.api_client.do.return_value = uc_list_api_response
ws_client.files.list_directory_contents.return_value = uc_list_api_response
actual_files = uc_object_store.list_objects(prefix=prefix)

assert actual_files == expected_files
ws_client.api_client.do.assert_called_once_with(
method='GET',
path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT,
data='{"path": "/Volumes/catalog/schema/volume/path/to/folder"}',
headers={'Source': 'mosaicml/composer'},
)
expected_call_prefix = os.path.join('/', prefix)
ws_client.files.list_directory_contents.assert_called_once_with(directory_path=expected_call_prefix,)

elif result == 'prefix_none':
ws_client.api_client.do.return_value = uc_list_api_response
ws_client.files.list_directory_contents.return_value = uc_list_api_response
actual_files = uc_object_store.list_objects(prefix=None)

assert actual_files == expected_files
ws_client.api_client.do.assert_called_once_with(
method='GET',
path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT,
data='{"path": "/Volumes/catalog/schema/volume/."}',
headers={'Source': 'mosaicml/composer'},
)
expected_call_prefix = '/Volumes/catalog/schema/volume/.'
ws_client.files.list_directory_contents.assert_called_once_with(directory_path=expected_call_prefix,)

elif result == 'not_found':
db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks')
ws_client.api_client.do.side_effect = db_core.DatabricksError(
ws_client.files.list_directory_contents.side_effect = db_core.DatabricksError(
'The path you provided does not exist or is not a directory.',
error_code='NOT_FOUND',
)
Expand All @@ -289,7 +279,7 @@ def test_list_objects(ws_client, uc_object_store, result):

elif result == 'error':
db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks')
ws_client.api_client.do.side_effect = db_core.DatabricksError
ws_client.files.list_directory_contents.side_effect = db_core.DatabricksError

with pytest.raises(ObjectStoreTransientError):
uc_object_store.list_objects(prefix=prefix)
Expand Down

0 comments on commit 01eec3a

Please sign in to comment.