diff --git a/gerrit_util.py b/gerrit_util.py index 62fc1b9ab5..691daf9182 100755 --- a/gerrit_util.py +++ b/gerrit_util.py @@ -15,36 +15,19 @@ import logging import netrc import os import re +import socket import stat import sys import time import urllib +import urlparse from cStringIO import StringIO -_netrc_file = '_netrc' if sys.platform.startswith('win') else '.netrc' -_netrc_file = os.path.join(os.environ['HOME'], _netrc_file) -try: - NETRC = netrc.netrc(_netrc_file) -except IOError: - print >> sys.stderr, 'WARNING: Could not read netrc file %s' % _netrc_file - NETRC = netrc.netrc(os.devnull) -except netrc.NetrcParseError as e: - _netrc_stat = os.stat(e.filename) - if _netrc_stat.st_mode & (stat.S_IRWXG | stat.S_IRWXO): - print >> sys.stderr, ( - 'WARNING: netrc file %s cannot be used because its file permissions ' - 'are insecure. netrc file permissions should be 600.' % _netrc_file) - else: - print >> sys.stderr, ('ERROR: Cannot use netrc file %s due to a parsing ' - 'error.' % _netrc_file) - raise - del _netrc_stat - NETRC = netrc.netrc(os.devnull) -del _netrc_file LOGGER = logging.getLogger() TRY_LIMIT = 5 + # Controls the transport protocol used to communicate with gerrit. # This is parameterized primarily to enable GerritTestCase. GERRIT_PROTOCOL = 'https' @@ -84,17 +67,141 @@ def GetConnectionClass(protocol=None): "Don't know how to work with protocol '%s'" % protocol) +class Authenticator(object): + """Base authenticator class for authenticator implementations to subclass.""" + + def get_auth_header(self, host): + raise NotImplementedError() + + @staticmethod + def get(): + """Returns: (Authenticator) The identified Authenticator to use. + + Probes the local system and its environment and identifies the + Authenticator instance to use. + """ + if GceAuthenticator.is_gce(): + return GceAuthenticator() + return NetrcAuthenticator() + + +class NetrcAuthenticator(Authenticator): + """Authenticator implementation that uses ".netrc" for token. + """ + + def __init__(self): + self.netrc = self._get_netrc() + + @staticmethod + def _get_netrc(): + path = '_netrc' if sys.platform.startswith('win') else '.netrc' + path = os.path.join(os.environ['HOME'], path) + try: + return netrc.netrc(path) + except IOError: + print >> sys.stderr, 'WARNING: Could not read netrc file %s' % path + return netrc.netrc(os.devnull) + except netrc.NetrcParseError as e: + st = os.stat(e.path) + if st.st_mode & (stat.S_IRWXG | stat.S_IRWXO): + print >> sys.stderr, ( + 'WARNING: netrc file %s cannot be used because its file ' + 'permissions are insecure. netrc file permissions should be ' + '600.' % path) + else: + print >> sys.stderr, ('ERROR: Cannot use netrc file %s due to a ' + 'parsing error.' % path) + raise + return netrc.netrc(os.devnull) + + def get_auth_header(self, host): + auth = self.netrc.authenticators(host) + if auth: + return 'Basic %s' % (base64.b64encode('%s:%s' % (auth[0], auth[2]))) + return None + + +class GceAuthenticator(Authenticator): + """Authenticator implementation that uses GCE metadata service for token. + """ + + _INFO_URL = 'http://metadata.google.internal' + _ACQUIRE_URL = ('http://metadata/computeMetadata/v1/instance/' + 'service-accounts/default/token') + _ACQUIRE_HEADERS = {"Metadata-Flavor": "Google"} + + _cache_is_gce = None + _token_cache = None + _token_expiration = None + + @classmethod + def is_gce(cls): + if cls._cache_is_gce is None: + cls._cache_is_gce = cls._test_is_gce() + return cls._cache_is_gce + + @classmethod + def _test_is_gce(cls): + # Based on https://cloud.google.com/compute/docs/metadata#runninggce + try: + resp = cls._get(cls._INFO_URL) + except socket.error: + # Could not resolve URL. + return False + return resp.getheader('Metadata-Flavor', None) == 'Google' + + @staticmethod + def _get(url, **kwargs): + next_delay_sec = 1 + for i in xrange(TRY_LIMIT): + if i > 0: + # Retry server error status codes. + LOGGER.info('Encountered server error; retrying after %d second(s).', + next_delay_sec) + time.sleep(next_delay_sec) + next_delay_sec *= 2 + + p = urlparse.urlparse(url) + c = GetConnectionClass(protocol=p.scheme)(p.netloc) + c.request('GET', url, **kwargs) + resp = c.getresponse() + LOGGER.debug('GET [%s] #%d/%d (%d)', url, i+1, TRY_LIMIT, resp.status) + if resp.status < httplib.INTERNAL_SERVER_ERROR: + return resp + + + @classmethod + def _get_token_dict(cls): + if cls._token_cache: + # If it expires within 25 seconds, refresh. + if cls._token_expiration < time.time() - 25: + return cls._token_cache + + resp = cls._get(cls._ACQUIRE_URL, headers=cls._ACQUIRE_HEADERS) + if resp.status != httplib.OK: + return None + cls._token_cache = json.load(resp) + cls._token_expiration = cls._token_cache['expires_in'] + time.time() + return cls._token_cache + + def get_auth_header(self, _host): + token_dict = self._get_token_dict() + if not token_dict: + return None + return '%(token_type)s %(access_token)s' % token_dict + + + def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None): """Opens an https connection to a gerrit service, and sends a request.""" headers = headers or {} bare_host = host.partition(':')[0] - auth = NETRC.authenticators(bare_host) + auth = Authenticator.get().get_auth_header(bare_host) if auth: - headers.setdefault('Authorization', 'Basic %s' % ( - base64.b64encode('%s:%s' % (auth[0], auth[2])))) + headers.setdefault('Authorization', auth) else: - LOGGER.debug('No authorization found in netrc for %s.' % bare_host) + LOGGER.debug('No authorization found for %s.' % bare_host) if 'Authorization' in headers and not path.startswith('a/'): url = '/a/%s' % path