diff --git a/taf/git.py b/taf/git.py index bfa92c08..091f4bb4 100644 --- a/taf/git.py +++ b/taf/git.py @@ -946,6 +946,7 @@ def find_first_branch_matching_pattern( traverse_branch_name: str, pattern_func: Callable[[str], bool], include_remotes: bool = False, + sort_key_func: Optional[Callable[[str], bool]]=None, ) -> Tuple[Optional[str], List[str]]: branch_tips = {} repo = self.pygit_repo @@ -958,20 +959,31 @@ def find_first_branch_matching_pattern( else: raise GitError(f"Branch {traverse_branch_name} does not exist") - branches = repo.branches if include_remotes else repo.branches.local + branches = self.branches(all=include_remotes) all_branch_names = [] + for branch_name in branches: - if pattern_func(branch_name): + stripped_name = self._remove_remote_prefix(branch_name) + if stripped_name in all_branch_names: + continue + if pattern_func(stripped_name): branch = repo.lookup_branch(branch_name) - branch_tips[branch_name] = branch.peel().hex - all_branch_names.append(branch_name) - + try: + branch_tips[stripped_name] = branch.peel().hex + except Exception: + ref = repo.references[f"refs/remotes/{branch_name}"] + commit = ref.peel(pygit2.Commit) + branch_tips[stripped_name] = commit.hex + all_branch_names.append(stripped_name) + + if sort_key_func is not None: + all_branch_names = sorted(all_branch_names, key=sort_key_func, reverse=True) # Iterate over commits from newest to oldest if len(branch_tips): - for commit in repo.walk(branch_target, pygit2.GIT_SORT_TIME): - # Check if commit is in any of the pre-filtered branches - for branch_name, tip_hex in branch_tips.items(): - if commit.hex == tip_hex or repo.descendant_of(tip_hex, commit.hex): + for commit in repo.walk(branch_target): + for branch_name in all_branch_names: + tip_hex = branch_tips.get(branch_name) + if tip_hex is not None and (commit.hex == tip_hex or repo.descendant_of(tip_hex, commit.hex)): return branch_name, [] return None, all_branch_names @@ -1374,6 +1386,15 @@ def _determine_default_branch(self) -> Optional[str]: ) return None + def _remove_remote_prefix(self, branch_name): + for remote in self.remotes: + prefix = f"{remote}/" + if branch_name.startswith(prefix): + return branch_name[len(prefix):] + return branch_name + return branch_name + + def _validate_repo_name(self, name: str) -> str: """Ensure the repo name is not malicious""" match = _repo_name_re.match(name)