#!/usr/bin/env vpython3
# Copyright (c) 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unit Tests for auth.py"""

import calendar
import datetime
import json
import os
import unittest
import sys
from unittest import mock

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import auth
import subprocess2

NOW = datetime.datetime(2019, 10, 17, 12, 30, 59, 0)
VALID_EXPIRY = NOW + datetime.timedelta(seconds=31)


class AuthenticatorTest(unittest.TestCase):
    def setUp(self):
        mock.patch('subprocess2.check_call').start()
        mock.patch('subprocess2.check_call_out').start()
        mock.patch('auth.datetime_now', return_value=NOW).start()
        self.addCleanup(mock.patch.stopall)

    def testHasCachedCredentials_NotLoggedIn(self):
        subprocess2.check_call_out.side_effect = [
            subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout',
                                           'stderr')
        ]
        self.assertFalse(auth.Authenticator().has_cached_credentials())

    def testHasCachedCredentials_LoggedIn(self):
        subprocess2.check_call_out.return_value = (json.dumps({
            'token': 'token',
            'expiry': 12345678
        }), '')
        self.assertTrue(auth.Authenticator().has_cached_credentials())

    def testGetAccessToken_NotLoggedIn(self):
        subprocess2.check_call_out.side_effect = [
            subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout',
                                           'stderr')
        ]
        self.assertRaises(auth.LoginRequiredError,
                          auth.Authenticator().get_access_token)

    def testGetAccessToken_CachedToken(self):
        authenticator = auth.Authenticator()
        authenticator._access_token = auth.Token('token', None)
        self.assertEqual(auth.Token('token', None),
                         authenticator.get_access_token())
        subprocess2.check_call_out.assert_not_called()

    def testGetAccesstoken_LoggedIn(self):
        expiry = calendar.timegm(VALID_EXPIRY.timetuple())
        subprocess2.check_call_out.return_value = (json.dumps({
            'token': 'token',
            'expiry': expiry
        }), '')
        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                         auth.Authenticator().get_access_token())
        subprocess2.check_call_out.assert_called_with([
            'luci-auth', 'token', '-scopes', auth.OAUTH_SCOPE_EMAIL,
            '-json-output', '-'
        ],
                                                      stdout=subprocess2.PIPE,
                                                      stderr=subprocess2.PIPE)

    def testGetAccessToken_DifferentScope(self):
        expiry = calendar.timegm(VALID_EXPIRY.timetuple())
        subprocess2.check_call_out.return_value = (json.dumps({
            'token': 'token',
            'expiry': expiry
        }), '')
        self.assertEqual(auth.Token('token', VALID_EXPIRY),
                         auth.Authenticator('custom scopes').get_access_token())
        subprocess2.check_call_out.assert_called_with([
            'luci-auth', 'token', '-scopes', 'custom scopes', '-json-output',
            '-'
        ],
                                                      stdout=subprocess2.PIPE,
                                                      stderr=subprocess2.PIPE)

    def testAuthorize_AccessToken(self):
        http = mock.Mock()
        http_request = http.request
        http_request.__name__ = '__name__'

        authenticator = auth.Authenticator()
        authenticator._access_token = auth.Token('access_token', None)
        authenticator._id_token = auth.Token('id_token', None)

        authorized = authenticator.authorize(http)
        authorized.request('https://example.com',
                           method='POST',
                           body='body',
                           headers={'header': 'value'})
        http_request.assert_called_once_with(
            'https://example.com', 'POST', 'body', {
                'header': 'value',
                'Authorization': 'Bearer access_token'
            }, mock.ANY, mock.ANY)

    def testGetIdToken_NotLoggedIn(self):
        subprocess2.check_call_out.side_effect = [
            subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout',
                                           'stderr')
        ]
        self.assertRaises(auth.LoginRequiredError,
                          auth.Authenticator().get_id_token)

    def testGetIdToken_CachedToken(self):
        authenticator = auth.Authenticator()
        authenticator._id_token = auth.Token('token', None)
        self.assertEqual(auth.Token('token', None),
                         authenticator.get_id_token())
        subprocess2.check_call_out.assert_not_called()

    def testGetIdToken_LoggedIn(self):
        expiry = calendar.timegm(VALID_EXPIRY.timetuple())
        subprocess2.check_call_out.return_value = (json.dumps({
            'token': 'token',
            'expiry': expiry
        }), '')
        self.assertEqual(
            auth.Token('token', VALID_EXPIRY),
            auth.Authenticator(audience='https://test.com').get_id_token())
        subprocess2.check_call_out.assert_called_with([
            'luci-auth', 'token', '-use-id-token', '-audience',
            'https://test.com', '-json-output', '-'
        ],
                                                      stdout=subprocess2.PIPE,
                                                      stderr=subprocess2.PIPE)

    def testAuthorize_IdToken(self):
        http = mock.Mock()
        http_request = http.request
        http_request.__name__ = '__name__'

        authenticator = auth.Authenticator()
        authenticator._access_token = auth.Token('access_token', None)
        authenticator._id_token = auth.Token('id_token', None)

        authorized = authenticator.authorize(http, use_id_token=True)
        authorized.request('https://example.com',
                           method='POST',
                           body='body',
                           headers={'header': 'value'})
        http_request.assert_called_once_with(
            'https://example.com', 'POST', 'body', {
                'header': 'value',
                'Authorization': 'Bearer id_token'
            }, mock.ANY, mock.ANY)


class TokenTest(unittest.TestCase):
    def setUp(self):
        mock.patch('auth.datetime_now', return_value=NOW).start()
        self.addCleanup(mock.patch.stopall)

    def testNeedsRefresh_NoExpiry(self):
        self.assertFalse(auth.Token('token', None).needs_refresh())

    def testNeedsRefresh_Expired(self):
        expired = NOW + datetime.timedelta(seconds=30)
        self.assertTrue(auth.Token('token', expired).needs_refresh())

    def testNeedsRefresh_Valid(self):
        self.assertFalse(auth.Token('token', VALID_EXPIRY).needs_refresh())


class HasLuciContextLocalAuthTest(unittest.TestCase):
    def setUp(self):
        mock.patch('os.environ').start()
        mock.patch('builtins.open', mock.mock_open()).start()
        self.addCleanup(mock.patch.stopall)

    def testNoLuciContextEnvVar(self):
        os.environ = {}
        self.assertFalse(auth.has_luci_context_local_auth())

    def testNonexistentPath(self):
        os.environ = {'LUCI_CONTEXT': 'path'}
        open.side_effect = OSError
        self.assertFalse(auth.has_luci_context_local_auth())
        open.assert_called_with('path')

    def testInvalidJsonFile(self):
        os.environ = {'LUCI_CONTEXT': 'path'}
        open().read.return_value = 'not-a-json-file'
        self.assertFalse(auth.has_luci_context_local_auth())
        open.assert_called_with('path')

    def testNoLocalAuth(self):
        os.environ = {'LUCI_CONTEXT': 'path'}
        open().read.return_value = '{}'
        self.assertFalse(auth.has_luci_context_local_auth())
        open.assert_called_with('path')

    def testNoDefaultAccountId(self):
        os.environ = {'LUCI_CONTEXT': 'path'}
        open().read.return_value = json.dumps({
            'local_auth': {
                'secret':
                'secret',
                'accounts': [{
                    'email': 'bots@account.iam.gserviceaccount.com',
                    'id': 'system',
                }],
                'rpc_port':
                1234,
            }
        })
        self.assertFalse(auth.has_luci_context_local_auth())
        open.assert_called_with('path')

    def testHasLocalAuth(self):
        os.environ = {'LUCI_CONTEXT': 'path'}
        open().read.return_value = json.dumps({
            'local_auth': {
                'secret':
                'secret',
                'accounts': [
                    {
                        'email': 'bots@account.iam.gserviceaccount.com',
                        'id': 'system',
                    },
                    {
                        'email': 'builder@account.iam.gserviceaccount.com',
                        'id': 'task',
                    },
                ],
                'rpc_port':
                1234,
                'default_account_id':
                'task',
            },
        })
        self.assertTrue(auth.has_luci_context_local_auth())
        open.assert_called_with('path')


if __name__ == '__main__':
    if '-v' in sys.argv:
        logging.basicConfig(level=logging.DEBUG)
    unittest.main()