#!/usr/bin/env vpython3
# Copyright (c) 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Unit Tests for auth.py"""

import calendar
import datetime
import json
import os
import unittest
import sys

if sys.version_info.major == 2:
  import mock
  BUILTIN_OPEN = '__builtin__.open'
else:
  from unittest import mock
  BUILTIN_OPEN = 'builtins.open'

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import auth
import subprocess2


NOW = datetime.datetime(2019, 10, 17, 12, 30, 59, 0)
VALID_EXPIRY = NOW + datetime.timedelta(seconds=31)


class AuthenticatorTest(unittest.TestCase):
  def setUp(self):
    mock.patch('subprocess2.check_call').start()
    mock.patch('subprocess2.check_call_out').start()
    mock.patch('auth.datetime_now', return_value=NOW).start()
    self.addCleanup(mock.patch.stopall)

  def testHasCachedCredentials_NotLoggedIn(self):
    subprocess2.check_call_out.side_effect = [
        subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')]
    self.assertFalse(auth.Authenticator().has_cached_credentials())

  def testHasCachedCredentials_LoggedIn(self):
    subprocess2.check_call_out.return_value = (
        json.dumps({'token': 'token', 'expiry': 12345678}), '')
    self.assertTrue(auth.Authenticator().has_cached_credentials())

  def testGetAccessToken_NotLoggedIn(self):
    subprocess2.check_call_out.side_effect = [
        subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')]
    self.assertRaises(
        auth.LoginRequiredError, auth.Authenticator().get_access_token)

  def testGetAccessToken_CachedToken(self):
    authenticator = auth.Authenticator()
    authenticator._access_token = auth.AccessToken('token', None)
    self.assertEqual(
        auth.AccessToken('token', None), authenticator.get_access_token())
    subprocess2.check_call_out.assert_not_called()

  def testGetAccesstoken_LoggedIn(self):
    expiry = calendar.timegm(VALID_EXPIRY.timetuple())
    subprocess2.check_call_out.return_value = (
        json.dumps({'token': 'token', 'expiry': expiry}), '')
    self.assertEqual(
        auth.AccessToken('token', VALID_EXPIRY),
        auth.Authenticator().get_access_token())
    subprocess2.check_call_out.assert_called_with(
        ['luci-auth',
         'token',
         '-scopes', auth.OAUTH_SCOPE_EMAIL,
         '-json-output', '-'],
        stdout=subprocess2.PIPE, stderr=subprocess2.PIPE)

  def testGetAccessToken_DifferentScope(self):
    expiry = calendar.timegm(VALID_EXPIRY.timetuple())
    subprocess2.check_call_out.return_value = (
        json.dumps({'token': 'token', 'expiry': expiry}), '')
    self.assertEqual(
        auth.AccessToken('token', VALID_EXPIRY),
        auth.Authenticator('custom scopes').get_access_token())
    subprocess2.check_call_out.assert_called_with(
        ['luci-auth', 'token', '-scopes', 'custom scopes', '-json-output', '-'],
        stdout=subprocess2.PIPE, stderr=subprocess2.PIPE)

  def testAuthorize(self):
    http = mock.Mock()
    http_request = http.request
    http_request.__name__ = '__name__'

    authenticator = auth.Authenticator()
    authenticator._access_token = auth.AccessToken('token', None)

    authorized = authenticator.authorize(http)
    authorized.request(
        'https://example.com', method='POST', body='body',
        headers={'header': 'value'})
    http_request.assert_called_once_with(
        'https://example.com', 'POST', 'body',
        {'header': 'value', 'Authorization': 'Bearer token'}, mock.ANY,
        mock.ANY)


class AccessTokenTest(unittest.TestCase):
  def setUp(self):
    mock.patch('auth.datetime_now', return_value=NOW).start()
    self.addCleanup(mock.patch.stopall)

  def testNeedsRefresh_NoExpiry(self):
    self.assertFalse(auth.AccessToken('token', None).needs_refresh())

  def testNeedsRefresh_Expired(self):
    expired = NOW + datetime.timedelta(seconds=30)
    self.assertTrue(auth.AccessToken('token', expired).needs_refresh())

  def testNeedsRefresh_Valid(self):
    self.assertFalse(auth.AccessToken('token', VALID_EXPIRY).needs_refresh())


class HasLuciContextLocalAuthTest(unittest.TestCase):
  def setUp(self):
    mock.patch('os.environ').start()
    mock.patch(BUILTIN_OPEN, mock.mock_open()).start()
    self.addCleanup(mock.patch.stopall)

  def testNoLuciContextEnvVar(self):
    os.environ = {}
    self.assertFalse(auth.has_luci_context_local_auth())

  def testNonexistentPath(self):
    os.environ = {'LUCI_CONTEXT': 'path'}
    open.side_effect = OSError
    self.assertFalse(auth.has_luci_context_local_auth())
    open.assert_called_with('path')

  def testInvalidJsonFile(self):
    os.environ = {'LUCI_CONTEXT': 'path'}
    open().read.return_value = 'not-a-json-file'
    self.assertFalse(auth.has_luci_context_local_auth())
    open.assert_called_with('path')

  def testNoLocalAuth(self):
    os.environ = {'LUCI_CONTEXT': 'path'}
    open().read.return_value = '{}'
    self.assertFalse(auth.has_luci_context_local_auth())
    open.assert_called_with('path')

  def testNoDefaultAccountId(self):
    os.environ = {'LUCI_CONTEXT': 'path'}
    open().read.return_value = json.dumps({
        'local_auth': {
            'secret': 'secret',
            'accounts': [{
                'email': 'bots@account.iam.gserviceaccount.com',
                'id': 'system',
            }],
            'rpc_port': 1234,
        }
    })
    self.assertFalse(auth.has_luci_context_local_auth())
    open.assert_called_with('path')

  def testHasLocalAuth(self):
    os.environ = {'LUCI_CONTEXT': 'path'}
    open().read.return_value = json.dumps({
        'local_auth': {
            'secret': 'secret',
            'accounts': [
                {
                    'email': 'bots@account.iam.gserviceaccount.com',
                    'id': 'system',
                },
                {
                    'email': 'builder@account.iam.gserviceaccount.com',
                    'id': 'task',
                },
            ],
            'rpc_port': 1234,
            'default_account_id': 'task',
        },
    })
    self.assertTrue(auth.has_luci_context_local_auth())
    open.assert_called_with('path')


if __name__ == '__main__':
  if '-v' in sys.argv:
    logging.basicConfig(level=logging.DEBUG)
  unittest.main()