diff --git a/auth.py b/auth.py index 6e0d2f32f4..11a2113b4e 100644 --- a/auth.py +++ b/auth.py @@ -16,6 +16,7 @@ import os import socket import sys import threading +import time import urllib import urlparse import webbrowser @@ -102,6 +103,119 @@ class LoginRequiredError(AuthenticationError): super(LoginRequiredError, self).__init__(msg) +class LuciContextAuthError(Exception): + """Raised on errors related to unsuccessful attempts to load LUCI_CONTEXT""" + + +def get_luci_context_access_token(): + """Returns a valid AccessToken from the local LUCI context auth server. + + Adapted from + https://chromium.googlesource.com/infra/luci/luci-py/+/master/client/libs/luci_context/luci_context.py + See the link above for more details. + + Returns: + AccessToken if LUCI_CONTEXT is present and attempt to load it is successful. + None if LUCI_CONTEXT is absent. + + Raises: + LuciContextAuthError if the attempt to load LUCI_CONTEXT + and request its access token is unsuccessful. + """ + return _get_luci_context_access_token(os.environ, datetime.datetime.utcnow()) + + +def _get_luci_context_access_token(env, now): + ctx_path = env.get('LUCI_CONTEXT') + if not ctx_path: + return None + ctx_path = ctx_path.decode(sys.getfilesystemencoding()) + logging.debug('Loading LUCI_CONTEXT: %r', ctx_path) + + def authErr(msg, *args): + error_msg = msg % args + ex = sys.exc_info()[1] + if not ex: + logging.error(error_msg) + raise LuciContextAuthError(error_msg) + logging.exception(error_msg) + raise LuciContextAuthError('%s: %s' % (error_msg, ex)) + + try: + loaded = _load_luci_context(ctx_path) + except (OSError, IOError, ValueError): + authErr('Failed to open, read or decode LUCI_CONTEXT') + try: + local_auth = loaded.get('local_auth') + except AttributeError: + authErr('LUCI_CONTEXT not in proper format') + # failed to grab local_auth from LUCI context + if not local_auth: + logging.debug('local_auth: no local auth found') + return None + try: + account_id = local_auth.get('default_account_id') + secret = local_auth.get('secret') + rpc_port = int(local_auth.get('rpc_port')) + except (AttributeError, ValueError): + authErr('local_auth: unexpected local auth format') + + if not secret: + authErr('local_auth: no secret returned') + # if account_id not specified, LUCI_CONTEXT should not be picked up + if not account_id: + return None + + logging.debug('local_auth: requesting an access token for account "%s"', + account_id) + http = httplib2.Http() + host = '127.0.0.1:%d' % rpc_port + resp, content = http.request( + uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host, + method='POST', + body=json.dumps({ + 'account_id': account_id, + 'scopes': OAUTH_SCOPES.split(' '), + 'secret': secret, + }), + headers={'Content-Type': 'application/json'}) + if resp.status != 200: + err = ('local_auth: Failed to grab access token from ' + 'LUCI context server with status %d: %r') + authErr(err, resp.status, content) + try: + token = json.loads(content) + error_code = token.get('error_code') + error_message = token.get('error_message') + access_token = token.get('access_token') + expiry = token.get('expiry') + except (AttributeError, ValueError): + authErr('local_auth: Unexpected access token response format') + if error_code: + authErr('local_auth: Error %d in retrieving access token: %s', + error_code, error_message) + if not access_token: + authErr('local_auth: No access token returned from LUCI context server') + expiry_dt = None + if expiry: + try: + expiry_dt = datetime.datetime.utcfromtimestamp(expiry) + except (TypeError, ValueError): + authErr('Invalid expiry in returned token') + logging.debug( + 'local_auth: got an access token for account "%s" that expires in %d sec', + account_id, expiry - time.mktime(now.timetuple())) + access_token = AccessToken(access_token, expiry_dt) + if _needs_refresh(access_token, now=now): + authErr('local_auth: the returned access token needs to be refreshed') + return access_token + + +def _load_luci_context(ctx_path): + with open(ctx_path) as f: + return json.load(f) + + def make_auth_config( use_oauth2=None, save_cookies=None, @@ -219,6 +333,9 @@ def get_authenticator_for_host(hostname, config): Returns: Authenticator object. + + Raises: + AuthenticationError if hostname is invalid. """ hostname = hostname.lower().rstrip('/') # Append some scheme, otherwise urlparse puts hostname into parsed.path. @@ -303,23 +420,43 @@ class Authenticator(object): with self._lock: return bool(self._get_cached_credentials()) - def get_access_token(self, force_refresh=False, allow_user_interaction=False): + def get_access_token(self, force_refresh=False, allow_user_interaction=False, + use_local_auth=True): """Returns AccessToken, refreshing it if necessary. Args: force_refresh: forcefully refresh access token even if it is not expired. allow_user_interaction: True to enable blocking for user input if needed. + use_local_auth: default to local auth if needed. Raises: AuthenticationError on error or if authentication flow was interrupted. LoginRequiredError if user interaction is required, but allow_user_interaction is False. """ + def get_loc_auth_tkn(): + exi = sys.exc_info() + if not use_local_auth: + logging.error('Failed to create access token') + raise + try: + self._access_token = get_luci_context_access_token() + if not self._access_token: + logging.error('Failed to create access token') + raise + return self._access_token + except LuciContextAuthError: + logging.exception('Failed to use local auth') + raise exi[0], exi[1], exi[2] + with self._lock: if force_refresh: logging.debug('Forcing access token refresh') - self._access_token = self._create_access_token(allow_user_interaction) - return self._access_token + try: + self._access_token = self._create_access_token(allow_user_interaction) + return self._access_token + except LoginRequiredError: + return get_loc_auth_tkn() # Load from on-disk cache on a first access. if not self._access_token: @@ -331,7 +468,11 @@ class Authenticator(object): self._access_token = self._load_access_token() # Nope, still expired, need to run the refresh flow. if not self._access_token or _needs_refresh(self._access_token): - self._access_token = self._create_access_token(allow_user_interaction) + try: + self._access_token = self._create_access_token( + allow_user_interaction) + except LoginRequiredError: + get_loc_auth_tkn() return self._access_token @@ -548,11 +689,12 @@ def _read_refresh_token_json(path): 'Failed to read refresh token from %s: missing key %s' % (path, e)) -def _needs_refresh(access_token): +def _needs_refresh(access_token, now=None): """True if AccessToken should be refreshed.""" if access_token.expires_at is not None: + now = now or datetime.datetime.utcnow() # Allow 5 min of clock skew between client and backend. - now = datetime.datetime.utcnow() + datetime.timedelta(seconds=300) + now += datetime.timedelta(seconds=300) return now >= access_token.expires_at # Token without expiration time never expires. return False diff --git a/presubmit_canned_checks.py b/presubmit_canned_checks.py index cf04fbc076..21bb407b39 100644 --- a/presubmit_canned_checks.py +++ b/presubmit_canned_checks.py @@ -71,7 +71,7 @@ def CheckChangedConfigs(input_api, output_api): try: authenticator = auth.get_authenticator_for_host( LUCI_CONFIG_HOST_NAME, auth.make_auth_config()) - acc_tkn = authenticator.get_access_token(allow_user_interaction=True).token + acc_tkn = authenticator.get_access_token() except auth.AuthenticationError as e: return [output_api.PresubmitError( 'Error in authenticating user.', long_text=str(e))] @@ -80,7 +80,7 @@ def CheckChangedConfigs(input_api, output_api): api_url = ('https://%s/_ah/api/config/v1/%s' % (LUCI_CONFIG_HOST_NAME, endpoint)) req = urllib2.Request(api_url) - req.add_header('Authorization', 'Bearer %s' % acc_tkn) + req.add_header('Authorization', 'Bearer %s' % acc_tkn.token) if body is not None: req.add_header('Content-Type', 'application/json') req.add_data(json.dumps(body)) diff --git a/tests/auth_test.py b/tests/auth_test.py new file mode 100755 index 0000000000..565fe30a7b --- /dev/null +++ b/tests/auth_test.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""Unit Tests for auth.py""" + +import __builtin__ +import datetime +import json +import logging +import os +import unittest +import sys +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +from testing_support import auto_stub +from third_party import httplib2 +from third_party import mock + +import auth + + +class TestGetLuciContextAccessToken(auto_stub.TestCase): + mock_env = {'LUCI_CONTEXT': 'default/test/path'} + + def _mock_local_auth(self, account_id, secret, rpc_port): + self.mock(auth, '_load_luci_context', mock.Mock()) + auth._load_luci_context.return_value = { + 'local_auth': { + 'default_account_id': account_id, + 'secret': secret, + 'rpc_port': rpc_port, + } + } + + def _mock_loc_server_resp(self, status, content): + mock_resp = mock.Mock() + mock_resp.status = status + self.mock(httplib2.Http, 'request', mock.Mock()) + httplib2.Http.request.return_value = (mock_resp, content) + + def test_correct_local_auth_format(self): + self._mock_local_auth('dead', 'beef', 10) + expiry_time = datetime.datetime.min + datetime.timedelta(minutes=60) + resp_content = { + 'error_code': None, + 'error_message': None, + 'access_token': 'token', + 'expiry': time.mktime(expiry_time.timetuple()), + } + self._mock_loc_server_resp(200, json.dumps(resp_content)) + token = auth._get_luci_context_access_token( + self.mock_env, datetime.datetime.min) + self.assertEquals(token.token, 'token') + + def test_incorrect_port_format(self): + self._mock_local_auth('foo', 'bar', 'bar') + with self.assertRaises(auth.LuciContextAuthError): + auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min) + + def test_no_account_id(self): + self._mock_local_auth(None, 'bar', 10) + token = auth._get_luci_context_access_token( + self.mock_env, datetime.datetime.min) + self.assertIsNone(token) + + def test_expired_token(self): + self._mock_local_auth('dead', 'beef', 10) + resp_content = { + 'error_code': None, + 'error_message': None, + 'access_token': 'token', + 'expiry': 1, + } + self._mock_loc_server_resp(200, json.dumps(resp_content)) + with self.assertRaises(auth.LuciContextAuthError): + auth._get_luci_context_access_token( + self.mock_env, datetime.datetime.utcfromtimestamp(1)) + + def test_incorrect_expiry_format(self): + self._mock_local_auth('dead', 'beef', 10) + resp_content = { + 'error_code': None, + 'error_message': None, + 'access_token': 'token', + 'expiry': 'dead', + } + self._mock_loc_server_resp(200, json.dumps(resp_content)) + with self.assertRaises(auth.LuciContextAuthError): + auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min) + + def test_incorrect_response_content_format(self): + self._mock_local_auth('dead', 'beef', 10) + self._mock_loc_server_resp(200, '5') + with self.assertRaises(auth.LuciContextAuthError): + auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min) + + +if __name__ == '__main__': + if '-v' in sys.argv: + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/presubmit_unittest.py b/tests/presubmit_unittest.py index 775ed7d1e2..9238ffc774 100755 --- a/tests/presubmit_unittest.py +++ b/tests/presubmit_unittest.py @@ -1974,8 +1974,7 @@ class CannedChecksUnittest(PresubmitTestsBase): token_mock = self.mox.CreateMock(auth.AccessToken) token_mock.token = 123 auth_mock = self.mox.CreateMock(auth.Authenticator) - auth_mock.get_access_token( - allow_user_interaction=True).AndReturn(token_mock) + auth_mock.get_access_token().AndReturn(token_mock) self.mox.StubOutWithMock(auth, 'get_authenticator_for_host') auth.get_authenticator_for_host( mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock)