From 43b055ca7b3602f46693e950a6df0b96faed6467 Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Thu, 9 Jan 2025 15:28:26 +0900 Subject: [PATCH] Avoid computing expensive default values when the value is overridden anyways The most expensive call at the moment is repo.get_changed_files, which does down the drain if: - another default_fn overrides the value - an explicit override is given when creating the `Parameters` With this change, the default function can return a function as a value, which is not evaluated unless necessary. Fixes #616 --- src/taskgraph/parameters.py | 4 ++-- test/test_parameters.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/taskgraph/parameters.py b/src/taskgraph/parameters.py index 886db3f21..80fd40475 100644 --- a/src/taskgraph/parameters.py +++ b/src/taskgraph/parameters.py @@ -110,7 +110,7 @@ def _get_defaults(repo_root=None): "do_not_optimize": [], "enable_always_target": True, "existing_tasks": {}, - "files_changed": repo.get_changed_files("AM"), + "files_changed": lambda: repo.get_changed_files("AM"), "filters": ["target_tasks_method"], "head_ref": repo.branch or repo.head_rev, "head_repository": repo_url, @@ -210,7 +210,7 @@ def _fill_defaults(repo_root=None, **kwargs): for name, default in defaults.items(): if name not in kwargs: - kwargs[name] = default + kwargs[name] = default() if callable(default) else default return kwargs def check(self): diff --git a/test/test_parameters.py b/test/test_parameters.py index f5e03c271..c2a40bd86 100644 --- a/test/test_parameters.py +++ b/test/test_parameters.py @@ -264,6 +264,11 @@ def test_extend_parameters_schema(monkeypatch): } ), ) + monkeypatch.setattr( + parameters, + "defaults_functions", + list(parameters.defaults_functions), + ) with pytest.raises(ParameterMismatch): Parameters(strict=False).check() @@ -437,7 +442,7 @@ def test_extend_parameters_schema(monkeypatch): ), ), ) -def test_get_defaults( +def test_defaults( monkeypatch, repo_root, is_repo, raises, expected_repo_root, expected ): def mock_get_repository(repo_root): @@ -478,4 +483,4 @@ def mock_parse(url): monkeypatch.setattr(parameters, "datetime", datetime_mock) monkeypatch.setattr(parameters, "get_version", lambda *_, **__: "1.0.0") - assert parameters._get_defaults(repo_root) == expected + assert parameters.Parameters(strict=False, repo_root=repo_root) == expected