Skip to content

Commit

Permalink
fix: fix find_first_branch_matching_pattern, include remote branches
Browse files Browse the repository at this point in the history
  • Loading branch information
renatav committed Dec 25, 2023
1 parent 4f546aa commit 7f31c32
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions taf/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ def find_first_branch_matching_pattern(
traverse_branch_name: str,
pattern_func: Callable[[str], bool],
include_remotes: bool = False,
sort_key_func: Optional[Callable[[str], bool]]=None,
) -> Tuple[Optional[str], List[str]]:
branch_tips = {}
repo = self.pygit_repo
Expand All @@ -958,20 +959,31 @@ def find_first_branch_matching_pattern(
else:
raise GitError(f"Branch {traverse_branch_name} does not exist")

branches = repo.branches if include_remotes else repo.branches.local
branches = self.branches(all=include_remotes)
all_branch_names = []

for branch_name in branches:
if pattern_func(branch_name):
stripped_name = self._remove_remote_prefix(branch_name)
if stripped_name in all_branch_names:
continue
if pattern_func(stripped_name):
branch = repo.lookup_branch(branch_name)
branch_tips[branch_name] = branch.peel().hex
all_branch_names.append(branch_name)

try:
branch_tips[stripped_name] = branch.peel().hex
except Exception:
ref = repo.references[f"refs/remotes/{branch_name}"]
commit = ref.peel(pygit2.Commit)
branch_tips[stripped_name] = commit.hex
all_branch_names.append(stripped_name)

if sort_key_func is not None:
all_branch_names = sorted(all_branch_names, key=sort_key_func, reverse=True)
# Iterate over commits from newest to oldest
if len(branch_tips):
for commit in repo.walk(branch_target, pygit2.GIT_SORT_TIME):
# Check if commit is in any of the pre-filtered branches
for branch_name, tip_hex in branch_tips.items():
if commit.hex == tip_hex or repo.descendant_of(tip_hex, commit.hex):
for commit in repo.walk(branch_target):
for branch_name in all_branch_names:
tip_hex = branch_tips.get(branch_name)
if tip_hex is not None and (commit.hex == tip_hex or repo.descendant_of(tip_hex, commit.hex)):
return branch_name, []
return None, all_branch_names

Expand Down Expand Up @@ -1374,6 +1386,15 @@ def _determine_default_branch(self) -> Optional[str]:
)
return None

def _remove_remote_prefix(self, branch_name):
for remote in self.remotes:
prefix = f"{remote}/"
if branch_name.startswith(prefix):
return branch_name[len(prefix):]
return branch_name
return branch_name


def _validate_repo_name(self, name: str) -> str:
"""Ensure the repo name is not malicious"""
match = _repo_name_re.match(name)
Expand Down

0 comments on commit 7f31c32

Please sign in to comment.