From 459e1a54a563a942ce881f2f26651867fb7a0255 Mon Sep 17 00:00:00 2001 From: Allen Li Date: Wed, 26 Jun 2024 22:35:50 +0000 Subject: [PATCH] [gerrit_util] Add ChainedAuthenticator Because we need to dynamically determine whether to use SSO. Bug: b/348024314 Change-Id: I5ac768f1e0c20254b4cfd4815270ee4e2b9a5544 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/5660884 Reviewed-by: Yiwei Zhang Commit-Queue: Allen Li Reviewed-by: Robbie Iannucci --- gerrit_util.py | 80 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 22 deletions(-) diff --git a/gerrit_util.py b/gerrit_util.py index a0f907f6c3..d39033a1ad 100644 --- a/gerrit_util.py +++ b/gerrit_util.py @@ -204,7 +204,7 @@ class Authenticator(object): raise NotImplementedError() @classmethod - def is_applicable(cls) -> bool: + def is_applicable(cls, *, conn: Optional[HttpConn] = None) -> bool: """Must return True if this Authenticator is available in the current environment.""" raise NotImplementedError() @@ -237,22 +237,32 @@ class Authenticator(object): skip_sso = newauth.SkipSSO() if use_new_auth: - LOGGER.debug('Authenticator.get: using new auth stack.') - authenticators = [ - SSOAuthenticator, - LuciContextAuthenticator, - GceAuthenticator, - LuciAuthAuthenticator, - ] - if skip_sso: - LOGGER.debug('Authenticator.get: skipping SSOAuthenticator.') - authenticators = authenticators[1:] - else: - authenticators = [ - LuciContextAuthenticator, - GceAuthenticator, - CookiesAuthenticator, - ] + LOGGER.debug('Authenticator.get: using new auth stack') + if LuciContextAuthenticator.is_applicable(): + LOGGER.debug( + 'Authenticator.get: using LUCI context authenticator') + ret = LuciContextAuthenticator() + else: + LOGGER.debug( + 'Authenticator.get: using chained authenticator') + a = [ + SSOAuthenticator(), + # GCE detection can't distinguish cloud workstations. + GceAuthenticator(), + LuciAuthAuthenticator(), + ] + if skip_sso: + LOGGER.debug( + 'Authenticator.get: skipping SSOAuthenticator.') + authenticators = authenticators[1:] + ret = ChainedAuthenticator(a) + cls._resolved = ret + return ret + authenticators = [ + LuciContextAuthenticator, + GceAuthenticator, + CookiesAuthenticator, + ] for candidate in authenticators: if candidate.is_applicable(): @@ -314,7 +324,7 @@ class SSOAuthenticator(Authenticator): ) @classmethod - def is_applicable(cls) -> bool: + def is_applicable(cls, *, conn: Optional[HttpConn] = None) -> bool: if not cls._resolve_sso_cmd(): return False email = scm.GIT.GetConfig(os.getcwd(), 'user.email', default='') @@ -497,7 +507,7 @@ class CookiesAuthenticator(Authenticator): self._gitcookies = self._EMPTY @classmethod - def is_applicable(cls) -> bool: + def is_applicable(cls, *, conn: Optional[HttpConn] = None) -> bool: # We consider CookiesAuthenticator always applicable for now. return True @@ -656,7 +666,7 @@ class GceAuthenticator(Authenticator): _token_expiration = None @classmethod - def is_applicable(cls): + def is_applicable(cls, *, conn: Optional[HttpConn] = None): if os.getenv('SKIP_GCE_AUTH_FOR_GIT'): return False if cls._cache_is_gce is None: @@ -726,7 +736,7 @@ class LuciContextAuthenticator(Authenticator): """Authenticator implementation that uses LUCI_CONTEXT ambient local auth. """ @staticmethod - def is_applicable(): + def is_applicable(*, conn: Optional[HttpConn] = None): return auth.has_luci_context_local_auth() def __init__(self): @@ -750,10 +760,36 @@ class LuciAuthAuthenticator(LuciContextAuthenticator): """ @staticmethod - def is_applicable(): + def is_applicable(*, conn: Optional[HttpConn] = None): return True +class ChainedAuthenticator(Authenticator): + """Authenticator that delegates to others in sequence. + + Authenticators should implement the method `is_applicable_for`. + """ + + def __init__(self, authenticators: List[Authenticator]): + self.authenticators = list(authenticators) + + def is_applicable(self, *, conn: Optional[HttpConn] = None) -> bool: + return bool(any( + a.is_applicable(conn=conn) for a in self.authenticators)) + + def authenticate(self, conn: HttpConn): + for a in self.authenticators: + if a.is_applicable(conn=conn): + a.authenticate(conn) + break + else: + raise ValueError( + f'{self!r} has no applicable authenticator for {conn!r}') + + def debug_summary_state(self) -> str: + return '' + + class ReqParams(TypedDict): uri: str method: str