Modify git squash-branch to perform reparenting

Bug: 40264739
Change-Id: I4ad7f4f8a670334b32c239458048e56c6af44098
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/6227541
Reviewed-by: Josip Sokcevic <sokcevic@chromium.org>
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
Commit-Queue: Alexander Cooper <alcooper@chromium.org>
changes/41/6227541/6
Alexander Cooper committed by LUCI CQ
parent 569d698b0b
commit 380df04b62

@ -625,6 +625,60 @@ def get_branch_tree(use_limit=False):
return skipped, branch_tree
def get_diverged_branches(branch_tree=None):
"""Gets the branches from the tree that have diverged from their upstream
Returns the list of branches that have diverged from their respective
upstream branch.
Expects to receive a tree as generated from `get_branch_tree`, which it will
call if not supplied, ignoring branches without upstreams.
"""
if not branch_tree:
_, branch_tree = get_branch_tree()
diverged_branches = []
for branch, upstream_branch in branch_tree.items():
# If the merge base of a branch and its upstream is not equal to the
# upstream, then it means that both branch diverged.
upstream_branch_hash = hash_one(upstream_branch)
merge_base_hash = hash_one(get_or_create_merge_base(branch))
if upstream_branch_hash != merge_base_hash:
diverged_branches.append(branch)
return diverged_branches
def get_hashes(branch_tree=None):
"""Get the dictionary of {branch: hash}
Returns a dictionary that contains the hash of every branch. Suitable for
saving hashes before performing destructive operations to perform
appropriate rebases.
Expects to receive a tree as generated from `get_branch_tree`, which it will
call if not supplied, ignoring branches without upstreams.
"""
if not branch_tree:
_, branch_tree = get_branch_tree()
hashes = {}
for branch, upstream_branch in branch_tree.items():
hashes[branch] = hash_one(branch)
hashes[upstream_branch] = hash_one(upstream_branch)
return hashes
def get_downstream_branches(branch_tree=None):
"""Get the dictionary of {branch: children}
Returns a dictionary that contains the list of downstream branches for every
branch.
Expects to receive a tree as generated from `get_branch_tree`, which it will
call if not supplied, ignoring branches without upstreams.
"""
if not branch_tree:
_, branch_tree = get_branch_tree()
downstream_branches = collections.defaultdict(list)
for branch, upstream_branch in branch_tree.items():
downstream_branches[upstream_branch].append(branch)
return downstream_branches
def get_or_create_merge_base(branch, parent=None) -> Optional[str]:
"""Finds the configured merge base for branch.

@ -10,6 +10,48 @@ import gclient_utils
import git_common
# Squash a branch, taking care to rebase the branch on top of the new commit
# position of its upstream branch.
def rebase_branch(branch, initial_hashes):
print('Re-parenting branch %s.' % branch)
assert initial_hashes[branch] == git_common.hash_one(branch)
upstream_branch = git_common.upstream(branch)
old_upstream_branch = initial_hashes[upstream_branch]
# Because the branch's upstream has potentially changed from squashing it,
# the current branch is rebased on top of the new upstream.
git_common.run('rebase', '--onto', upstream_branch, old_upstream_branch,
branch, '--update-refs')
# Squashes all branches that are part of the subtree starting at `branch`.
def rebase_subtree(branch, initial_hashes, downstream_branches):
# Rebase us onto our parent
rebase_branch(branch, initial_hashes)
# Recurse on downstream branches, if any.
for downstream_branch in downstream_branches[branch]:
rebase_subtree(downstream_branch, initial_hashes, downstream_branches)
def children_have_diverged(branch, downstream_branches, diverged_branches):
# If we have no diverged branches, then no children have diverged.
if not diverged_branches:
return False
# If we have diverged, then our children have diverged.
if branch in diverged_branches:
return True
# If any of our children have diverged, then we need to return true.
for downstream_branch in downstream_branches[branch]:
if children_have_diverged(downstream_branch, downstream_branches,
diverged_branches):
return True
return False
def main(args):
if gclient_utils.IsEnvCog():
print('squash-branch command is not supported in non-git environment.',
@ -25,7 +67,34 @@ def main(args):
opts = parser.parse_args(args)
if git_common.is_dirty_git_tree('squash-branch'):
return 1
# Save off the current branch so we can return to it at the end.
return_branch = git_common.current_branch()
# Save the hashes before we mutate the tree so that we have all of the
# necessary rebasing information.
_, tree = git_common.get_branch_tree()
initial_hashes = git_common.get_hashes(tree)
downstream_branches = git_common.get_downstream_branches(tree)
diverged_branches = git_common.get_diverged_branches(tree)
# We won't be rebasing our squashed branch, so only check any potential
# children
for branch in downstream_branches[return_branch]:
if children_have_diverged(branch, downstream_branches,
diverged_branches):
print('Cannot use `git squash-branch` since some children have '
'diverged from their upstream and could cause conflicts.')
return 1
git_common.squash_current_branch(opts.message)
# Fixup our children with our new state.
for branch in downstream_branches[return_branch]:
rebase_subtree(branch, initial_hashes, downstream_branches)
git_common.run('checkout', return_branch)
return 0

@ -13,38 +13,6 @@ import git_common as git
import sys
# Returns the list of branches that have diverged from their respective upstream
# branch.
def get_diverged_branches(tree):
diverged_branches = []
for branch, upstream_branch in tree.items():
# If the merge base of a branch and its upstream is not equal to the
# upstream, then it means that both branch diverged.
upstream_branch_hash = git.hash_one(upstream_branch)
merge_base_hash = git.hash_one(git.get_or_create_merge_base(branch))
if upstream_branch_hash != merge_base_hash:
diverged_branches.append(branch)
return diverged_branches
# Returns a dictionary that contains the hash of every branch before the
# squashing started.
def get_initial_hashes(tree):
initial_hashes = {}
for branch, upstream_branch in tree.items():
initial_hashes[branch] = git.hash_one(branch)
initial_hashes[upstream_branch] = git.hash_one(upstream_branch)
return initial_hashes
# Returns a dictionary that contains the downstream branches of every branch.
def get_downstream_branches(tree):
downstream_branches = collections.defaultdict(list)
for branch, upstream_branch in tree.items():
downstream_branches[upstream_branch].append(branch)
return downstream_branches
# Squash a branch, taking care to rebase the branch on top of the new commit
# position of its upstream branch.
def squash_branch(branch, initial_hashes):
@ -102,7 +70,7 @@ def main(args=None):
print('Use --ignore-no-upstream to ignore this check and proceed.')
return 1
diverged_branches = get_diverged_branches(tree)
diverged_branches = git.get_diverged_branches(tree)
if diverged_branches:
print('Cannot use `git squash-branch-tree` since the following\n'
'branches have diverged from their upstream and could cause\n'
@ -115,8 +83,8 @@ def main(args=None):
# we can go back to it at the end.
return_branch = git.current_branch()
initial_hashes = get_initial_hashes(tree)
downstream_branches = get_downstream_branches(tree)
initial_hashes = git.get_hashes(tree)
downstream_branches = git.get_downstream_branches(tree)
squash_subtree(opts.branch, initial_hashes, downstream_branches)
git.run('checkout', return_branch)

@ -751,6 +751,34 @@ class GitMutableStructuredTest(git_test_utils.GitRepoReadWriteTestBase,
('root_A', 'root_X'),
])
def testGetHashes(self):
hashes = self.repo.run(self.gc.get_hashes)
for branch, branch_hash in hashes.items():
self.assertEqual(self.repo.run(self.gc.hash_one, branch),
branch_hash)
def testGetDownstreamBranches(self):
downstream_branches = self.repo.run(self.gc.get_downstream_branches)
self.assertEqual(
downstream_branches, {
'root_A': ['branch_G'],
'branch_G': ['branch_K'],
'branch_K': ['branch_L'],
'root_X': ['branch_Z', 'root_A'],
})
def testGetDivergedBranches(self):
# root_X and root_A don't actually have a common base commit due to the
# test repo's structure, which causes get_diverged_branches to throw
# an error.
self.repo.git('branch', '--unset-upstream', 'root_A')
# K is setup with G as it's root, but it's branched at B.
# L is setup with K as it's root, but it's branched at J.
diverged_branches = self.repo.run(self.gc.get_diverged_branches)
self.assertEqual(diverged_branches, ['branch_K', 'branch_L'])
def testIsGitTreeDirty(self):
retval = []
self.repo.capture_stdio(lambda: retval.append(

@ -0,0 +1,146 @@
#!/usr/bin/env vpython3
# coding=utf-8
# Copyright 2024 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Tests for git_squash_branch."""
import os
import sys
import unittest
DEPOT_TOOLS_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, DEPOT_TOOLS_ROOT)
from testing_support import git_test_utils
import git_squash_branch
import git_common
git_common.TEST_MODE = True
class GitSquashBranchTest(git_test_utils.GitRepoReadWriteTestBase):
# Empty repo.
REPO_SCHEMA = """
"""
def setUp(self):
super(GitSquashBranchTest, self).setUp()
# Note: Using the REPO_SCHEMA wouldn't simplify this test so it is not
# used.
#
# Create a repo with the follow schema
#
# main <- branchA <- branchB <- branchC
# ^
# \ branchD
#
# where each branch has 2 commits.
# The repo is empty. Add the first commit or else most commands don't
# work, including `git branch`, which doesn't even show the main branch.
self.repo.git('commit', '-m', 'First commit', '--allow-empty')
# Create the first branch downstream from `main` with 2 commits.
self.repo.git('checkout', '-B', 'branchA', '--track', 'main')
self._createFileAndCommit('fileA1')
self._createFileAndCommit('fileA2')
# Create a branch downstream from `branchA` with 2 commits.
self.repo.git('checkout', '-B', 'branchB', '--track', 'branchA')
self._createFileAndCommit('fileB1')
self._createFileAndCommit('fileB2')
# Create another branch downstream from `branchB` with 2 commits.
self.repo.git('checkout', '-B', 'branchC', '--track', 'branchB')
self._createFileAndCommit('fileC1')
self._createFileAndCommit('fileC2')
# Create another branch downstream from `branchA` with 2 commits.
self.repo.git('checkout', '-B', 'branchD', '--track', 'branchA')
self._createFileAndCommit('fileD1')
self._createFileAndCommit('fileD2')
def testGitSquashBranchFailsWithDivergedBranch(self):
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
self.repo.git('checkout', 'branchB')
self._createFileAndCommit('fileB3')
self.repo.git('checkout', 'branchA')
# We have now made a state where branchC has diverged from branchB.
output, _ = self.repo.capture_stdio(git_squash_branch.main, [])
self.assertIn('some children have diverged', output)
def testGitSquashBranchRootOnly(self):
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
self.repo.git('checkout', 'branchA')
self.repo.run(git_squash_branch.main, [])
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
def testGitSquashBranchLeaf(self):
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
self.repo.git('checkout', 'branchD')
self.repo.run(git_squash_branch.main, [])
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 1)
def testGitSquashBranchSequential(self):
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
self.repo.git('checkout', 'branchA')
self.repo.run(git_squash_branch.main, [])
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
self.repo.git('checkout', 'branchB')
self.repo.run(git_squash_branch.main, [])
self.assertEqual(self._getCountAheadOfUpstream('branchA'), 1)
self.assertEqual(self._getCountAheadOfUpstream('branchB'), 1)
self.assertEqual(self._getCountAheadOfUpstream('branchC'), 2)
self.assertEqual(self._getCountAheadOfUpstream('branchD'), 2)
# Creates a file with arbitrary contents and commit it to the current
# branch.
def _createFileAndCommit(self, filename):
with self.repo.open(filename, 'w') as f:
f.write('content')
self.repo.git('add', filename)
self.repo.git_commit('Added file ' + filename)
# Returns the count of how many commits `branch` is ahead of its upstream.
def _getCountAheadOfUpstream(self, branch):
upstream = branch + '@{u}'
output = self.repo.git('rev-list', '--count',
upstream + '..' + branch).stdout
return int(output)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save