diff --git a/python/private/pypi/simpleapi_download.bzl b/python/private/pypi/simpleapi_download.bzl index b633df70d5..ef39fb8723 100644 --- a/python/private/pypi/simpleapi_download.bzl +++ b/python/private/pypi/simpleapi_download.bzl @@ -17,7 +17,7 @@ A file that houses private functions used in the `bzlmod` extension with the sam """ load("@bazel_features//:features.bzl", "bazel_features") -load("//python/private:auth.bzl", "get_auth") +load("//python/private:auth.bzl", _get_auth = "get_auth") load("//python/private:envsubst.bzl", "envsubst") load("//python/private:normalize_name.bzl", "normalize_name") load("//python/private:text_util.bzl", "render") @@ -30,6 +30,7 @@ def simpleapi_download( cache, parallel_download = True, read_simpleapi = None, + get_auth = None, _fail = fail): """Download Simple API HTML. @@ -59,6 +60,7 @@ def simpleapi_download( parallel_download: A boolean to enable usage of bazel 7.1 non-blocking downloads. read_simpleapi: a function for reading and parsing of the SimpleAPI contents. Used in tests. + get_auth: A function to get auth information passed to read_simpleapi. Used in tests. _fail: a function to print a failure. Used in tests. Returns: @@ -98,6 +100,7 @@ def simpleapi_download( ), attr = attr, cache = cache, + get_auth = get_auth, **download_kwargs ) if hasattr(result, "wait"): @@ -144,7 +147,7 @@ def simpleapi_download( return contents -def _read_simpleapi(ctx, url, attr, cache, **download_kwargs): +def _read_simpleapi(ctx, url, attr, cache, get_auth = None, **download_kwargs): """Read SimpleAPI. Args: @@ -157,6 +160,7 @@ def _read_simpleapi(ctx, url, attr, cache, **download_kwargs): * auth_patterns: The auth_patterns parameter for ctx.download, see http_file for docs. cache: A dict for storing the results. + get_auth: A function to get auth information. Used in tests. **download_kwargs: Any extra params to ctx.download. Note that output and auth will be passed for you. @@ -194,13 +198,13 @@ def _read_simpleapi(ctx, url, attr, cache, **download_kwargs): output = ctx.path(output_str.strip("_").lower() + ".html") - _get_auth = ctx.get_auth if hasattr(ctx, "get_auth") else get_auth + get_auth = get_auth or _get_auth # NOTE: this may have block = True or block = False in the download_kwargs download = ctx.download( url = [real_url], output = output, - auth = _get_auth(ctx, [real_url], ctx_attr = attr), + auth = get_auth(ctx, [real_url], ctx_attr = attr), allow_fail = True, **download_kwargs ) diff --git a/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl b/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl index e99662dec0..964d3e25ea 100644 --- a/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl +++ b/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl @@ -22,10 +22,11 @@ _tests = [] def _test_simple(env): calls = [] - def read_simpleapi(ctx, url, attr, cache, block): + def read_simpleapi(ctx, url, attr, cache, get_auth, block): _ = ctx # buildifier: disable=unused-variable _ = attr _ = cache + _ = get_auth env.expect.that_bool(block).equals(False) calls.append(url) if "foo" in url and "main" in url: @@ -73,10 +74,11 @@ def _test_fail(env): calls = [] fails = [] - def read_simpleapi(ctx, url, attr, cache, block): + def read_simpleapi(ctx, url, attr, cache, get_auth, block): _ = ctx # buildifier: disable=unused-variable _ = attr _ = cache + _ = get_auth env.expect.that_bool(block).equals(False) calls.append(url) if "foo" in url: @@ -133,7 +135,6 @@ def _test_download_url(env): download = download, read = lambda i: "contents of " + i, path = lambda i: "path/for/" + i, - get_auth = lambda ctx, urls, ctx_attr: struct(), ), attr = struct( index_url_overrides = {}, @@ -144,6 +145,7 @@ def _test_download_url(env): ), cache = {}, parallel_download = False, + get_auth = lambda ctx, urls, ctx_attr: struct(), ) env.expect.that_dict(downloads).contains_exactly({ @@ -168,7 +170,6 @@ def _test_download_url_parallel(env): download = download, read = lambda i: "contents of " + i, path = lambda i: "path/for/" + i, - get_auth = lambda ctx, urls, ctx_attr: struct(), ), attr = struct( index_url_overrides = {}, @@ -179,6 +180,7 @@ def _test_download_url_parallel(env): ), cache = {}, parallel_download = True, + get_auth = lambda ctx, urls, ctx_attr: struct(), ) env.expect.that_dict(downloads).contains_exactly({ @@ -203,7 +205,6 @@ def _test_download_envsubst_url(env): download = download, read = lambda i: "contents of " + i, path = lambda i: "path/for/" + i, - get_auth = lambda ctx, urls, ctx_attr: struct(), ), attr = struct( index_url_overrides = {}, @@ -214,6 +215,7 @@ def _test_download_envsubst_url(env): ), cache = {}, parallel_download = False, + get_auth = lambda ctx, urls, ctx_attr: struct(), ) env.expect.that_dict(downloads).contains_exactly({