[scm] Fix type errors

Bug: b/351071334
Change-Id: I43855632ac36a06569047d688fe710b83df4e707
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/5751005
Commit-Queue: Allen Li <ayatane@chromium.org>
Reviewed-by: Yiwei Zhang <yiwzhang@google.com>
changes/05/5751005/3
Allen Li 10 months ago committed by LUCI CQ
parent 0c86eac988
commit a6cd7589e9

102
scm.py

@ -11,10 +11,11 @@ import pathlib
import platform
import re
import threading
import typing
from collections import defaultdict
from typing import Iterable, Literal, Dict, List, Optional, Sequence
from typing import Tuple, Mapping
from typing import Collection, Iterable, Literal, Dict
from typing import Optional, Sequence, Mapping
import gclient_utils
import git_common
@ -48,7 +49,7 @@ def determine_scm(root):
GitConfigScope = Literal['system', 'local', 'worktree']
GitScopeOrder: List[GitConfigScope] = ['system', 'local', 'worktree']
GitScopeOrder: list[GitConfigScope] = ['system', 'local', 'worktree']
GitFlatConfigData = Mapping[str, Sequence[str]]
@ -185,11 +186,12 @@ class CachedGitConfigState(object):
"""
return self.GetConfig(key) == 'true'
def GetConfigList(self, key: str) -> List[str]:
def GetConfigList(self, key: str) -> list[str]:
"""Returns all values of `key` as a list of strings."""
return self._maybe_load_config().get(key, [])
return list(self._maybe_load_config().get(key, []))
def YieldConfigRegexp(self, pattern: Optional[str]) -> Iterable[Tuple[str, str]]:
def YieldConfigRegexp(self,
pattern: Optional[str]) -> Iterable[tuple[str, str]]:
"""Yields (key, value) pairs for any config keys matching `pattern`.
This use re.match, so `pattern` needs to be for the entire config key.
@ -278,7 +280,8 @@ class GitConfigStateReal(GitConfigStateBase):
except subprocess2.CalledProcessError:
return {}
cfg: Dict[str, List[str]] = defaultdict(list)
assert isinstance(rawConfig, str)
cfg: Dict[str, list[str]] = defaultdict(list)
# Splitting by '\x00' gets an additional empty string at the end.
for line in rawConfig.split('\x00')[:-1]:
@ -337,14 +340,14 @@ class GitConfigStateTest(GitConfigStateBase):
def __init__(self,
initial_state: Optional[Dict[GitConfigScope,
GitFlatConfigData]] = None):
self.state: Dict[GitConfigScope, Dict[str, List[str]]] = {}
self.state: Dict[GitConfigScope, Dict[str, list[str]]] = {}
if initial_state is not None:
# We want to copy initial_state to make it mutable inside our class.
for scope, data in initial_state.items():
self.state[scope] = {k: list(v) for k, v in data.items()}
super().__init__()
def _get_scope(self, scope: GitConfigScope) -> Dict[str, List[str]]:
def _get_scope(self, scope: GitConfigScope) -> Dict[str, list[str]]:
ret = self.state.get(scope, None)
if ret is None:
ret = {}
@ -474,7 +477,8 @@ class GIT(object):
with cls._CONFIG_CACHE_LOCK:
state = {}
for key, val in cls._CONFIG_CACHE.items():
state[str(key)] = val._maybe_load_config()
if val is not None:
state[str(key)] = val._maybe_load_config()
return state
@staticmethod
@ -502,11 +506,11 @@ class GIT(object):
return git_common.run(*args, **kwargs)
@staticmethod
def CaptureStatus(cwd,
upstream_branch,
end_commit=None,
ignore_submodules=True):
# type: (str, str, Optional[str]) -> Sequence[Tuple[str, str]]
def CaptureStatus(
cwd: str,
upstream_branch: str,
end_commit: Optional[str] = None,
ignore_submodules: bool = True) -> Sequence[tuple[str, str]]:
"""Returns git status.
Returns an array of (status, file) tuples."""
@ -526,6 +530,7 @@ class GIT(object):
command.extend(['-r', '%s...%s' % (upstream_branch, end_commit)])
status = GIT.Capture(command, cwd)
assert isinstance(status, str)
results = []
if status:
for statusline in status.splitlines():
@ -561,12 +566,14 @@ class GIT(object):
return GIT._get_config_state(cwd).GetConfigBool(key)
@staticmethod
def GetConfigList(cwd: str, key: str) -> List[str]:
def GetConfigList(cwd: str, key: str) -> list[str]:
"""Returns all values of `key` as a list of strings."""
return GIT._get_config_state(cwd).GetConfigList(key)
@staticmethod
def YieldConfigRegexp(cwd: str, pattern: Optional[str] = None) -> Iterable[Tuple[str, str]]:
def YieldConfigRegexp(
cwd: str,
pattern: Optional[str] = None) -> Iterable[tuple[str, str]]:
"""Yields (key, value) pairs for any config keys matching `pattern`.
This use re.match, so `pattern` needs to be for the entire config key.
@ -646,10 +653,11 @@ class GIT(object):
"""Returns the full default remote branch reference, e.g.
'refs/remotes/origin/main'."""
if os.path.exists(cwd):
ref = 'refs/remotes/%s/HEAD' % remote
try:
# Try using local git copy first
ref = 'refs/remotes/%s/HEAD' % remote
ref = GIT.Capture(['symbolic-ref', ref], cwd=cwd)
assert isinstance(ref, str)
if not ref.endswith('master'):
return ref
except subprocess2.CalledProcessError:
@ -666,11 +674,14 @@ class GIT(object):
try:
# Fetch information from git server
resp = GIT.Capture(['ls-remote', '--symref', url, 'HEAD'])
assert isinstance(resp, str)
regex = r'^ref: (.*)\tHEAD$'
for line in resp.split('\n'):
m = re.match(regex, line)
if m:
return ''.join(GIT.RefToRemoteRef(m.group(1), remote))
refpair = GIT.RefToRemoteRef(m.group(1), remote)
assert isinstance(refpair, tuple)
return ''.join(refpair)
except subprocess2.CalledProcessError:
pass
# Return default branch
@ -689,7 +700,10 @@ class GIT(object):
return GIT.Capture(['branch', '-r'], cwd=cwd).split()
@staticmethod
def FetchUpstreamTuple(cwd, branch=None):
def FetchUpstreamTuple(
cwd: str,
branch: Optional[str] = None
) -> tuple[Optional[str], Optional[str]]:
"""Returns a tuple containing remote and remote ref,
e.g. 'origin', 'refs/heads/main'
"""
@ -721,7 +735,7 @@ class GIT(object):
return None, None
@staticmethod
def RefToRemoteRef(ref, remote):
def RefToRemoteRef(ref, remote) -> Optional[tuple[str, str]]:
"""Convert a checkout ref to the equivalent remote ref.
Returns:
@ -757,7 +771,7 @@ class GIT(object):
return None
@staticmethod
def GetUpstreamBranch(cwd):
def GetUpstreamBranch(cwd) -> Optional[str]:
"""Gets the current branch's upstream branch."""
remote, upstream_branch = GIT.FetchUpstreamTuple(cwd)
if remote != '.' and upstream_branch:
@ -767,8 +781,9 @@ class GIT(object):
return upstream_branch
@staticmethod
def IsAncestor(maybe_ancestor, ref, cwd=None):
# type: (string, string, Optional[string]) -> bool
def IsAncestor(maybe_ancestor: str,
ref: str,
cwd: Optional[str] = None) -> bool:
"""Verifies if |maybe_ancestor| is an ancestor of |ref|."""
try:
GIT.Capture(['merge-base', '--is-ancestor', maybe_ancestor, ref],
@ -791,20 +806,27 @@ class GIT(object):
return ''
@staticmethod
def GenerateDiff(cwd,
branch=None,
branch_head='HEAD',
full_move=False,
files=None):
def GenerateDiff(cwd: str,
branch: Optional[str] = None,
branch_head: str = 'HEAD',
full_move: bool = False,
files: Optional[Iterable[str]] = None) -> str:
"""Diffs against the upstream branch or optionally another branch.
full_move means that move or copy operations should completely recreate the
files, usually in the prospect to apply the patch for a try job."""
if not branch:
branch = GIT.GetUpstreamBranch(cwd)
assert isinstance(branch, str)
command = [
'-c', 'core.quotePath=false', 'diff', '-p', '--no-color',
'--no-prefix', '--no-ext-diff', branch + "..." + branch_head
'-c',
'core.quotePath=false',
'diff',
'-p',
'--no-color',
'--no-prefix',
'--no-ext-diff',
branch + "..." + branch_head,
]
if full_move:
command.append('--no-renames')
@ -814,7 +836,9 @@ class GIT(object):
if files:
command.append('--')
command.extend(files)
diff = GIT.Capture(command, cwd=cwd, strip_out=False).splitlines(True)
output = GIT.Capture(command, cwd=cwd, strip_out=False)
assert isinstance(output, str)
diff = output.splitlines(True)
for i in range(len(diff)):
# In the case of added files, replace /dev/null with the path to the
# file being added.
@ -829,8 +853,8 @@ class GIT(object):
return GIT.Capture(command, cwd=cwd).splitlines(False)
@staticmethod
def GetSubmoduleCommits(cwd, submodules):
# type: (string, List[string]) => Mapping[string][string]
def GetSubmoduleCommits(cwd: str,
submodules: list[str]) -> Mapping[str, str]:
"""Returns a mapping of staged or committed new commits for submodules."""
if not submodules:
return {}
@ -849,7 +873,7 @@ class GIT(object):
def GetCheckoutRoot(cwd) -> str:
"""Returns the top level directory of a git checkout as an absolute path.
"""
root: str = GIT.Capture(['rev-parse', '--show-cdup'], cwd=cwd)
root = GIT.Capture(['rev-parse', '--show-cdup'], cwd=cwd)
assert isinstance(root, str)
return os.path.abspath(os.path.join(cwd, root))
@ -861,10 +885,10 @@ class GIT(object):
return False
@staticmethod
def IsVersioned(cwd, relative_dir):
# type: (str, str) -> int
def IsVersioned(cwd: str, relative_dir: str) -> int:
"""Checks whether the given |relative_dir| is part of cwd's repo."""
output = GIT.Capture(['ls-tree', 'HEAD', '--', relative_dir], cwd=cwd)
assert isinstance(output, str)
if not output:
return VERSIONED_NO
if output.startswith('160000'):
@ -872,8 +896,7 @@ class GIT(object):
return VERSIONED_DIR
@staticmethod
def ListSubmodules(repo_root):
# type: (str) -> Collection[str]
def ListSubmodules(repo_root: str) -> Collection[str]:
"""Returns the list of submodule paths for the given repo.
Path separators will be adjusted for the current OS.
@ -883,6 +906,7 @@ class GIT(object):
config_output = GIT.Capture(
['config', '--file', '.gitmodules', '--get-regexp', 'path'],
cwd=repo_root)
assert isinstance(config_output, str)
return [
line.split()[-1].replace('/', os.path.sep)
for line in config_output.splitlines()

Loading…
Cancel
Save