diff --git a/scm.py b/scm.py index 425ff20b9..c5feeebaa 100644 --- a/scm.py +++ b/scm.py @@ -11,10 +11,11 @@ import pathlib import platform import re import threading +import typing from collections import defaultdict -from typing import Iterable, Literal, Dict, List, Optional, Sequence -from typing import Tuple, Mapping +from typing import Collection, Iterable, Literal, Dict +from typing import Optional, Sequence, Mapping import gclient_utils import git_common @@ -48,7 +49,7 @@ def determine_scm(root): GitConfigScope = Literal['system', 'local', 'worktree'] -GitScopeOrder: List[GitConfigScope] = ['system', 'local', 'worktree'] +GitScopeOrder: list[GitConfigScope] = ['system', 'local', 'worktree'] GitFlatConfigData = Mapping[str, Sequence[str]] @@ -185,11 +186,12 @@ class CachedGitConfigState(object): """ return self.GetConfig(key) == 'true' - def GetConfigList(self, key: str) -> List[str]: + def GetConfigList(self, key: str) -> list[str]: """Returns all values of `key` as a list of strings.""" - return self._maybe_load_config().get(key, []) + return list(self._maybe_load_config().get(key, [])) - def YieldConfigRegexp(self, pattern: Optional[str]) -> Iterable[Tuple[str, str]]: + def YieldConfigRegexp(self, + pattern: Optional[str]) -> Iterable[tuple[str, str]]: """Yields (key, value) pairs for any config keys matching `pattern`. This use re.match, so `pattern` needs to be for the entire config key. @@ -278,7 +280,8 @@ class GitConfigStateReal(GitConfigStateBase): except subprocess2.CalledProcessError: return {} - cfg: Dict[str, List[str]] = defaultdict(list) + assert isinstance(rawConfig, str) + cfg: Dict[str, list[str]] = defaultdict(list) # Splitting by '\x00' gets an additional empty string at the end. for line in rawConfig.split('\x00')[:-1]: @@ -337,14 +340,14 @@ class GitConfigStateTest(GitConfigStateBase): def __init__(self, initial_state: Optional[Dict[GitConfigScope, GitFlatConfigData]] = None): - self.state: Dict[GitConfigScope, Dict[str, List[str]]] = {} + self.state: Dict[GitConfigScope, Dict[str, list[str]]] = {} if initial_state is not None: # We want to copy initial_state to make it mutable inside our class. for scope, data in initial_state.items(): self.state[scope] = {k: list(v) for k, v in data.items()} super().__init__() - def _get_scope(self, scope: GitConfigScope) -> Dict[str, List[str]]: + def _get_scope(self, scope: GitConfigScope) -> Dict[str, list[str]]: ret = self.state.get(scope, None) if ret is None: ret = {} @@ -474,7 +477,8 @@ class GIT(object): with cls._CONFIG_CACHE_LOCK: state = {} for key, val in cls._CONFIG_CACHE.items(): - state[str(key)] = val._maybe_load_config() + if val is not None: + state[str(key)] = val._maybe_load_config() return state @staticmethod @@ -502,11 +506,11 @@ class GIT(object): return git_common.run(*args, **kwargs) @staticmethod - def CaptureStatus(cwd, - upstream_branch, - end_commit=None, - ignore_submodules=True): - # type: (str, str, Optional[str]) -> Sequence[Tuple[str, str]] + def CaptureStatus( + cwd: str, + upstream_branch: str, + end_commit: Optional[str] = None, + ignore_submodules: bool = True) -> Sequence[tuple[str, str]]: """Returns git status. Returns an array of (status, file) tuples.""" @@ -526,6 +530,7 @@ class GIT(object): command.extend(['-r', '%s...%s' % (upstream_branch, end_commit)]) status = GIT.Capture(command, cwd) + assert isinstance(status, str) results = [] if status: for statusline in status.splitlines(): @@ -561,12 +566,14 @@ class GIT(object): return GIT._get_config_state(cwd).GetConfigBool(key) @staticmethod - def GetConfigList(cwd: str, key: str) -> List[str]: + def GetConfigList(cwd: str, key: str) -> list[str]: """Returns all values of `key` as a list of strings.""" return GIT._get_config_state(cwd).GetConfigList(key) @staticmethod - def YieldConfigRegexp(cwd: str, pattern: Optional[str] = None) -> Iterable[Tuple[str, str]]: + def YieldConfigRegexp( + cwd: str, + pattern: Optional[str] = None) -> Iterable[tuple[str, str]]: """Yields (key, value) pairs for any config keys matching `pattern`. This use re.match, so `pattern` needs to be for the entire config key. @@ -646,10 +653,11 @@ class GIT(object): """Returns the full default remote branch reference, e.g. 'refs/remotes/origin/main'.""" if os.path.exists(cwd): + ref = 'refs/remotes/%s/HEAD' % remote try: # Try using local git copy first - ref = 'refs/remotes/%s/HEAD' % remote ref = GIT.Capture(['symbolic-ref', ref], cwd=cwd) + assert isinstance(ref, str) if not ref.endswith('master'): return ref except subprocess2.CalledProcessError: @@ -666,11 +674,14 @@ class GIT(object): try: # Fetch information from git server resp = GIT.Capture(['ls-remote', '--symref', url, 'HEAD']) + assert isinstance(resp, str) regex = r'^ref: (.*)\tHEAD$' for line in resp.split('\n'): m = re.match(regex, line) if m: - return ''.join(GIT.RefToRemoteRef(m.group(1), remote)) + refpair = GIT.RefToRemoteRef(m.group(1), remote) + assert isinstance(refpair, tuple) + return ''.join(refpair) except subprocess2.CalledProcessError: pass # Return default branch @@ -689,7 +700,10 @@ class GIT(object): return GIT.Capture(['branch', '-r'], cwd=cwd).split() @staticmethod - def FetchUpstreamTuple(cwd, branch=None): + def FetchUpstreamTuple( + cwd: str, + branch: Optional[str] = None + ) -> tuple[Optional[str], Optional[str]]: """Returns a tuple containing remote and remote ref, e.g. 'origin', 'refs/heads/main' """ @@ -721,7 +735,7 @@ class GIT(object): return None, None @staticmethod - def RefToRemoteRef(ref, remote): + def RefToRemoteRef(ref, remote) -> Optional[tuple[str, str]]: """Convert a checkout ref to the equivalent remote ref. Returns: @@ -757,7 +771,7 @@ class GIT(object): return None @staticmethod - def GetUpstreamBranch(cwd): + def GetUpstreamBranch(cwd) -> Optional[str]: """Gets the current branch's upstream branch.""" remote, upstream_branch = GIT.FetchUpstreamTuple(cwd) if remote != '.' and upstream_branch: @@ -767,8 +781,9 @@ class GIT(object): return upstream_branch @staticmethod - def IsAncestor(maybe_ancestor, ref, cwd=None): - # type: (string, string, Optional[string]) -> bool + def IsAncestor(maybe_ancestor: str, + ref: str, + cwd: Optional[str] = None) -> bool: """Verifies if |maybe_ancestor| is an ancestor of |ref|.""" try: GIT.Capture(['merge-base', '--is-ancestor', maybe_ancestor, ref], @@ -791,20 +806,27 @@ class GIT(object): return '' @staticmethod - def GenerateDiff(cwd, - branch=None, - branch_head='HEAD', - full_move=False, - files=None): + def GenerateDiff(cwd: str, + branch: Optional[str] = None, + branch_head: str = 'HEAD', + full_move: bool = False, + files: Optional[Iterable[str]] = None) -> str: """Diffs against the upstream branch or optionally another branch. full_move means that move or copy operations should completely recreate the files, usually in the prospect to apply the patch for a try job.""" if not branch: branch = GIT.GetUpstreamBranch(cwd) + assert isinstance(branch, str) command = [ - '-c', 'core.quotePath=false', 'diff', '-p', '--no-color', - '--no-prefix', '--no-ext-diff', branch + "..." + branch_head + '-c', + 'core.quotePath=false', + 'diff', + '-p', + '--no-color', + '--no-prefix', + '--no-ext-diff', + branch + "..." + branch_head, ] if full_move: command.append('--no-renames') @@ -814,7 +836,9 @@ class GIT(object): if files: command.append('--') command.extend(files) - diff = GIT.Capture(command, cwd=cwd, strip_out=False).splitlines(True) + output = GIT.Capture(command, cwd=cwd, strip_out=False) + assert isinstance(output, str) + diff = output.splitlines(True) for i in range(len(diff)): # In the case of added files, replace /dev/null with the path to the # file being added. @@ -829,8 +853,8 @@ class GIT(object): return GIT.Capture(command, cwd=cwd).splitlines(False) @staticmethod - def GetSubmoduleCommits(cwd, submodules): - # type: (string, List[string]) => Mapping[string][string] + def GetSubmoduleCommits(cwd: str, + submodules: list[str]) -> Mapping[str, str]: """Returns a mapping of staged or committed new commits for submodules.""" if not submodules: return {} @@ -849,7 +873,7 @@ class GIT(object): def GetCheckoutRoot(cwd) -> str: """Returns the top level directory of a git checkout as an absolute path. """ - root: str = GIT.Capture(['rev-parse', '--show-cdup'], cwd=cwd) + root = GIT.Capture(['rev-parse', '--show-cdup'], cwd=cwd) assert isinstance(root, str) return os.path.abspath(os.path.join(cwd, root)) @@ -861,10 +885,10 @@ class GIT(object): return False @staticmethod - def IsVersioned(cwd, relative_dir): - # type: (str, str) -> int + def IsVersioned(cwd: str, relative_dir: str) -> int: """Checks whether the given |relative_dir| is part of cwd's repo.""" output = GIT.Capture(['ls-tree', 'HEAD', '--', relative_dir], cwd=cwd) + assert isinstance(output, str) if not output: return VERSIONED_NO if output.startswith('160000'): @@ -872,8 +896,7 @@ class GIT(object): return VERSIONED_DIR @staticmethod - def ListSubmodules(repo_root): - # type: (str) -> Collection[str] + def ListSubmodules(repo_root: str) -> Collection[str]: """Returns the list of submodule paths for the given repo. Path separators will be adjusted for the current OS. @@ -883,6 +906,7 @@ class GIT(object): config_output = GIT.Capture( ['config', '--file', '.gitmodules', '--get-regexp', 'path'], cwd=repo_root) + assert isinstance(config_output, str) return [ line.split()[-1].replace('/', os.path.sep) for line in config_output.splitlines()