diff --git a/Dockerfile.dev b/Dockerfile.dev index 1a839104e4..5d2b622f01 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -40,6 +40,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ -e /flytekit \ -e /flytekit/plugins/flytekit-deck-standard \ -e /flytekit/plugins/flytekit-flyteinteractive \ + obstore==0.3.0b9 \ markdown \ pandas \ pillow \ diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 0640bc2eb5..6adc18755e 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -25,18 +25,20 @@ import tempfile import typing from time import sleep -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast from uuid import UUID import fsspec from decorator import decorator from fsspec.asyn import AsyncFileSystem from fsspec.utils import get_protocol +from obstore.store import AzureStore, GCSStore, S3Store from typing_extensions import Unpack from flytekit import configuration from flytekit.configuration import DataConfig from flytekit.core.local_fsspec import FlyteLocalFileSystem +from flytekit.core.obstore_filesystem import ObstoreAzureBlobFileSystem, ObstoreGCSFileSystem, ObstoreS3FileSystem from flytekit.core.utils import timeit from flytekit.exceptions.system import FlyteDownloadDataException, FlyteUploadDataException from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException @@ -44,49 +46,128 @@ from flytekit.loggers import logger from flytekit.utils.asyn import loop_manager -# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11 # for key and secret -_FSSPEC_S3_KEY_ID = "key" -_FSSPEC_S3_SECRET = "secret" -_ANON = "anon" +_FSSPEC_S3_KEY_ID = "access_key_id" +_FSSPEC_S3_SECRET = "secret_access_key" +_ANON = "skip_signature" Uploadable = typing.Union[str, os.PathLike, pathlib.Path, bytes, io.BufferedReader, io.BytesIO, io.StringIO] -def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]: - kwargs: Dict[str, Any] = { - "cache_regions": True, - } +def s3_setup_args(s3_cfg: configuration.S3Config, bucket: str = "", anonymous: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + store_kwargs: Dict[str, Any] = {} + if s3_cfg.access_key_id: - kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id + store_kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id if s3_cfg.secret_access_key: - kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key + store_kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key # S3fs takes this as a special arg if s3_cfg.endpoint is not None: - kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} + store_kwargs["endpoint_url"] = s3_cfg.endpoint + if anonymous: + store_kwargs[_ANON] = "true" + + store = S3Store.from_env( + bucket, + config={ + **store_kwargs, + "aws_allow_http": "true", # Allow HTTP connections + "aws_virtual_hosted_style_request": "false", # Use path-style addressing + }, + ) + + kwargs["retries"] = s3_cfg.retries + + kwargs["store"] = store + + return kwargs + + +def gs_setup_args(gcs_cfg: configuration.GCSConfig, bucket: str = "", anonymous: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + + store = GCSStore.from_env( + bucket, + ) if anonymous: - kwargs[_ANON] = True + kwargs["token"] = _ANON + + kwargs["store"] = store return kwargs -def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: +def split_path(path: str) -> Tuple[str, str]: + """ + Split bucket and file path + + Parameters + ---------- + path : string + Input path, like `s3://mybucket/path/to/file` + + Examples + -------- + >>> split_path("s3://mybucket/path/to/file") + ['mybucket', 'path/to/file'] + """ + support_types = ["s3", "gs", "abfs"] + protocol = get_protocol(path) + if protocol not in support_types: + # no bucket for file + return "", path + + if path.startswith(protocol + "://"): + path = path[len(protocol) + 3 :] + elif path.startswith(protocol + "::"): + path = path[len(protocol) + 2 :] + path = path.strip("/") + + if "/" not in path: + return path, "" + else: + path_li = path.split("/") + bucket = path_li[0] + # use obstore for s3 and gcs only now, no need to split + # bucket out of path for other storage + file_path = "/".join(path_li[1:]) + return (bucket, file_path) + + +def azure_setup_args( + azure_cfg: configuration.AzureBlobStorageConfig, container: str = "", anonymous: bool = False +) -> Dict[str, Any]: kwargs: Dict[str, Any] = {} + store_kwargs: Dict[str, Any] = {} if azure_cfg.account_name: - kwargs["account_name"] = azure_cfg.account_name + store_kwargs["account_name"] = azure_cfg.account_name if azure_cfg.account_key: - kwargs["account_key"] = azure_cfg.account_key + store_kwargs["account_key"] = azure_cfg.account_key if azure_cfg.client_id: - kwargs["client_id"] = azure_cfg.client_id + store_kwargs["client_id"] = azure_cfg.client_id if azure_cfg.client_secret: - kwargs["client_secret"] = azure_cfg.client_secret + store_kwargs["client_secret"] = azure_cfg.client_secret if azure_cfg.tenant_id: - kwargs["tenant_id"] = azure_cfg.tenant_id - kwargs[_ANON] = anonymous + store_kwargs["tenant_id"] = azure_cfg.tenant_id + if anonymous: + kwargs[_ANON] = "true" + + store = AzureStore.from_env( + container, + config={ + **store_kwargs, + }, + ) + + kwargs["store"] = store + + return kwargs @@ -189,21 +270,27 @@ def get_filesystem( protocol: typing.Optional[str] = None, anonymous: bool = False, path: typing.Optional[str] = None, + bucket: str = "", **kwargs, ) -> fsspec.AbstractFileSystem: + # TODO: add bucket to adlfs if not protocol: return self._default_remote if protocol == "file": kwargs["auto_mkdir"] = True return FlyteLocalFileSystem(**kwargs) elif protocol == "s3": - s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs = s3_setup_args(self._data_config.s3, bucket, anonymous=anonymous) s3kwargs.update(kwargs) return fsspec.filesystem(protocol, **s3kwargs) # type: ignore elif protocol == "gs": - if anonymous: - kwargs["token"] = _ANON - return fsspec.filesystem(protocol, **kwargs) # type: ignore + gskwargs = gs_setup_args(self._data_config.gcs, bucket, anonymous=anonymous) + gskwargs.update(kwargs) + return fsspec.filesystem(protocol, **gskwargs) # type: ignore + elif protocol == "abfs": + azkwargs = azure_setup_args(self._data_config.azure, bucket, anonymous=anonymous) + azkwargs.update(kwargs) + return fsspec.filesystem(protocol, **azkwargs) # type: ignore elif protocol == "ftp": kwargs.update(fsspec.implementations.ftp.FTPFileSystem._get_kwargs_from_urls(path)) return fsspec.filesystem(protocol, **kwargs) @@ -216,16 +303,20 @@ def get_filesystem( return fsspec.filesystem(protocol, **kwargs) async def get_async_filesystem_for_path( - self, path: str = "", anonymous: bool = False, **kwargs + self, path: str = "", bucket: str = "", anonymous: bool = False, **kwargs ) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]: protocol = get_protocol(path) loop = asyncio.get_running_loop() - return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs) + return self.get_filesystem( + protocol, anonymous=anonymous, path=path, bucket=bucket, asynchronous=True, loop=loop, **kwargs + ) - def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: + def get_filesystem_for_path( + self, path: str = "", bucket: str = "", anonymous: bool = False, **kwargs + ) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) - return self.get_filesystem(protocol, anonymous=anonymous, path=path, **kwargs) + return self.get_filesystem(protocol, anonymous=anonymous, path=path, bucket=bucket, **kwargs) @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: @@ -295,7 +386,8 @@ def exists(self, path: str) -> bool: @retry_request async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = await self.get_async_filesystem_for_path(from_path) + bucket, from_path_file_only = split_path(from_path) + file_system = await self.get_async_filesystem_for_path(from_path, bucket) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) try: @@ -307,7 +399,7 @@ async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwa ) logger.info(f"Getting {from_path} to {to_path}") if isinstance(file_system, AsyncFileSystem): - dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + dst = await file_system._get(from_path_file_only, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 else: dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): @@ -336,7 +428,8 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw More of an internal function to be called by put_data and put_raw_data This does not need a separate sync function. """ - file_system = await self.get_async_filesystem_for_path(to_path) + bucket, to_path_file_only = split_path(to_path) + file_system = await self.get_async_filesystem_for_path(to_path, bucket) from_path = self.strip_file_header(from_path) if recursive: # Only check this for the local filesystem @@ -354,7 +447,7 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw kwargs["metadata"] = {} kwargs["metadata"].update(self._execution_metadata) if isinstance(file_system, AsyncFileSystem): - dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + dst = await file_system._put(from_path, to_path_file_only, recursive=recursive, **kwargs) # pylint: disable=W0212 else: dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): @@ -423,11 +516,13 @@ async def async_put_raw_data( r = await self._put(from_path, to_path, **kwargs) return r or to_path + bucket, _ = split_path(to_path) + # See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to # support effectively async open(). For now these use-cases below will revert to sync calls. # raw bytes if isinstance(lpath, bytes): - fs = self.get_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path, bucket) with fs.open(to_path, "wb", **kwargs) as s: s.write(lpath) return to_path @@ -436,7 +531,7 @@ async def async_put_raw_data( if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path, bucket) lpath.seek(0) with fs.open(to_path, "wb", **kwargs) as s: while data := lpath.read(read_chunk_size_bytes): @@ -446,7 +541,7 @@ async def async_put_raw_data( if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path, bucket) lpath.seek(0) with fs.open(to_path, "wb", **kwargs) as s: while data_str := lpath.read(read_chunk_size_bytes): @@ -635,6 +730,10 @@ async def async_put_data( put_data = loop_manager.synced(async_put_data) +fsspec.register_implementation("s3", ObstoreS3FileSystem) +fsspec.register_implementation("gs", ObstoreGCSFileSystem) +fsspec.register_implementation("abfs", ObstoreAzureBlobFileSystem) + flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( local_sandbox_dir=os.path.join(flyte_tmp_dir, "sandbox"), diff --git a/flytekit/core/obstore_filesystem.py b/flytekit/core/obstore_filesystem.py new file mode 100644 index 0000000000..5d5fb8d77a --- /dev/null +++ b/flytekit/core/obstore_filesystem.py @@ -0,0 +1,56 @@ +""" +Classes that overrides the AsyncFsspecStore that specify the filesystem specific parameters +""" + +from typing import Optional + +from obstore.fsspec import AsyncFsspecStore + +DEFAULT_BLOCK_SIZE = 5 * 2**20 + + +class ObstoreS3FileSystem(AsyncFsspecStore): + """ + Add following property used in S3FileSystem + """ + + root_marker = "" + connect_timeout = 5 + retries = 5 + read_timeout = 15 + default_block_size = DEFAULT_BLOCK_SIZE + protocol = ("s3", "s3a") + _extra_tokenize_attributes = ("default_block_size",) + + def __init__(self, retries: Optional[int] = None, **kwargs): + """ + Initialize the ObstoreS3FileSystem with optional retries. + + Args: + retries (int): Number of retry for requests + **kwargs: Other keyword arguments passed to the parent class + """ + if retries is not None: + self.retries = retries + + super().__init__(**kwargs) + + +class ObstoreGCSFileSystem(AsyncFsspecStore): + """ + Add following property used in GCSFileSystem + """ + + scopes = {"read_only", "read_write", "full_control"} + retries = 6 # number of retries on http failure + default_block_size = DEFAULT_BLOCK_SIZE + protocol = "gcs", "gs" + async_impl = True + + +class ObstoreAzureBlobFileSystem(AsyncFsspecStore): + """ + Add following property used in AzureBlobFileSystem + """ + + protocol = "abfs" diff --git a/pyproject.toml b/pyproject.toml index 3dc782c507..731de3d39d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.15", "msgpack>=1.1.0", + "obstore==0.3.0b10", "protobuf!=4.25.0", "pygments", "python-json-logger>=2.0.0", diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 42e74f453c..0d16330e3d 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -5,14 +5,20 @@ from uuid import UUID import typing import asyncio +from botocore.parsers import base64 import fsspec import mock +from obstore.store import S3Store import pytest from s3fs import S3FileSystem from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager, FlyteContext -from flytekit.core.data_persistence import FileAccessProvider, get_fsspec_storage_options, s3_setup_args +from flytekit.core.data_persistence import ( + FileAccessProvider, + get_fsspec_storage_options, + s3_setup_args, +) from flytekit.core.type_engine import TypeEngine from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile @@ -32,15 +38,21 @@ def test_path_getting(mock_uuid_class, mock_gcs): # Testing with raw output prefix pointing to a local path loc_sandbox = os.path.join(root, "tmp", "unittest") loc_data = os.path.join(root, "tmp", "unittestdata") - local_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data) + local_raw_fp = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data + ) r = local_raw_fp.get_random_string() rr = local_raw_fp.join(local_raw_fp.raw_output_prefix, r) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123") - rr = local_raw_fp.join(local_raw_fp.raw_output_prefix, r, local_raw_fp.get_file_tail("/fsa/blah.csv")) + rr = local_raw_fp.join( + local_raw_fp.raw_output_prefix, r, local_raw_fp.get_file_tail("/fsa/blah.csv") + ) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123", "blah.csv") # Test local path and directory - assert local_raw_fp.get_random_local_path() == os.path.join(root, "tmp", "unittest", "local_flytekit", "abcdef123") + assert local_raw_fp.get_random_local_path() == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123" + ) assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == os.path.join( root, "tmp", "unittest", "local_flytekit", "abcdef123", "blah.txt" ) @@ -49,20 +61,28 @@ def test_path_getting(mock_uuid_class, mock_gcs): ) # Recursive paths - assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + assert ( + "file:///abc/happy/" + ), "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" ) - assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + assert ( + "file:///abc/happy/" + ), "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( "file:///abc/happy", "s3://my-s3-bucket/bucket1" ) # Test with remote pointed to s3. - s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket") + s3_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket" + ) r = s3_fa.get_random_string() rr = s3_fa.join(s3_fa.raw_output_prefix, r) assert rr == "s3://my-s3-bucket/abcdef123" # trailing slash should make no difference - s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/") + s3_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/" + ) r = s3_fa.get_random_string() rr = s3_fa.join(s3_fa.raw_output_prefix, r) assert rr == "s3://my-s3-bucket/abcdef123" @@ -70,17 +90,23 @@ def test_path_getting(mock_uuid_class, mock_gcs): # Testing with raw output prefix pointing to file:// # Skip tests for windows if os.name != "nt": - file_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata") + file_raw_fp = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata" + ) r = file_raw_fp.get_random_string() rr = file_raw_fp.join(file_raw_fp.raw_output_prefix, r) rr = file_raw_fp.strip_file_header(rr) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123") r = file_raw_fp.get_random_string() - rr = file_raw_fp.join(file_raw_fp.raw_output_prefix, r, file_raw_fp.get_file_tail("/fsa/blah.csv")) + rr = file_raw_fp.join( + file_raw_fp.raw_output_prefix, r, file_raw_fp.get_file_tail("/fsa/blah.csv") + ) rr = file_raw_fp.strip_file_header(rr) assert rr == os.path.join(root, "tmp", "unittestdata", "abcdef123", "blah.csv") - g_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/") + g_fa = FileAccessProvider( + local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/" + ) r = g_fa.get_random_string() rr = g_fa.join(g_fa.raw_output_prefix, r) assert rr == "gs://my-s3-bucket/abcdef123" @@ -119,7 +145,11 @@ async def test_local_provider(source_folder): # dest folder exists. dc = Config.for_sandbox().data_config with tempfile.TemporaryDirectory() as dest_tmpdir: - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=dest_tmpdir, + data_config=dc, + ) r = provider.get_random_string() doesnotexist = provider.join(provider.raw_output_prefix, r) await provider.async_put_data(source_folder, doesnotexist, is_multipart=True) @@ -176,9 +206,13 @@ def test_s3_provider(source_folder): # Running mkdir on s3 filesystem doesn't do anything so leaving out for now dc = Config.for_sandbox().data_config provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix="s3://my-s3-bucket/testdata/", + data_config=dc, + ) + doesnotexist = provider.join( + provider.raw_output_prefix, provider.get_random_string() ) - doesnotexist = provider.join(provider.raw_output_prefix, provider.get_random_string()) provider.put_data(source_folder, doesnotexist, is_multipart=True) fs = provider.get_filesystem_for_path(doesnotexist) files = fs.find(doesnotexist) @@ -190,7 +224,9 @@ def test_local_provider_get_empty(): with tempfile.TemporaryDirectory() as empty_source: with tempfile.TemporaryDirectory() as dest_folder: provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=empty_source, + data_config=dc, ) provider.get_data(empty_source, dest_folder, is_multipart=True) loc = provider.get_filesystem_for_path(dest_folder) @@ -202,18 +238,29 @@ def test_local_provider_get_empty(): @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") -def test_s3_setup_args_env_empty(mock_os, mock_get_config_file): +@mock.patch("obstore.store.S3Store.from_env") +def test_s3_setup_args_env_empty(mock_from_env, mock_os, mock_get_config_file): mock_get_config_file.return_value = None mock_os.get.return_value = None s3c = S3Config.auto() kwargs = s3_setup_args(s3c) - assert kwargs == {"cache_regions": True} + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "aws_allow_http": "true", # Allow HTTP connections + "aws_virtual_hosted_style_request": "false", # Use path-style addressing + }, + ) @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") -def test_s3_setup_args_env_both(mock_os, mock_get_config_file): +@mock.patch("obstore.store.S3Store.from_env") +def test_s3_setup_args_env_both(mock_from_env, mock_os, mock_get_config_file): mock_get_config_file.return_value = None + ee = { "AWS_ACCESS_KEY_ID": "ignore-user", "AWS_SECRET_ACCESS_KEY": "ignore-secret", @@ -222,12 +269,23 @@ def test_s3_setup_args_env_both(mock_os, mock_get_config_file): } mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "access_key_id": "flyte", + "secret_access_key": "flyte-secret", + "aws_allow_http": "true", # Allow HTTP connections + "aws_virtual_hosted_style_request": "false", # Use path-style addressing + }, + ) @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") -def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): +@mock.patch("obstore.store.S3Store.from_env") +def test_s3_setup_args_env_flyte(mock_from_env, mock_os, mock_get_config_file): mock_get_config_file.return_value = None ee = { "FLYTE_AWS_ACCESS_KEY_ID": "flyte", @@ -235,12 +293,23 @@ def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): } mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "access_key_id": "flyte", + "secret_access_key": "flyte-secret", + "aws_allow_http": "true", # Allow HTTP connections + "aws_virtual_hosted_style_request": "false", # Use path-style addressing + }, + ) @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") -def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): +@mock.patch("obstore.store.S3Store.from_env") +def test_s3_setup_args_env_aws(mock_from_env, mock_os, mock_get_config_file): mock_get_config_file.return_value = None ee = { "AWS_ACCESS_KEY_ID": "ignore-user", @@ -248,8 +317,15 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): } mock_os.get.side_effect = lambda x, y: ee.get(x) kwargs = s3_setup_args(S3Config.auto()) - # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default - assert kwargs == {"cache_regions": True} + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "aws_allow_http": "true", # Allow HTTP connections + "aws_virtual_hosted_style_request": "false", # Use path-style addressing + }, + ) @mock.patch("flytekit.configuration.get_config_file") @@ -272,31 +348,42 @@ def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_ "FLYTE_GCP_GSUTIL_PARALLELISM": "False", } mock_os.get.side_effect = lambda x, y: ee.get(x) - storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") + storage_options = get_fsspec_storage_options( + "gs", DataConfig.auto(), anonymous=True, other_argument="value" + ) assert storage_options == {"token": "anon", "other_argument": "value"} @mock.patch("flytekit.configuration.get_config_file") @mock.patch("os.environ") -def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file): +@mock.patch("obstore.store.AzureStore.from_env") +def test_get_fsspec_storage_options_azure(mock_from_env, mock_os, mock_get_config_file): mock_get_config_file.return_value = None + account_key = "accountkey" + + account_key_base64 = base64.b64encode(account_key.encode()).decode() + ee = { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", - "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": account_key_base64, "FLYTE_AZURE_TENANT_ID": "tenantid", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", } mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) - assert storage_options == { - "account_name": "accountname", - "account_key": "accountkey", - "client_id": "clientid", - "client_secret": "clientsecret", - "tenant_id": "tenantid", - "anon": False, - } + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "account_name": "accountname", + "account_key": account_key_base64, + "client_id": "clientid", + "client_secret": "clientsecret", + "tenant_id": "tenantid", + }, + ) @mock.patch("flytekit.configuration.get_config_file") @@ -352,8 +439,14 @@ def test_crawl_local_non_nt(source_folder): res = fd.crawl() split = [(x, y) for x, y in res] files = [os.path.join(x, y) for x, y in split] - assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} - expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(split) == { + (source_folder, "original.txt"), + (source_folder, os.path.join("nested", "more.txt")), + } + expected = { + os.path.join(source_folder, "original.txt"), + os.path.join(source_folder, "nested", "more.txt"), + } assert set(files) == expected # Test crawling a directory without trailing / or \ @@ -379,12 +472,19 @@ def test_crawl_s3(source_folder): # Running mkdir on s3 filesystem doesn't do anything so leaving out for now dc = Config.for_sandbox().data_config provider = FileAccessProvider( - local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + local_sandbox_dir="/tmp/unittest", + raw_output_prefix="s3://my-s3-bucket/testdata/", + data_config=dc, + ) + s3_random_target = provider.join( + provider.raw_output_prefix, provider.get_random_string() ) - s3_random_target = provider.join(provider.raw_output_prefix, provider.get_random_string()) provider.put_data(source_folder, s3_random_target, is_multipart=True) ctx = FlyteContextManager.current_context() - expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + expected = { + f"{s3_random_target}/original.txt", + f"{s3_random_target}/nested/more.txt", + } with FlyteContextManager.with_context(ctx.with_file_access(provider)): fd = FlyteDirectory(path=s3_random_target) @@ -392,7 +492,10 @@ def test_crawl_s3(source_folder): res = [(x, y) for x, y in res] files = [os.path.join(x, y) for x, y in res] assert set(files) == expected - assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + assert set(res) == { + (s3_random_target, "original.txt"), + (s3_random_target, os.path.join("nested", "more.txt")), + } fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") res = fd_file.crawl() @@ -405,7 +508,11 @@ def test_walk_local_copy_to_s3(source_folder): dc = Config.for_sandbox().data_config explicit_empty_folder = UUID(int=random.getrandbits(128)).hex raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", + raw_output_prefix=raw_output_path, + data_config=dc, + ) ctx = FlyteContextManager.current_context() local_fd = FlyteDirectory(path=source_folder) @@ -433,7 +540,9 @@ def test_s3_metadata(): dc = Config.for_sandbox().data_config random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/metadata_test/{random_folder}" - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) _, local_zip = tempfile.mkstemp(suffix=".gz") with open(local_zip, "w") as f: f.write("hello world") @@ -454,7 +563,9 @@ def test_s3_metadata(): assert len(files) == 2 -async def dummy_output_to_literal_map(ctx: FlyteContext, ff: typing.List[FlyteFile]) -> Literal: +async def dummy_output_to_literal_map( + ctx: FlyteContext, ff: typing.List[FlyteFile] +) -> Literal: lt = TypeEngine.to_literal_type(typing.List[FlyteFile]) lit = await TypeEngine.async_to_literal(ctx, ff, typing.List[FlyteFile], lt) return lit @@ -479,7 +590,9 @@ def test_async_local_copy_to_s3(): random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" print(f"Uploading to {raw_output}") - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) start_time = datetime.datetime.now(datetime.timezone.utc) start_wall_time = time.perf_counter() @@ -522,10 +635,17 @@ def test_async_download_from_s3(): random_folder = UUID(int=random.getrandbits(64)).hex raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" print(f"Uploading to {raw_output}") - provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc + ) with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: - lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + lit = TypeEngine.to_literal( + ctx, + ff, + typing.List[FlyteFile], + TypeEngine.to_literal_type(typing.List[FlyteFile]), + ) print(f"Literal is {lit}") python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) @@ -545,10 +665,17 @@ def test_async_download_from_s3(): print(f"Time taken (serial download): {end_time - start_time}") print(f"Wall time taken (serial download): {end_wall_time - start_wall_time}") - print(f"Process time taken (serial download): {end_process_time - start_process_time}") + print( + f"Process time taken (serial download): {end_process_time - start_process_time}" + ) with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: - lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + lit = TypeEngine.to_literal( + ctx, + ff, + typing.List[FlyteFile], + TypeEngine.to_literal_type(typing.List[FlyteFile]), + ) print(f"Literal is {lit}") python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 116717b92d..776c2df01a 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -10,6 +10,7 @@ import mock import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from botocore.parsers import base64 from flytekit.configuration import Config from flytekit.core.data_persistence import FileAccessProvider @@ -153,18 +154,29 @@ def test_generate_new_custom_path(): assert np == "s3://foo-bucket/my-default-prefix/bar.txt" -def test_initialise_azure_file_provider_with_account_key(): +@mock.patch("obstore.store.AzureStore.from_env") +def test_initialise_azure_file_provider_with_account_key(mock_from_env): + account_key = "accountkey" + account_key_base64 = base64.b64encode(account_key.encode()).decode() + with mock.patch.dict( os.environ, - {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, + {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": account_key_base64}, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert fp.get_filesystem().account_key == "accountkey" - assert fp.get_filesystem().sync_credential is None + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "account_name": "accountname", + "account_key": account_key_base64, + }, + ) -def test_initialise_azure_file_provider_with_service_principal(): + +@mock.patch("obstore.store.AzureStore.from_env") +def test_initialise_azure_file_provider_with_service_principal(mock_from_env): with mock.patch.dict( os.environ, { @@ -175,14 +187,21 @@ def test_initialise_azure_file_provider_with_service_principal(): }, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) - assert fp.get_filesystem().client_secret == "clientsecret" - assert fp.get_filesystem().client_id == "clientid" - assert fp.get_filesystem().tenant_id == "tenantid" + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "account_name": "accountname", + "client_secret": "clientsecret", + "client_id": "clientid", + "tenant_id": "tenantid", + }, + ) -def test_initialise_azure_file_provider_with_default_credential(): +@mock.patch("obstore.store.AzureStore.from_env") +def test_initialise_azure_file_provider_with_default_credential(mock_from_env): with mock.patch.dict( os.environ, { @@ -191,8 +210,14 @@ def test_initialise_azure_file_provider_with_default_credential(): }, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") - assert fp.get_filesystem().account_name == "accountname" - assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) + + mock_from_env.return_value = mock.Mock() + mock_from_env.assert_called_with( + "", + config={ + "account_name": "accountname", + }, + ) def test_get_file_system(): diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index be61388fa5..b72ec1aed9 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -318,11 +318,11 @@ def test_directory_guess(): assert fft.extension() == "" -@mock.patch("s3fs.core.S3FileSystem._lsdir") +@mock.patch("flytekit.core.obstore_filesystem.ObstoreS3FileSystem._ls") @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -def test_list_dir(mock_get_data, mock_lsdir): +def test_list_dir(mock_get_data, mock_ls): remote_dir = "s3://test-flytedir" - mock_lsdir.return_value = [ + mock_ls.return_value = [ {"name": os.path.join(remote_dir, "file1.txt"), "type": "file"}, {"name": os.path.join(remote_dir, "file2.txt"), "type": "file"}, {"name": os.path.join(remote_dir, "subdir"), "type": "directory"},