# Copyright 2010 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for oauth2_client."""

import datetime
import logging
import os
import sys
import unittest
import urllib2
import urlparse
from stat import S_IMODE
from StringIO import StringIO

test_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0]))

lib_dir = os.path.join(test_bin_dir, '..')
sys.path.insert(0, lib_dir)

# Needed for boto.cacerts
boto_lib_dir = os.path.join(test_bin_dir, '..', 'boto')
sys.path.insert(0, boto_lib_dir)

import oauth2_client

LOG = logging.getLogger('oauth2_client_test')

class MockOpener:
  def __init__(self):
    self.reset()

  def reset(self):
    self.open_error = None
    self.open_result = None
    self.open_capture_url = None
    self.open_capture_data = None

  def open(self, req, data=None):
    self.open_capture_url = req.get_full_url()
    self.open_capture_data = req.get_data()
    if self.open_error is not None:
      raise self.open_error
    else:
      return StringIO(self.open_result)


class MockDateTime:
  def __init__(self):
    self.mock_now = None

  def utcnow(self):
    return self.mock_now


class OAuth2ClientTest(unittest.TestCase):
  def setUp(self):
    self.opener = MockOpener()
    self.mock_datetime = MockDateTime()
    self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826)
    self.mock_datetime.mock_now = self.start_time
    self.client = oauth2_client.OAuth2Client(
        oauth2_client.OAuth2Provider(
            'Sample OAuth Provider',
            'https://provider.example.com/oauth/provider?mode=authorize',
            'https://provider.example.com/oauth/provider?mode=token'),
        'clid', 'clsecret',
        url_opener=self.opener, datetime_strategy=self.mock_datetime)

  def testFetchAccessToken(self):
    refresh_token = '1/ZaBrxdPl77Bi4jbsO7x-NmATiaQZnWPB51nTvo8n9Sw'
    access_token = '1/aalskfja-asjwerwj'
    self.opener.open_result = (
        '{"access_token":"%s","expires_in":3600}' % access_token)
    cred = oauth2_client.RefreshToken(self.client, refresh_token)
    token = self.client.FetchAccessToken(cred)

    self.assertEquals(
        self.opener.open_capture_url,
        'https://provider.example.com/oauth/provider?mode=token')
    self.assertEquals({
        'grant_type': ['refresh_token'],
        'client_id': ['clid'],
        'client_secret': ['clsecret'],
        'refresh_token': [refresh_token]},
        urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True,
                      strict_parsing=True))
    self.assertEquals(access_token, token.token)
    self.assertEquals(
        datetime.datetime(2011, 3, 1, 11, 25, 13, 300826),
        token.expiry)

  def testFetchAccessTokenFailsForBadJsonResponse(self):
    self.opener.open_result = 'blah'
    cred = oauth2_client.RefreshToken(self.client, 'abc123')
    self.assertRaises(
        oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred)

  def testFetchAccessTokenFailsForErrorResponse(self):
    self.opener.open_error = urllib2.HTTPError(
        None, 400, 'Bad Request', None, StringIO('{"error": "invalid token"}'))
    cred = oauth2_client.RefreshToken(self.client, 'abc123')
    self.assertRaises(
        oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred)

  def testFetchAccessTokenFailsForHttpError(self):
    self.opener.open_result = urllib2.HTTPError(
        'foo', 400, 'Bad Request', None, None)
    cred = oauth2_client.RefreshToken(self.client, 'abc123')
    self.assertRaises(
        oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred)

  def testGetAccessToken(self):
    refresh_token = 'ref_token'
    access_token_1 = 'abc123'
    self.opener.open_result = (
        '{"access_token":"%s",' '"expires_in":3600}' % access_token_1)
    cred = oauth2_client.RefreshToken(self.client, refresh_token)

    token_1 = self.client.GetAccessToken(cred)

    # There's no access token in the cache; verify that we fetched a fresh
    # token.
    self.assertEquals({
        'grant_type': ['refresh_token'],
        'client_id': ['clid'],
        'client_secret': ['clsecret'],
        'refresh_token': [refresh_token]},
        urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True,
                      strict_parsing=True))
    self.assertEquals(access_token_1, token_1.token)
    self.assertEquals(self.start_time + datetime.timedelta(minutes=60),
                      token_1.expiry)

    # Advance time by less than expiry time, and fetch another token.
    self.opener.reset()
    self.mock_datetime.mock_now = (
        self.start_time + datetime.timedelta(minutes=55))
    token_2 = self.client.GetAccessToken(cred)

    # Since the access token wasn't expired, we get the cache token, and there
    # was no refresh request.
    self.assertEquals(token_1, token_2)
    self.assertEquals(access_token_1, token_2.token)
    self.assertEquals(None, self.opener.open_capture_url)
    self.assertEquals(None, self.opener.open_capture_data)

    # Advance time past expiry time, and fetch another token.
    self.opener.reset()
    self.mock_datetime.mock_now = (
        self.start_time + datetime.timedelta(minutes=55, seconds=1))
    access_token_2 = 'zyx456'
    self.opener.open_result = (
        '{"access_token":"%s",' '"expires_in":3600}' % access_token_2)
    token_3 = self.client.GetAccessToken(cred)
    
    # This should have resulted in a refresh request and a fresh access token.
    self.assertEquals({
        'grant_type': ['refresh_token'],
        'client_id': ['clid'],
        'client_secret': ['clsecret'],
        'refresh_token': [refresh_token]},
        urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True,
                      strict_parsing=True))
    self.assertEquals(access_token_2, token_3.token)
    self.assertEquals(self.mock_datetime.mock_now + datetime.timedelta(minutes=60),
                      token_3.expiry)

  def testGetAuthorizationUri(self):
    authn_uri = self.client.GetAuthorizationUri(
        'https://www.example.com/oauth/redir?mode=approve%20me',
        ('scope_foo', 'scope_bar'),
        {'state': 'this and that & sundry'})

    uri_parts = urlparse.urlsplit(authn_uri)
    self.assertEquals(('https', 'provider.example.com', '/oauth/provider'),
                      uri_parts[:3])

    self.assertEquals({
        'response_type': ['code'],
        'client_id': ['clid'],
        'redirect_uri':
            ['https://www.example.com/oauth/redir?mode=approve%20me'],
        'scope': ['scope_foo scope_bar'],
        'state': ['this and that & sundry'],
        'mode': ['authorize']},
        urlparse.parse_qs(uri_parts[3]))

  def testExchangeAuthorizationCode(self):
    code = 'codeABQ1234'
    exp_refresh_token = 'ref_token42'
    exp_access_token = 'access_tokenXY123'
    self.opener.open_result = (
        '{"access_token":"%s","expires_in":3600,"refresh_token":"%s"}'
        % (exp_access_token, exp_refresh_token))

    refresh_token, access_token = self.client.ExchangeAuthorizationCode(
        code, 'urn:ietf:wg:oauth:2.0:oob', ('scope1', 'scope2'))

    self.assertEquals({
        'grant_type': ['authorization_code'],
        'client_id': ['clid'],
        'client_secret': ['clsecret'],
        'code': [code],
        'redirect_uri': ['urn:ietf:wg:oauth:2.0:oob'],
        'scope': ['scope1 scope2'] },
        urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True,
                      strict_parsing=True))
    self.assertEquals(exp_access_token, access_token.token)
    self.assertEquals(self.start_time + datetime.timedelta(minutes=60),
                      access_token.expiry)

    self.assertEquals(self.client, refresh_token.oauth2_client)
    self.assertEquals(exp_refresh_token, refresh_token.refresh_token)

    # Check that the access token was put in the cache.
    cached_token = self.client.access_token_cache.GetToken(
        refresh_token.CacheKey())
    self.assertEquals(access_token, cached_token)


class AccessTokenTest(unittest.TestCase):

  def testShouldRefresh(self):
    mock_datetime = MockDateTime()
    start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
    expiry = start + datetime.timedelta(minutes=60)
    token = oauth2_client.AccessToken(
        'foo', expiry, datetime_strategy=mock_datetime)

    mock_datetime.mock_now = start
    self.assertFalse(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(minutes=54)
    self.assertFalse(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(minutes=55)
    self.assertFalse(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(
        minutes=55, seconds=1)
    self.assertTrue(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(
        minutes=61)
    self.assertTrue(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(minutes=58)
    self.assertFalse(token.ShouldRefresh(time_delta=120))

    mock_datetime.mock_now = start + datetime.timedelta(
        minutes=58, seconds=1)
    self.assertTrue(token.ShouldRefresh(time_delta=120))

  def testShouldRefreshNoExpiry(self):
    mock_datetime = MockDateTime()
    start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
    token = oauth2_client.AccessToken(
        'foo', None, datetime_strategy=mock_datetime)

    mock_datetime.mock_now = start
    self.assertFalse(token.ShouldRefresh())

    mock_datetime.mock_now = start + datetime.timedelta(
        minutes=472)
    self.assertFalse(token.ShouldRefresh())

  def testSerialization(self):
    expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
    token = oauth2_client.AccessToken('foo', expiry)
    serialized_token = token.Serialize()
    LOG.debug('testSerialization: serialized_token=%s' % serialized_token)

    token2 = oauth2_client.AccessToken.UnSerialize(serialized_token)
    self.assertEquals(token, token2)


class RefreshTokenTest(unittest.TestCase):
  def setUp(self):
    self.opener = MockOpener()
    self.mock_datetime = MockDateTime()
    self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826)
    self.mock_datetime.mock_now = self.start_time
    self.client = oauth2_client.OAuth2Client(
        oauth2_client.OAuth2Provider(
            'Sample OAuth Provider',
            'https://provider.example.com/oauth/provider?mode=authorize',
            'https://provider.example.com/oauth/provider?mode=token'),
        'clid', 'clsecret',
        url_opener=self.opener, datetime_strategy=self.mock_datetime)

    self.cred = oauth2_client.RefreshToken(self.client, 'ref_token_abc123')

  def testUniqeId(self):
    cred_id = self.cred.CacheKey()
    self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id)

  def testGetAuthorizationHeader(self):
    access_token = 'access_123'
    self.opener.open_result = (
        '{"access_token":"%s","expires_in":3600}' % access_token)

    self.assertEquals('Bearer %s' % access_token,
                      self.cred.GetAuthorizationHeader())


class FileSystemTokenCacheTest(unittest.TestCase):

  def setUp(self):
    self.cache = oauth2_client.FileSystemTokenCache()
    self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826)
    self.token_1 = oauth2_client.AccessToken('token1', self.start_time)
    self.token_2 = oauth2_client.AccessToken(
        'token2', self.start_time + datetime.timedelta(seconds=492))
    self.key = 'token1key'

  def tearDown(self):
    try:
      os.unlink(self.cache.CacheFileName(self.key))
    except:
      pass

  def testPut(self):
    self.cache.PutToken(self.key, self.token_1)
    # Assert that the cache file exists and has correct permissions.
    self.assertEquals(
        0600, S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode))

  def testPutGet(self):
    # No cache file present.
    self.assertEquals(None, self.cache.GetToken(self.key))

    # Put a token
    self.cache.PutToken(self.key, self.token_1)
    cached_token = self.cache.GetToken(self.key)
    self.assertEquals(self.token_1, cached_token)

    # Put a different token
    self.cache.PutToken(self.key, self.token_2)
    cached_token = self.cache.GetToken(self.key)
    self.assertEquals(self.token_2, cached_token)

  def testGetBadFile(self):
    f = open(self.cache.CacheFileName(self.key), 'w')
    f.write('blah')
    f.close()
    self.assertEquals(None, self.cache.GetToken(self.key))

  def testCacheFileName(self):
    cache = oauth2_client.FileSystemTokenCache(
        path_pattern='/var/run/ccache/token.%(uid)s.%(key)s')
    self.assertEquals('/var/run/ccache/token.%d.abc123' % os.getuid(),
                      cache.CacheFileName('abc123'))

    cache = oauth2_client.FileSystemTokenCache(
        path_pattern='/var/run/ccache/token.%(key)s')
    self.assertEquals('/var/run/ccache/token.abc123',
                      cache.CacheFileName('abc123'))


if __name__ == '__main__':
  logging.basicConfig(level=logging.DEBUG)
  unittest.main()