diff --git a/git_auth.py b/git_auth.py index 0b8d46151..c11d12efe 100644 --- a/git_auth.py +++ b/git_auth.py @@ -465,18 +465,27 @@ class UserInterface(object): self._stdout.write(s) +RemoteURLFunc = Callable[[], str] + + class ConfigWizard(object): - """Wizard for setting up user's Git config Gerrit authentication.""" + """Wizard for setting up user's Git config Gerrit authentication. - def __init__(self, ui): + Instances carry internal state, so cannot be reused. + """ + + def __init__(self, *, ui: UserInterface, remote_url_func: RemoteURLFunc): self._ui = ui + self._remote_url_func = remote_url_func + + # Internal state self._user_actions = [] - def run(self, *, remote_url: str, force_global: bool): + def run(self, *, force_global: bool): with self._handle_config_errors(): - self._run(remote_url=remote_url, force_global=force_global) + self._run(force_global=force_global) - def _run(self, *, remote_url: str, force_global: bool): + def _run(self, *, force_global: bool): self._println('This tool will help check your Gerrit authentication.') self._println( '(Report any issues to https://issues.chromium.org/issues/new?component=1456702&template=2076315)' @@ -489,6 +498,7 @@ class ConfigWizard(object): self._println('SSO helper is available.') self._set_config('protocol.sso.allow', 'always', scope='global') self._println() + remote_url = self._remote_url_func() if _is_gerrit_url(remote_url): if force_global: self._println( diff --git a/git_cl.py b/git_cl.py index e7718b004..fcfd6f38c 100755 --- a/git_cl.py +++ b/git_cl.py @@ -3908,14 +3908,18 @@ def CMDcreds_check(parser, args): options, args = parser.parse_args(args) if newauth.SwitchedOn(): - cl = Changelist() - try: - remote_url = cl.GetRemoteUrl() - except subprocess2.CalledProcessError: - remote_url = '' - wizard = git_auth.ConfigWizard( - git_auth.UserInterface(sys.stdin, sys.stdout)) - wizard.run(remote_url=remote_url, force_global=options.force_global) + + def f() -> str: + cl = Changelist() + try: + return cl.GetRemoteUrl() + except subprocess2.CalledProcessError: + return '' + + wizard = git_auth.ConfigWizard(ui=git_auth.UserInterface( + sys.stdin, sys.stdout), + remote_url_func=f) + wizard.run(force_global=options.force_global) return 0 if newauth.ExplicitlyDisabled(): git_auth.ClearRepoConfig(os.getcwd(), Changelist()) diff --git a/tests/git_auth_test.py b/tests/git_auth_test.py index 286e36b00..96e808171 100755 --- a/tests/git_auth_test.py +++ b/tests/git_auth_test.py @@ -258,7 +258,9 @@ class TestConfigWizard(unittest.TestCase): self._global_state_view: Iterable[tuple[str, list[str]]] = scm_mock.GIT(self) self.ui = _FakeUI() - self.wizard = git_auth.ConfigWizard(self.ui) + + self.wizard = git_auth.ConfigWizard( + ui=self.ui, remote_url_func=lambda: 'remote.example.com') @property def global_state(self):