From 380df04b62be9242ecb7c8fd3dc1ac09bf9889d0 Mon Sep 17 00:00:00 2001 From: Alexander Cooper Date: Wed, 5 Feb 2025 10:15:43 -0800 Subject: [PATCH] Modify git squash-branch to perform reparenting Bug: 40264739 Change-Id: I4ad7f4f8a670334b32c239458048e56c6af44098 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/6227541 Reviewed-by: Josip Sokcevic Reviewed-by: Yiwei Zhang Commit-Queue: Alexander Cooper --- git_common.py | 54 ++++++++++++ git_squash_branch.py | 69 +++++++++++++++ git_squash_branch_tree.py | 38 +-------- tests/git_common_test.py | 28 ++++++ tests/git_squash_branch_test.py | 146 ++++++++++++++++++++++++++++++++ 5 files changed, 300 insertions(+), 35 deletions(-) create mode 100644 tests/git_squash_branch_test.py diff --git a/git_common.py b/git_common.py index 990a22044..257dca54f 100644 --- a/git_common.py +++ b/git_common.py @@ -625,6 +625,60 @@ def get_branch_tree(use_limit=False): return skipped, branch_tree +def get_diverged_branches(branch_tree=None): + """Gets the branches from the tree that have diverged from their upstream + + Returns the list of branches that have diverged from their respective + upstream branch. + Expects to receive a tree as generated from `get_branch_tree`, which it will + call if not supplied, ignoring branches without upstreams. + """ + if not branch_tree: + _, branch_tree = get_branch_tree() + diverged_branches = [] + for branch, upstream_branch in branch_tree.items(): + # If the merge base of a branch and its upstream is not equal to the + # upstream, then it means that both branch diverged. + upstream_branch_hash = hash_one(upstream_branch) + merge_base_hash = hash_one(get_or_create_merge_base(branch)) + if upstream_branch_hash != merge_base_hash: + diverged_branches.append(branch) + return diverged_branches + + +def get_hashes(branch_tree=None): + """Get the dictionary of {branch: hash} + + Returns a dictionary that contains the hash of every branch. Suitable for + saving hashes before performing destructive operations to perform + appropriate rebases. + Expects to receive a tree as generated from `get_branch_tree`, which it will + call if not supplied, ignoring branches without upstreams. + """ + if not branch_tree: + _, branch_tree = get_branch_tree() + hashes = {} + for branch, upstream_branch in branch_tree.items(): + hashes[branch] = hash_one(branch) + hashes[upstream_branch] = hash_one(upstream_branch) + return hashes + + +def get_downstream_branches(branch_tree=None): + """Get the dictionary of {branch: children} + + Returns a dictionary that contains the list of downstream branches for every + branch. + Expects to receive a tree as generated from `get_branch_tree`, which it will + call if not supplied, ignoring branches without upstreams. + """ + if not branch_tree: + _, branch_tree = get_branch_tree() + downstream_branches = collections.defaultdict(list) + for branch, upstream_branch in branch_tree.items(): + downstream_branches[upstream_branch].append(branch) + return downstream_branches + def get_or_create_merge_base(branch, parent=None) -> Optional[str]: """Finds the configured merge base for branch. diff --git a/git_squash_branch.py b/git_squash_branch.py index 761a71c7d..005707b85 100755 --- a/git_squash_branch.py +++ b/git_squash_branch.py @@ -10,6 +10,48 @@ import gclient_utils import git_common +# Squash a branch, taking care to rebase the branch on top of the new commit +# position of its upstream branch. +def rebase_branch(branch, initial_hashes): + print('Re-parenting branch %s.' % branch) + assert initial_hashes[branch] == git_common.hash_one(branch) + + upstream_branch = git_common.upstream(branch) + old_upstream_branch = initial_hashes[upstream_branch] + + # Because the branch's upstream has potentially changed from squashing it, + # the current branch is rebased on top of the new upstream. + git_common.run('rebase', '--onto', upstream_branch, old_upstream_branch, + branch, '--update-refs') + + +# Squashes all branches that are part of the subtree starting at `branch`. +def rebase_subtree(branch, initial_hashes, downstream_branches): + # Rebase us onto our parent + rebase_branch(branch, initial_hashes) + + # Recurse on downstream branches, if any. + for downstream_branch in downstream_branches[branch]: + rebase_subtree(downstream_branch, initial_hashes, downstream_branches) + + +def children_have_diverged(branch, downstream_branches, diverged_branches): + # If we have no diverged branches, then no children have diverged. + if not diverged_branches: + return False + + # If we have diverged, then our children have diverged. + if branch in diverged_branches: + return True + + # If any of our children have diverged, then we need to return true. + for downstream_branch in downstream_branches[branch]: + if children_have_diverged(downstream_branch, downstream_branches, + diverged_branches): + return True + + return False + def main(args): if gclient_utils.IsEnvCog(): print('squash-branch command is not supported in non-git environment.', @@ -25,7 +67,34 @@ def main(args): opts = parser.parse_args(args) if git_common.is_dirty_git_tree('squash-branch'): return 1 + + # Save off the current branch so we can return to it at the end. + return_branch = git_common.current_branch() + + # Save the hashes before we mutate the tree so that we have all of the + # necessary rebasing information. + _, tree = git_common.get_branch_tree() + initial_hashes = git_common.get_hashes(tree) + downstream_branches = git_common.get_downstream_branches(tree) + diverged_branches = git_common.get_diverged_branches(tree) + + # We won't be rebasing our squashed branch, so only check any potential + # children + for branch in downstream_branches[return_branch]: + if children_have_diverged(branch, downstream_branches, + diverged_branches): + print('Cannot use `git squash-branch` since some children have ' + 'diverged from their upstream and could cause conflicts.') + return 1 + git_common.squash_current_branch(opts.message) + + # Fixup our children with our new state. + for branch in downstream_branches[return_branch]: + rebase_subtree(branch, initial_hashes, downstream_branches) + + git_common.run('checkout', return_branch) + return 0 diff --git a/git_squash_branch_tree.py b/git_squash_branch_tree.py index 724f7ff5b..3dbe1ee80 100755 --- a/git_squash_branch_tree.py +++ b/git_squash_branch_tree.py @@ -13,38 +13,6 @@ import git_common as git import sys -# Returns the list of branches that have diverged from their respective upstream -# branch. -def get_diverged_branches(tree): - diverged_branches = [] - for branch, upstream_branch in tree.items(): - # If the merge base of a branch and its upstream is not equal to the - # upstream, then it means that both branch diverged. - upstream_branch_hash = git.hash_one(upstream_branch) - merge_base_hash = git.hash_one(git.get_or_create_merge_base(branch)) - if upstream_branch_hash != merge_base_hash: - diverged_branches.append(branch) - return diverged_branches - - -# Returns a dictionary that contains the hash of every branch before the -# squashing started. -def get_initial_hashes(tree): - initial_hashes = {} - for branch, upstream_branch in tree.items(): - initial_hashes[branch] = git.hash_one(branch) - initial_hashes[upstream_branch] = git.hash_one(upstream_branch) - return initial_hashes - - -# Returns a dictionary that contains the downstream branches of every branch. -def get_downstream_branches(tree): - downstream_branches = collections.defaultdict(list) - for branch, upstream_branch in tree.items(): - downstream_branches[upstream_branch].append(branch) - return downstream_branches - - # Squash a branch, taking care to rebase the branch on top of the new commit # position of its upstream branch. def squash_branch(branch, initial_hashes): @@ -102,7 +70,7 @@ def main(args=None): print('Use --ignore-no-upstream to ignore this check and proceed.') return 1 - diverged_branches = get_diverged_branches(tree) + diverged_branches = git.get_diverged_branches(tree) if diverged_branches: print('Cannot use `git squash-branch-tree` since the following\n' 'branches have diverged from their upstream and could cause\n' @@ -115,8 +83,8 @@ def main(args=None): # we can go back to it at the end. return_branch = git.current_branch() - initial_hashes = get_initial_hashes(tree) - downstream_branches = get_downstream_branches(tree) + initial_hashes = git.get_hashes(tree) + downstream_branches = git.get_downstream_branches(tree) squash_subtree(opts.branch, initial_hashes, downstream_branches) git.run('checkout', return_branch) diff --git a/tests/git_common_test.py b/tests/git_common_test.py index d7b95fa33..37cbcf51b 100755 --- a/tests/git_common_test.py +++ b/tests/git_common_test.py @@ -751,6 +751,34 @@ class GitMutableStructuredTest(git_test_utils.GitRepoReadWriteTestBase, ('root_A', 'root_X'), ]) + def testGetHashes(self): + hashes = self.repo.run(self.gc.get_hashes) + for branch, branch_hash in hashes.items(): + self.assertEqual(self.repo.run(self.gc.hash_one, branch), + branch_hash) + + def testGetDownstreamBranches(self): + downstream_branches = self.repo.run(self.gc.get_downstream_branches) + self.assertEqual( + downstream_branches, { + 'root_A': ['branch_G'], + 'branch_G': ['branch_K'], + 'branch_K': ['branch_L'], + 'root_X': ['branch_Z', 'root_A'], + }) + + def testGetDivergedBranches(self): + # root_X and root_A don't actually have a common base commit due to the + # test repo's structure, which causes get_diverged_branches to throw + # an error. + self.repo.git('branch', '--unset-upstream', 'root_A') + + # K is setup with G as it's root, but it's branched at B. + # L is setup with K as it's root, but it's branched at J. + diverged_branches = self.repo.run(self.gc.get_diverged_branches) + self.assertEqual(diverged_branches, ['branch_K', 'branch_L']) + + def testIsGitTreeDirty(self): retval = [] self.repo.capture_stdio(lambda: retval.append( diff --git a/tests/git_squash_branch_test.py b/tests/git_squash_branch_test.py new file mode 100644 index 000000000..d8bf14bb3 --- /dev/null +++ b/tests/git_squash_branch_test.py @@ -0,0 +1,146 @@ +#!/usr/bin/env vpython3 +# coding=utf-8 +# Copyright 2024 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. +"""Tests for git_squash_branch.""" + +import os +import sys +import unittest + +DEPOT_TOOLS_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, DEPOT_TOOLS_ROOT) + +from testing_support import git_test_utils + +import git_squash_branch +import git_common + +git_common.TEST_MODE = True + + +class GitSquashBranchTest(git_test_utils.GitRepoReadWriteTestBase): + # Empty repo. + REPO_SCHEMA = """ + """ + + def setUp(self): + super(GitSquashBranchTest, self).setUp() + + # Note: Using the REPO_SCHEMA wouldn't simplify this test so it is not + # used. + # + # Create a repo with the follow schema + # + # main <- branchA <- branchB <- branchC + # ^ + # \ branchD + # + # where each branch has 2 commits. + + # The repo is empty. Add the first commit or else most commands don't + # work, including `git branch`, which doesn't even show the main branch. + self.repo.git('commit', '-m', 'First commit', '--allow-empty') + + # Create the first branch downstream from `main` with 2 commits. + self.repo.git('checkout', '-B', 'branchA', '--track', 'main') + self._createFileAndCommit('fileA1') + self._createFileAndCommit('fileA2') + + # Create a branch downstream from `branchA` with 2 commits. + self.repo.git('checkout', '-B', 'branchB', '--track', 'branchA') + self._createFileAndCommit('fileB1') + self._createFileAndCommit('fileB2') + + # Create another branch downstream from `branchB` with 2 commits. + self.repo.git('checkout', '-B', 'branchC', '--track', 'branchB') + self._createFileAndCommit('fileC1') + self._createFileAndCommit('fileC2') + + # Create another branch downstream from `branchA` with 2 commits. + self.repo.git('checkout', '-B', 'branchD', '--track', 'branchA') + self._createFileAndCommit('fileD1') + self._createFileAndCommit('fileD2') + + def testGitSquashBranchFailsWithDivergedBranch(self): + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + self.repo.git('checkout', 'branchB') + self._createFileAndCommit('fileB3') + self.repo.git('checkout', 'branchA') + + # We have now made a state where branchC has diverged from branchB. + output, _ = self.repo.capture_stdio(git_squash_branch.main, []) + self.assertIn('some children have diverged', output) + + def testGitSquashBranchRootOnly(self): + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + self.repo.git('checkout', 'branchA') + self.repo.run(git_squash_branch.main, []) + + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + def testGitSquashBranchLeaf(self): + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + self.repo.git('checkout', 'branchD') + self.repo.run(git_squash_branch.main, []) + + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 1) + + def testGitSquashBranchSequential(self): + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + self.repo.git('checkout', 'branchA') + self.repo.run(git_squash_branch.main, []) + + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + self.repo.git('checkout', 'branchB') + self.repo.run(git_squash_branch.main, []) + + self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1) + self.assertEqual(self._getCountAheadOfUpstream('branchB'), 1) + self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2) + self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2) + + # Creates a file with arbitrary contents and commit it to the current + # branch. + def _createFileAndCommit(self, filename): + with self.repo.open(filename, 'w') as f: + f.write('content') + self.repo.git('add', filename) + self.repo.git_commit('Added file ' + filename) + + # Returns the count of how many commits `branch` is ahead of its upstream. + def _getCountAheadOfUpstream(self, branch): + upstream = branch + '@{u}' + output = self.repo.git('rev-list', '--count', + upstream + '..' + branch).stdout + return int(output) + + +if __name__ == '__main__': + unittest.main()