Skip to content

Commit

Permalink
Avoid computing expensive default values when the value is overridden…
Browse files Browse the repository at this point in the history
… anyways

The most expensive call at the moment is repo.get_changed_files, which
does down the drain if:
- another default_fn overrides the value
- an explicit override is given when creating the `Parameters`

With this change, the default function can return a function as a value,
which is not evaluated unless necessary.

Fixes #616
  • Loading branch information
glandium committed Jan 9, 2025
1 parent 5fb0641 commit 5983602
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/taskgraph/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gzip
import hashlib
import inspect
import json
import os
import time
Expand Down Expand Up @@ -110,7 +111,7 @@ def _get_defaults(repo_root=None):
"do_not_optimize": [],
"enable_always_target": True,
"existing_tasks": {},
"files_changed": repo.get_changed_files("AM"),
"files_changed": lambda: repo.get_changed_files("AM"),
"filters": ["target_tasks_method"],
"head_ref": repo.branch or repo.head_rev,
"head_repository": repo_url,
Expand Down Expand Up @@ -210,7 +211,7 @@ def _fill_defaults(repo_root=None, **kwargs):

for name, default in defaults.items():
if name not in kwargs:
kwargs[name] = default
kwargs[name] = default() if inspect.isfunction(default) else default
return kwargs

def check(self):
Expand Down
9 changes: 7 additions & 2 deletions test/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def test_extend_parameters_schema(monkeypatch):
}
),
)
monkeypatch.setattr(
parameters,
"defaults_functions",
list(parameters.defaults_functions),
)

with pytest.raises(ParameterMismatch):
Parameters(strict=False).check()
Expand Down Expand Up @@ -437,7 +442,7 @@ def test_extend_parameters_schema(monkeypatch):
),
),
)
def test_get_defaults(
def test_defaults(
monkeypatch, repo_root, is_repo, raises, expected_repo_root, expected
):
def mock_get_repository(repo_root):
Expand Down Expand Up @@ -478,4 +483,4 @@ def mock_parse(url):
monkeypatch.setattr(parameters, "datetime", datetime_mock)
monkeypatch.setattr(parameters, "get_version", lambda *_, **__: "1.0.0")

assert parameters._get_defaults(repo_root) == expected
assert parameters.Parameters(strict=False, repo_root=repo_root) == expected

0 comments on commit 5983602

Please sign in to comment.