diff --git a/gerrit_util.py b/gerrit_util.py index eb0003034..f31a34e3b 100644 --- a/gerrit_util.py +++ b/gerrit_util.py @@ -9,8 +9,7 @@ https://gerrit-review.googlesource.com/Documentation/rest-api.html import base64 import contextlib -from typing import List, Optional, Tuple, Type -import httplib2 +import http.cookiejar import json import logging import os @@ -20,7 +19,12 @@ import socket import tempfile import time import urllib.parse + +from io import StringIO from multiprocessing.pool import ThreadPool +from typing import Any, Container, Dict, List, Optional, Tuple, Type, TypedDict + +import httplib2 import auth import gclient_utils @@ -28,8 +32,6 @@ import metrics import metrics_utils import scm -import http.cookiejar -from io import StringIO # TODO: Should fix these warnings. # pylint: disable=line-too-long @@ -251,42 +253,39 @@ class CookiesAuthenticator(Authenticator): def ensure_authenticated(self, gerrit_host: str, git_host: str) -> Tuple[bool, str]: - """Returns (bypassable, error message). - - If the error message is empty, there is no error to report. - If bypassable is true, the caller will allow the user to continue past the - error. - """ - # Lazy-loader to identify Gerrit and Git hosts. - gerrit_auth = self._get_auth_for_host(gerrit_host) - git_auth = self._get_auth_for_host(git_host) - if gerrit_auth and git_auth: - if gerrit_auth == git_auth: - return True, '' - all_gsrc, _ = self.get_auth_info( - 'd0esN0tEx1st.googlesource.com') - print( - 'WARNING: You have different credentials for Gerrit and git hosts:\n' - ' %s\n' - ' %s\n' - ' Consider running the following command:\n' - ' git cl creds-check\n' - ' %s\n' - ' %s' % - (git_host, gerrit_host, - ('Hint: delete creds for .googlesource.com' if all_gsrc else - ''), self._get_new_password_message(git_host))) - return True, 'If you know what you are doing' - - missing = (([] if gerrit_auth else [gerrit_host]) + - ([] if git_auth else [git_host])) - return False, ( - 'Credentials for the following hosts are required:\n' - ' %s\n' - 'These are read from %s\n' - '%s' % - ('\n '.join(missing), self.get_gitcookies_path(), - self._get_new_password_message(git_host))) + """Returns (bypassable, error message). + + If the error message is empty, there is no error to report. + If bypassable is true, the caller will allow the user to continue past the + error. + """ + # Lazy-loader to identify Gerrit and Git hosts. + gerrit_auth = self._get_auth_for_host(gerrit_host) + git_auth = self._get_auth_for_host(git_host) + if gerrit_auth and git_auth: + if gerrit_auth == git_auth: + return True, '' + all_gsrc, _ = self.get_auth_info('d0esN0tEx1st.googlesource.com') + print( + 'WARNING: You have different credentials for Gerrit and git hosts:\n' + ' %s\n' + ' %s\n' + ' Consider running the following command:\n' + ' git cl creds-check\n' + ' %s\n' + ' %s' % + (git_host, gerrit_host, + ('Hint: delete creds for .googlesource.com' if all_gsrc else + ''), self._get_new_password_message(git_host))) + return True, 'If you know what you are doing' + + missing = (([] if gerrit_auth else [gerrit_host]) + + ([] if git_auth else [git_host])) + return False, ('Credentials for the following hosts are required:\n' + ' %s\n' + 'These are read from %s\n' + '%s' % ('\n '.join(missing), self.get_gitcookies_path(), + self._get_new_password_message(git_host))) # Used to redact the cookies from the gitcookies file. @@ -410,12 +409,42 @@ class LuciContextAuthenticator(Authenticator): return '' +class ReqParams(TypedDict): + uri: str + method: str + headers: Dict[str, str] + body: Optional[str] + + +class HttpConn(httplib2.Http): + """HttpConn is an httplib2.Http with additional request-specific fields.""" + + def __init__(self, *args, req_host: str, req_uri: str, req_method: str, + req_headers: Dict[str, str], req_body: Optional[str], + **kwargs) -> None: + self.req_host = req_host + self.req_uri = req_uri + self.req_method = req_method + self.req_headers = req_headers + self.req_body = req_body + super().__init__(*args, **kwargs) + + @property + def req_params(self) -> ReqParams: + return { + 'uri': self.req_uri, + 'method': self.req_method, + 'headers': self.req_headers, + 'body': self.req_body, + } + + def CreateHttpConn(host, path, reqtype='GET', - headers=None, - body=None, - timeout=300): + headers: Optional[Dict[str, str]] = None, + body: Optional[Dict] = None, + timeout=300) -> HttpConn: """Opens an HTTPS connection to a Gerrit service, and sends a request.""" headers = headers or {} bare_host = host.partition(':')[0] @@ -438,8 +467,9 @@ def CreateHttpConn(host, if auth_header and not url.startswith('/a/'): url = '/a%s' % url + rendered_body: Optional[str] = None if body: - body = json.dumps(body, sort_keys=True) + rendered_body = json.dumps(body, sort_keys=True) headers.setdefault('Content-Type', 'application/json') if LOGGER.isEnabledFor(logging.DEBUG): LOGGER.debug('%s %s://%s%s' % (reqtype, GERRIT_PROTOCOL, host, url)) @@ -447,22 +477,21 @@ def CreateHttpConn(host, if key == 'Authorization': val = 'HIDDEN' LOGGER.debug('%s: %s' % (key, val)) - if body: - LOGGER.debug(body) - conn = httplib2.Http(timeout=timeout, proxy_info=proxy) - # HACK: httplib2.Http has no such attribute; we store req_host here for - # later use in ReadHttpResponse. - conn.req_host = host - conn.req_params = { - 'uri': urllib.parse.urljoin('%s://%s' % (GERRIT_PROTOCOL, host), url), - 'method': reqtype, - 'headers': headers, - 'body': body, - } - return conn + if rendered_body: + LOGGER.debug(rendered_body) + + uri = urllib.parse.urljoin(f'{GERRIT_PROTOCOL}://{host}', url) + return HttpConn(timeout=timeout, + proxy_info=proxy, + req_host=host, + req_uri=uri, + req_method=reqtype, + req_headers=headers, + req_body=rendered_body) -def ReadHttpResponse(conn, accept_statuses=frozenset([200])): +def ReadHttpResponse(conn: HttpConn, + accept_statuses: Container[int] = frozenset([200])): """Reads an HTTP response from a connection into a string buffer. Args: @@ -472,6 +501,7 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])): Returns: A string buffer containing the connection's reply. """ + response = contents = None sleep_time = SLEEP_TIME for idx in range(TRY_LIMIT): before_response = time.time() @@ -521,6 +551,11 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])): sleep_time = log_retry_and_sleep(sleep_time, idx) # end of retries loop + # Help the type checker a bit here - it can't figure out the `except` logic + # in the loop above. + assert response, ( + "Impossible: End of retry loop without response or exception.") + if response.status in accept_statuses: return StringIO(contents) @@ -539,7 +574,8 @@ def ReadHttpResponse(conn, accept_statuses=frozenset([200])): raise GerritError(response.status, reason) -def ReadHttpJsonResponse(conn, accept_statuses=frozenset([200])): +def ReadHttpJsonResponse( + conn, accept_statuses: Container[int] = frozenset([200])) -> Dict: """Parses an https response as json.""" fh = ReadHttpResponse(conn, accept_statuses) # The first line of the response should always be: )]}' @@ -548,7 +584,7 @@ def ReadHttpJsonResponse(conn, accept_statuses=frozenset([200])): raise GerritError(200, 'Unexpected json output: %s' % s[:100]) s = fh.read() if not s: - return None + return {} return json.loads(s) @@ -596,7 +632,7 @@ def QueryChanges(host, path = '%s&n=%d' % (path, limit) if o_params: path = '%s&%s' % (path, '&'.join(['o=%s' % p for p in o_params])) - return ReadHttpJsonResponse(CreateHttpConn(host, path, timeout=30.0)) + return ReadHttpJsonResponse(CreateHttpConn(host, path, timeout=30)) def GenerateAllChanges(host, @@ -981,7 +1017,7 @@ def AddReviewers(host, reviewers=None, ccs=None, notify=True, - accept_statuses=frozenset([200, 400, 422])): + accept_statuses: Container[int] = frozenset([200, 400, 422])): """Add reviewers to a change.""" if not reviewers and not ccs: return None @@ -1032,7 +1068,7 @@ def SetReview(host, change, msg=None, labels=None, notify=None, ready=None): if not msg and not labels: return path = 'changes/%s/revisions/current/review' % change - body = {'drafts': 'KEEP'} + body: Dict[str, Any] = {'drafts': 'KEEP'} if msg: body['message'] = msg if labels: diff --git a/tests/gerrit_util_test.py b/tests/gerrit_util_test.py index e7c407cc2..57f6df687 100755 --- a/tests/gerrit_util_test.py +++ b/tests/gerrit_util_test.py @@ -449,7 +449,7 @@ class GerritUtilTest(unittest.TestCase): @mock.patch('gerrit_util.ReadHttpResponse') def testReadHttpJsonResponse_EmptyValue(self, mockReadHttpResponse): mockReadHttpResponse.return_value = StringIO(')]}\'') - self.assertIsNone(gerrit_util.ReadHttpJsonResponse(None)) + self.assertEqual(gerrit_util.ReadHttpJsonResponse(None), {}) @mock.patch('gerrit_util.ReadHttpResponse') def testReadHttpJsonResponse_JSON(self, mockReadHttpResponse):