diff --git a/auth.py b/auth.py index 84011b3ad..e0afec988 100644 --- a/auth.py +++ b/auth.py @@ -21,6 +21,7 @@ import urlparse import subprocess2 from third_party import httplib2 +from third_party.oauth2client import client # depot_tools/. @@ -54,8 +55,8 @@ class AccessToken(collections.namedtuple('AccessToken', [ """True if this AccessToken should be refreshed.""" if self.expires_at is not None: now = now or datetime.datetime.utcnow() - # Allow 30s of clock skew between client and backend. - now += datetime.timedelta(seconds=30) + # Allow 3 min of clock skew between client and backend. + now += datetime.timedelta(seconds=180) return now >= self.expires_at # Token without expiration time never expires. return False @@ -99,8 +100,6 @@ def has_luci_context_local_auth(): return bool(params.default_account_id) -# TODO(crbug.com/1001756): Remove. luci-auth uses local auth if available, -# making this unnecessary. def get_luci_context_access_token(scopes=OAUTH_SCOPE_EMAIL): """Returns a valid AccessToken from the local LUCI context auth server. @@ -292,18 +291,18 @@ def add_auth_options(parser, default_config=None): help='Do not save authentication cookies to local disk.') # OAuth2 related options. - # TODO(crbug.com/1001756): Remove. No longer supported. parser.auth_group.add_option( '--auth-no-local-webserver', action='store_false', dest='use_local_webserver', default=default_config.use_local_webserver, - help='DEPRECATED. Do not use') + help='Do not run a local web server when performing OAuth2 login flow.') parser.auth_group.add_option( '--auth-host-port', type=int, default=default_config.webserver_port, - help='DEPRECATED. Do not use') + help='Port a local web server should listen on. Used only if ' + '--auth-no-local-webserver is not set. [default: %default]') parser.auth_group.add_option( '--auth-refresh-token-json', help='DEPRECATED. Do not use') @@ -373,25 +372,27 @@ class Authenticator(object): logging.debug('Using auth config %r', config) def has_cached_credentials(self): - """Returns True if credentials can be obtained. + """Returns True if long term credentials (refresh token) are in cache. + + Doesn't make network calls. - If returns False, get_access_token() later will probably ask for interactive - login by raising LoginRequiredError, unless local auth in configured. + If returns False, get_access_token() later will ask for interactive login by + raising LoginRequiredError. If returns True, most probably get_access_token() won't ask for interactive - login, unless an external token is provided that has been revoked. + login, though it is not guaranteed, since cached token can be already + revoked and there's no way to figure this out without actually trying to use + it. """ with self._lock: - return bool(self._get_luci_auth_token()) + return bool(self._get_cached_credentials()) def get_access_token(self, force_refresh=False, allow_user_interaction=False, use_local_auth=True): """Returns AccessToken, refreshing it if necessary. Args: - TODO(crbug.com/1001756): Remove. luci-auth doesn't support - force-refreshing tokens. - force_refresh: Ignored, + force_refresh: forcefully refresh access token even if it is not expired. allow_user_interaction: True to enable blocking for user input if needed. use_local_auth: default to local auth if needed. @@ -400,41 +401,53 @@ class Authenticator(object): LoginRequiredError if user interaction is required, but allow_user_interaction is False. """ - with self._lock: - if self._access_token and not self._access_token.needs_refresh(): - return self._access_token - - # Token expired or missing. Maybe some other process already updated it, - # reload from the cache. - self._access_token = self._get_luci_auth_token() - if self._access_token and not self._access_token.needs_refresh(): + def get_loc_auth_tkn(): + exi = sys.exc_info() + if not use_local_auth: + logging.error('Failed to create access token') + raise + try: + self._access_token = get_luci_context_access_token() + if not self._access_token: + logging.error('Failed to create access token') + raise return self._access_token + except LuciContextAuthError: + logging.exception('Failed to use local auth') + raise exi[0], exi[1], exi[2] - # Nope, still expired, need to run the refresh flow. - if not self._external_token and allow_user_interaction: - logging.debug('Launching luci-auth login') - self._access_token = self._run_oauth_dance() - if self._access_token and not self._access_token.needs_refresh(): - return self._access_token - - # TODO(crbug.com/1001756): Remove. luci-auth uses local auth if it exists. - # Refresh flow failed. Try local auth. - if use_local_auth: + with self._lock: + if force_refresh: + logging.debug('Forcing access token refresh') try: - self._access_token = get_luci_context_access_token() - except LuciContextAuthError: - logging.exception('Failed to use local auth') - if self._access_token and not self._access_token.needs_refresh(): - return self._access_token - - # Give up. - logging.error('Failed to create access token') - raise LoginRequiredError(self._scopes) + self._access_token = self._create_access_token(allow_user_interaction) + return self._access_token + except LoginRequiredError: + return get_loc_auth_tkn() + + # Load from on-disk cache on a first access. + if not self._access_token: + self._access_token = self._load_access_token() + + # Refresh if expired or missing. + if not self._access_token or self._access_token.needs_refresh(): + # Maybe some other process already updated it, reload from the cache. + self._access_token = self._load_access_token() + # Nope, still expired, need to run the refresh flow. + if not self._access_token or self._access_token.needs_refresh(): + try: + self._access_token = self._create_access_token( + allow_user_interaction) + except LoginRequiredError: + get_loc_auth_tkn() + + return self._access_token def authorize(self, http): """Monkey patches authentication logic of httplib2.Http instance. The modified http.request method will add authentication headers to each + request and will refresh access_tokens when a 401 is received on a request. Args: @@ -444,6 +457,7 @@ class Authenticator(object): A modified instance of http that was passed in. """ # Adapted from oauth2client.OAuth2Credentials.authorize. + request_orig = http.request @functools.wraps(request_orig) @@ -453,37 +467,92 @@ class Authenticator(object): connection_type=None): headers = (headers or {}).copy() headers['Authorization'] = 'Bearer %s' % self.get_access_token().token - return request_orig( + resp, content = request_orig( uri, method, body, headers, redirections, connection_type) + if resp.status in client.REFRESH_STATUS_CODES: + logging.info('Refreshing due to a %s', resp.status) + access_token = self.get_access_token(force_refresh=True) + headers['Authorization'] = 'Bearer %s' % access_token.token + return request_orig( + uri, method, body, headers, redirections, connection_type) + else: + return (resp, content) http.request = new_request return http ## Private methods. - def _run_luci_auth_login(self): - """Run luci-auth login. + def _get_cached_credentials(self): + """Returns oauth2client.Credentials loaded from luci-auth.""" + credentials = _get_luci_auth_credentials(self._scopes) - Returns: - AccessToken with credentials. - """ - logging.debug('Running luci-auth login') - subprocess2.check_call(['luci-auth', 'login', '-scopes', self._scopes]) - return self._get_luci_auth_token() + if not credentials: + logging.debug('No cached token') + else: + _log_credentials_info('cached token', credentials) - def _get_luci_auth_token(self): - logging.debug('Running luci-auth token') - try: - out, err = subprocess2.check_call_out( - ['luci-auth', 'token', '-scopes', self._scopes, '-json-output', '-'], - stdout=subprocess2.PIPE, stderr=subprocess2.PIPE) - logging.debug('luci-auth token stderr:\n%s', err) - token_info = json.loads(out) - return AccessToken( - token_info['token'], - datetime.datetime.utcfromtimestamp(token_info['expiry'])) - except subprocess2.CalledProcessError: + return credentials if (credentials and not credentials.invalid) else None + + def _load_access_token(self): + """Returns cached AccessToken if it is not expired yet.""" + logging.debug('Reloading access token from cache') + creds = self._get_cached_credentials() + if not creds or not creds.access_token or creds.access_token_expired: + logging.debug('Access token is missing or expired') return None + return AccessToken(str(creds.access_token), creds.token_expiry) + + def _create_access_token(self, allow_user_interaction=False): + """Mints and caches a new access token, launching OAuth2 dance if necessary. + + Uses cached refresh token, if present. In that case user interaction is not + required and function will finish quietly. Otherwise it will launch 3-legged + OAuth2 flow, that needs user interaction. + + Args: + allow_user_interaction: if True, allow interaction with the user (e.g. + reading standard input, or launching a browser). + + Returns: + AccessToken. + + Raises: + AuthenticationError on error or if authentication flow was interrupted. + LoginRequiredError if user interaction is required, but + allow_user_interaction is False. + """ + logging.debug( + 'Making new access token (allow_user_interaction=%r)', + allow_user_interaction) + credentials = self._get_cached_credentials() + + # 3-legged flow with (perhaps cached) refresh token. + refreshed = False + if credentials and not credentials.invalid: + try: + logging.debug('Attempting to refresh access_token') + credentials.refresh(httplib2.Http()) + _log_credentials_info('refreshed token', credentials) + refreshed = True + except client.Error as err: + logging.warning( + 'OAuth error during access token refresh (%s). ' + 'Attempting a full authentication flow.', err) + + # Refresh token is missing or invalid, go through the full flow. + if not refreshed: + if not allow_user_interaction: + logging.debug('Requesting user to login') + raise LoginRequiredError(self._scopes) + logging.debug('Launching OAuth browser flow') + credentials = _run_oauth_dance(self._scopes) + _log_credentials_info('new token', credentials) + + logging.info( + 'OAuth access_token refreshed. Expires in %s.', + credentials.token_expiry - datetime.datetime.utcnow()) + return AccessToken(str(credentials.access_token), credentials.token_expiry) ## Private functions. @@ -492,3 +561,44 @@ class Authenticator(object): def _is_headless(): """True if machine doesn't seem to have a display.""" return sys.platform == 'linux2' and not os.environ.get('DISPLAY') + + +def _log_credentials_info(title, credentials): + """Dumps (non sensitive) part of client.Credentials object to debug log.""" + if credentials: + logging.debug('%s info: %r', title, { + 'access_token_expired': credentials.access_token_expired, + 'has_access_token': bool(credentials.access_token), + 'invalid': credentials.invalid, + 'utcnow': datetime.datetime.utcnow(), + 'token_expiry': credentials.token_expiry, + }) + + +def _get_luci_auth_credentials(scopes): + try: + token_info = json.loads(subprocess2.check_output( + ['luci-auth', 'token', '-scopes', scopes, '-json-output', '-'], + stderr=subprocess2.VOID)) + except subprocess2.CalledProcessError: + return None + + return client.OAuth2Credentials( + access_token=token_info['token'], + client_id=None, + client_secret=None, + refresh_token=None, + token_expiry=datetime.datetime.utcfromtimestamp(token_info['expiry']), + token_uri=None, + user_agent=None, + revoke_uri=None) + + +def _run_oauth_dance(scopes): + """Perform full 3-legged OAuth2 flow with the browser. + + Returns: + oauth2client.Credentials. + """ + subprocess2.check_call(['luci-auth', 'login', '-scopes', scopes]) + return _get_luci_auth_credentials(scopes) diff --git a/tests/auth_test.py b/tests/auth_test.py index cb4d56437..c70b36218 100755 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -5,7 +5,7 @@ """Unit Tests for auth.py""" -import contextlib +import __builtin__ import datetime import json import logging @@ -16,35 +16,37 @@ import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from third_party import mock + +from testing_support import auto_stub from third_party import httplib2 +from third_party import mock import auth -def _mockLocalAuth(account_id, secret, rpc_port): - mock_luci_context = { +class TestLuciContext(auto_stub.TestCase): + def setUp(self): + auth._get_luci_context_local_auth_params.clear_cache() + + def _mock_local_auth(self, account_id, secret, rpc_port): + self.mock(os, 'environ', {'LUCI_CONTEXT': 'default/test/path'}) + self.mock(auth, '_load_luci_context', mock.Mock()) + auth._load_luci_context.return_value = { 'local_auth': { - 'default_account_id': account_id, - 'secret': secret, - 'rpc_port': rpc_port, + 'default_account_id': account_id, + 'secret': secret, + 'rpc_port': rpc_port, } - } - mock.patch('auth._load_luci_context', return_value=mock_luci_context).start() - mock.patch('os.environ', {'LUCI_CONTEXT': 'default/test/path'}).start() - - -def _mockResponse(status, content): - mock_response = (mock.Mock(status=status), content) - mock.patch('auth.httplib2.Http.request', return_value=mock_response).start() - + } -class TestLuciContext(unittest.TestCase): - def setUp(self): - auth._get_luci_context_local_auth_params.clear_cache() + def _mock_loc_server_resp(self, status, content): + mock_resp = mock.Mock() + mock_resp.status = status + self.mock(httplib2.Http, 'request', mock.Mock()) + httplib2.Http.request.return_value = (mock_resp, content) def test_all_good(self): - _mockLocalAuth('account', 'secret', 8080) + self._mock_local_auth('account', 'secret', 8080) self.assertTrue(auth.has_luci_context_local_auth()) expiry_time = datetime.datetime.min + datetime.timedelta(hours=1) @@ -55,18 +57,18 @@ class TestLuciContext(unittest.TestCase): 'expiry': (expiry_time - datetime.datetime.utcfromtimestamp(0)).total_seconds(), } - _mockResponse(200, json.dumps(resp_content)) + self._mock_loc_server_resp(200, json.dumps(resp_content)) params = auth._get_luci_context_local_auth_params() token = auth._get_luci_context_access_token(params, datetime.datetime.min) self.assertEqual(token.token, 'token') def test_no_account_id(self): - _mockLocalAuth(None, 'secret', 8080) + self._mock_local_auth(None, 'secret', 8080) self.assertFalse(auth.has_luci_context_local_auth()) self.assertIsNone(auth.get_luci_context_access_token()) def test_incorrect_port_format(self): - _mockLocalAuth('account', 'secret', 'port') + self._mock_local_auth('account', 'secret', 'port') self.assertFalse(auth.has_luci_context_local_auth()) with self.assertRaises(auth.LuciContextAuthError): auth.get_luci_context_access_token() @@ -79,7 +81,7 @@ class TestLuciContext(unittest.TestCase): 'access_token': 'token', 'expiry': 1, } - _mockResponse(200, json.dumps(resp_content)) + self._mock_loc_server_resp(200, json.dumps(resp_content)) with self.assertRaises(auth.LuciContextAuthError): auth._get_luci_context_access_token( params, datetime.datetime.utcfromtimestamp(1)) @@ -92,13 +94,13 @@ class TestLuciContext(unittest.TestCase): 'access_token': 'token', 'expiry': 'dead', } - _mockResponse(200, json.dumps(resp_content)) + self._mock_loc_server_resp(200, json.dumps(resp_content)) with self.assertRaises(auth.LuciContextAuthError): auth._get_luci_context_access_token(params, datetime.datetime.min) def test_incorrect_response_content_format(self): params = auth._LuciContextLocalAuthParams('account', 'secret', 8080) - _mockResponse(200, '5') + self._mock_loc_server_resp(200, '5') with self.assertRaises(auth.LuciContextAuthError): auth._get_luci_context_access_token(params, datetime.datetime.min)