Skip to content

Commit

Permalink
Make get_auth a parameter to read_simpleapi
Browse files Browse the repository at this point in the history
  • Loading branch information
WillMorrison committed Jan 14, 2025
1 parent fdb5a82 commit 0320b60
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
12 changes: 8 additions & 4 deletions python/private/pypi/simpleapi_download.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A file that houses private functions used in the `bzlmod` extension with the sam
"""

load("@bazel_features//:features.bzl", "bazel_features")
load("//python/private:auth.bzl", "get_auth")
load("//python/private:auth.bzl", _get_auth = "get_auth")
load("//python/private:envsubst.bzl", "envsubst")
load("//python/private:normalize_name.bzl", "normalize_name")
load("//python/private:text_util.bzl", "render")
Expand All @@ -30,6 +30,7 @@ def simpleapi_download(
cache,
parallel_download = True,
read_simpleapi = None,
get_auth = None,
_fail = fail):
"""Download Simple API HTML.
Expand Down Expand Up @@ -59,6 +60,7 @@ def simpleapi_download(
parallel_download: A boolean to enable usage of bazel 7.1 non-blocking downloads.
read_simpleapi: a function for reading and parsing of the SimpleAPI contents.
Used in tests.
get_auth: A function to get auth information passed to read_simpleapi. Used in tests.
_fail: a function to print a failure. Used in tests.
Returns:
Expand Down Expand Up @@ -98,6 +100,7 @@ def simpleapi_download(
),
attr = attr,
cache = cache,
get_auth = get_auth,
**download_kwargs
)
if hasattr(result, "wait"):
Expand Down Expand Up @@ -144,7 +147,7 @@ def simpleapi_download(

return contents

def _read_simpleapi(ctx, url, attr, cache, **download_kwargs):
def _read_simpleapi(ctx, url, attr, cache, get_auth = None, **download_kwargs):
"""Read SimpleAPI.
Args:
Expand All @@ -157,6 +160,7 @@ def _read_simpleapi(ctx, url, attr, cache, **download_kwargs):
* auth_patterns: The auth_patterns parameter for ctx.download, see
http_file for docs.
cache: A dict for storing the results.
get_auth: A function to get auth information. Used in tests.
**download_kwargs: Any extra params to ctx.download.
Note that output and auth will be passed for you.
Expand Down Expand Up @@ -194,13 +198,13 @@ def _read_simpleapi(ctx, url, attr, cache, **download_kwargs):

output = ctx.path(output_str.strip("_").lower() + ".html")

_get_auth = ctx.get_auth if hasattr(ctx, "get_auth") else get_auth
get_auth = get_auth or _get_auth

# NOTE: this may have block = True or block = False in the download_kwargs
download = ctx.download(
url = [real_url],
output = output,
auth = _get_auth(ctx, [real_url], ctx_attr = attr),
auth = get_auth(ctx, [real_url], ctx_attr = attr),
allow_fail = True,
**download_kwargs
)
Expand Down
12 changes: 7 additions & 5 deletions tests/pypi/simpleapi_download/simpleapi_download_tests.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ _tests = []
def _test_simple(env):
calls = []

def read_simpleapi(ctx, url, attr, cache, block):
def read_simpleapi(ctx, url, attr, cache, get_auth, block):
_ = ctx # buildifier: disable=unused-variable
_ = attr
_ = cache
_ = get_auth
env.expect.that_bool(block).equals(False)
calls.append(url)
if "foo" in url and "main" in url:
Expand Down Expand Up @@ -73,10 +74,11 @@ def _test_fail(env):
calls = []
fails = []

def read_simpleapi(ctx, url, attr, cache, block):
def read_simpleapi(ctx, url, attr, cache, get_auth, block):
_ = ctx # buildifier: disable=unused-variable
_ = attr
_ = cache
_ = get_auth
env.expect.that_bool(block).equals(False)
calls.append(url)
if "foo" in url:
Expand Down Expand Up @@ -133,7 +135,6 @@ def _test_download_url(env):
download = download,
read = lambda i: "contents of " + i,
path = lambda i: "path/for/" + i,
get_auth = lambda ctx, urls, ctx_attr: struct(),
),
attr = struct(
index_url_overrides = {},
Expand All @@ -144,6 +145,7 @@ def _test_download_url(env):
),
cache = {},
parallel_download = False,
get_auth = lambda ctx, urls, ctx_attr: struct(),
)

env.expect.that_dict(downloads).contains_exactly({
Expand All @@ -168,7 +170,6 @@ def _test_download_url_parallel(env):
download = download,
read = lambda i: "contents of " + i,
path = lambda i: "path/for/" + i,
get_auth = lambda ctx, urls, ctx_attr: struct(),
),
attr = struct(
index_url_overrides = {},
Expand All @@ -179,6 +180,7 @@ def _test_download_url_parallel(env):
),
cache = {},
parallel_download = True,
get_auth = lambda ctx, urls, ctx_attr: struct(),
)

env.expect.that_dict(downloads).contains_exactly({
Expand All @@ -203,7 +205,6 @@ def _test_download_envsubst_url(env):
download = download,
read = lambda i: "contents of " + i,
path = lambda i: "path/for/" + i,
get_auth = lambda ctx, urls, ctx_attr: struct(),
),
attr = struct(
index_url_overrides = {},
Expand All @@ -214,6 +215,7 @@ def _test_download_envsubst_url(env):
),
cache = {},
parallel_download = False,
get_auth = lambda ctx, urls, ctx_attr: struct(),
)

env.expect.that_dict(downloads).contains_exactly({
Expand Down

0 comments on commit 0320b60

Please sign in to comment.