You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			1644 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			1644 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
#!/usr/bin/python2.4
 | 
						|
#
 | 
						|
# Copyright 2008 Google Inc.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
"""Mox, an object-mocking framework for Python.
 | 
						|
 | 
						|
Mox works in the record-replay-verify paradigm.  When you first create
 | 
						|
a mock object, it is in record mode.  You then programmatically set
 | 
						|
the expected behavior of the mock object (what methods are to be
 | 
						|
called on it, with what parameters, what they should return, and in
 | 
						|
what order).
 | 
						|
 | 
						|
Once you have set up the expected mock behavior, you put it in replay
 | 
						|
mode.  Now the mock responds to method calls just as you told it to.
 | 
						|
If an unexpected method (or an expected method with unexpected
 | 
						|
parameters) is called, then an exception will be raised.
 | 
						|
 | 
						|
Once you are done interacting with the mock, you need to verify that
 | 
						|
all the expected interactions occured.  (Maybe your code exited
 | 
						|
prematurely without calling some cleanup method!)  The verify phase
 | 
						|
ensures that every expected method was called; otherwise, an exception
 | 
						|
will be raised.
 | 
						|
 | 
						|
WARNING! Mock objects created by Mox are not thread-safe.  If you are
 | 
						|
call a mock in multiple threads, it should be guarded by a mutex.
 | 
						|
 | 
						|
TODO(stevepm): Add the option to make mocks thread-safe!
 | 
						|
 | 
						|
Suggested usage / workflow:
 | 
						|
 | 
						|
  # Create Mox factory
 | 
						|
  my_mox = Mox()
 | 
						|
 | 
						|
  # Create a mock data access object
 | 
						|
  mock_dao = my_mox.CreateMock(DAOClass)
 | 
						|
 | 
						|
  # Set up expected behavior
 | 
						|
  mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
 | 
						|
  mock_dao.DeletePerson(person)
 | 
						|
 | 
						|
  # Put mocks in replay mode
 | 
						|
  my_mox.ReplayAll()
 | 
						|
 | 
						|
  # Inject mock object and run test
 | 
						|
  controller.SetDao(mock_dao)
 | 
						|
  controller.DeletePersonById('1')
 | 
						|
 | 
						|
  # Verify all methods were called as expected
 | 
						|
  my_mox.VerifyAll()
 | 
						|
"""
 | 
						|
 | 
						|
from collections import deque
 | 
						|
import difflib
 | 
						|
import inspect
 | 
						|
import re
 | 
						|
import types
 | 
						|
import unittest
 | 
						|
 | 
						|
import stubout
 | 
						|
 | 
						|
class Error(AssertionError):
 | 
						|
  """Base exception for this module."""
 | 
						|
 | 
						|
  pass
 | 
						|
 | 
						|
 | 
						|
class ExpectedMethodCallsError(Error):
 | 
						|
  """Raised when Verify() is called before all expected methods have been called
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, expected_methods):
 | 
						|
    """Init exception.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # expected_methods: A sequence of MockMethod objects that should have been
 | 
						|
      #   called.
 | 
						|
      expected_methods: [MockMethod]
 | 
						|
 | 
						|
    Raises:
 | 
						|
      ValueError: if expected_methods contains no methods.
 | 
						|
    """
 | 
						|
 | 
						|
    if not expected_methods:
 | 
						|
      raise ValueError("There must be at least one expected method")
 | 
						|
    Error.__init__(self)
 | 
						|
    self._expected_methods = expected_methods
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    calls = "\n".join(["%3d.  %s" % (i, m)
 | 
						|
                       for i, m in enumerate(self._expected_methods)])
 | 
						|
    return "Verify: Expected methods never called:\n%s" % (calls,)
 | 
						|
 | 
						|
 | 
						|
class UnexpectedMethodCallError(Error):
 | 
						|
  """Raised when an unexpected method is called.
 | 
						|
 | 
						|
  This can occur if a method is called with incorrect parameters, or out of the
 | 
						|
  specified order.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, unexpected_method, expected):
 | 
						|
    """Init exception.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # unexpected_method: MockMethod that was called but was not at the head of
 | 
						|
      #   the expected_method queue.
 | 
						|
      # expected: MockMethod or UnorderedGroup the method should have
 | 
						|
      #   been in.
 | 
						|
      unexpected_method: MockMethod
 | 
						|
      expected: MockMethod or UnorderedGroup
 | 
						|
    """
 | 
						|
 | 
						|
    Error.__init__(self)
 | 
						|
    if expected is None:
 | 
						|
      self._str = "Unexpected method call %s" % (unexpected_method,)
 | 
						|
    else:
 | 
						|
      differ = difflib.Differ()
 | 
						|
      diff = differ.compare(str(unexpected_method).splitlines(True),
 | 
						|
                            str(expected).splitlines(True))
 | 
						|
      self._str = ("Unexpected method call.  unexpected:-  expected:+\n%s"
 | 
						|
                   % ("\n".join(diff),))
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    return self._str
 | 
						|
 | 
						|
 | 
						|
class UnknownMethodCallError(Error):
 | 
						|
  """Raised if an unknown method is requested of the mock object."""
 | 
						|
 | 
						|
  def __init__(self, unknown_method_name):
 | 
						|
    """Init exception.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # unknown_method_name: Method call that is not part of the mocked class's
 | 
						|
      #   public interface.
 | 
						|
      unknown_method_name: str
 | 
						|
    """
 | 
						|
 | 
						|
    Error.__init__(self)
 | 
						|
    self._unknown_method_name = unknown_method_name
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    return "Method called is not a member of the object: %s" % \
 | 
						|
      self._unknown_method_name
 | 
						|
 | 
						|
 | 
						|
class Mox(object):
 | 
						|
  """Mox: a factory for creating mock objects."""
 | 
						|
 | 
						|
  # A list of types that should be stubbed out with MockObjects (as
 | 
						|
  # opposed to MockAnythings).
 | 
						|
  _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
 | 
						|
                      types.ObjectType, types.TypeType]
 | 
						|
 | 
						|
  def __init__(self):
 | 
						|
    """Initialize a new Mox."""
 | 
						|
 | 
						|
    self._mock_objects = []
 | 
						|
    self.stubs = stubout.StubOutForTesting()
 | 
						|
 | 
						|
  def CreateMock(self, class_to_mock):
 | 
						|
    """Create a new mock object.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # class_to_mock: the class to be mocked
 | 
						|
      class_to_mock: class
 | 
						|
 | 
						|
    Returns:
 | 
						|
      MockObject that can be used as the class_to_mock would be.
 | 
						|
    """
 | 
						|
 | 
						|
    new_mock = MockObject(class_to_mock)
 | 
						|
    self._mock_objects.append(new_mock)
 | 
						|
    return new_mock
 | 
						|
 | 
						|
  def CreateMockAnything(self, description=None):
 | 
						|
    """Create a mock that will accept any method calls.
 | 
						|
 | 
						|
    This does not enforce an interface.
 | 
						|
 | 
						|
    Args:
 | 
						|
      description: str. Optionally, a descriptive name for the mock object being
 | 
						|
        created, for debugging output purposes.
 | 
						|
    """
 | 
						|
    new_mock = MockAnything(description=description)
 | 
						|
    self._mock_objects.append(new_mock)
 | 
						|
    return new_mock
 | 
						|
 | 
						|
  def ReplayAll(self):
 | 
						|
    """Set all mock objects to replay mode."""
 | 
						|
 | 
						|
    for mock_obj in self._mock_objects:
 | 
						|
      mock_obj._Replay()
 | 
						|
 | 
						|
 | 
						|
  def VerifyAll(self):
 | 
						|
    """Call verify on all mock objects created."""
 | 
						|
 | 
						|
    for mock_obj in self._mock_objects:
 | 
						|
      mock_obj._Verify()
 | 
						|
 | 
						|
  def ResetAll(self):
 | 
						|
    """Call reset on all mock objects.  This does not unset stubs."""
 | 
						|
 | 
						|
    for mock_obj in self._mock_objects:
 | 
						|
      mock_obj._Reset()
 | 
						|
 | 
						|
  def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
 | 
						|
    """Replace a method, attribute, etc. with a Mock.
 | 
						|
 | 
						|
    This will replace a class or module with a MockObject, and everything else
 | 
						|
    (method, function, etc) with a MockAnything.  This can be overridden to
 | 
						|
    always use a MockAnything by setting use_mock_anything to True.
 | 
						|
 | 
						|
    Args:
 | 
						|
      obj: A Python object (class, module, instance, callable).
 | 
						|
      attr_name: str.  The name of the attribute to replace with a mock.
 | 
						|
      use_mock_anything: bool. True if a MockAnything should be used regardless
 | 
						|
        of the type of attribute.
 | 
						|
    """
 | 
						|
 | 
						|
    attr_to_replace = getattr(obj, attr_name)
 | 
						|
 | 
						|
    # Check for a MockAnything. This could cause confusing problems later on.
 | 
						|
    if attr_to_replace == MockAnything():
 | 
						|
      raise TypeError('Cannot mock a MockAnything! Did you remember to '
 | 
						|
                      'call UnsetStubs in your previous test?')
 | 
						|
 | 
						|
    if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
 | 
						|
      stub = self.CreateMock(attr_to_replace)
 | 
						|
    else:
 | 
						|
      stub = self.CreateMockAnything(description='Stub for %s' % attr_to_replace)
 | 
						|
 | 
						|
    self.stubs.Set(obj, attr_name, stub)
 | 
						|
 | 
						|
  def UnsetStubs(self):
 | 
						|
    """Restore stubs to their original state."""
 | 
						|
 | 
						|
    self.stubs.UnsetAll()
 | 
						|
 | 
						|
def Replay(*args):
 | 
						|
  """Put mocks into Replay mode.
 | 
						|
 | 
						|
  Args:
 | 
						|
    # args is any number of mocks to put into replay mode.
 | 
						|
  """
 | 
						|
 | 
						|
  for mock in args:
 | 
						|
    mock._Replay()
 | 
						|
 | 
						|
 | 
						|
def Verify(*args):
 | 
						|
  """Verify mocks.
 | 
						|
 | 
						|
  Args:
 | 
						|
    # args is any number of mocks to be verified.
 | 
						|
  """
 | 
						|
 | 
						|
  for mock in args:
 | 
						|
    mock._Verify()
 | 
						|
 | 
						|
 | 
						|
def Reset(*args):
 | 
						|
  """Reset mocks.
 | 
						|
 | 
						|
  Args:
 | 
						|
    # args is any number of mocks to be reset.
 | 
						|
  """
 | 
						|
 | 
						|
  for mock in args:
 | 
						|
    mock._Reset()
 | 
						|
 | 
						|
 | 
						|
class MockAnything:
 | 
						|
  """A mock that can be used to mock anything.
 | 
						|
 | 
						|
  This is helpful for mocking classes that do not provide a public interface.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, description=None):
 | 
						|
    """Initialize a new MockAnything.
 | 
						|
 | 
						|
    Args:
 | 
						|
      description: str. Optionally, a descriptive name for the mock object being
 | 
						|
        created, for debugging output purposes.
 | 
						|
    """
 | 
						|
    self._description = description
 | 
						|
    self._Reset()
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    return "<MockAnything instance at %s>" % id(self)
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<MockAnything instance>'
 | 
						|
 | 
						|
  def __getattr__(self, method_name):
 | 
						|
    """Intercept method calls on this object.
 | 
						|
 | 
						|
     A new MockMethod is returned that is aware of the MockAnything's
 | 
						|
     state (record or replay).  The call will be recorded or replayed
 | 
						|
     by the MockMethod's __call__.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # method name: the name of the method being called.
 | 
						|
      method_name: str
 | 
						|
 | 
						|
    Returns:
 | 
						|
      A new MockMethod aware of MockAnything's state (record or replay).
 | 
						|
    """
 | 
						|
 | 
						|
    return self._CreateMockMethod(method_name)
 | 
						|
 | 
						|
  def _CreateMockMethod(self, method_name, method_to_mock=None):
 | 
						|
    """Create a new mock method call and return it.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # method_name: the name of the method being called.
 | 
						|
      # method_to_mock: The actual method being mocked, used for introspection.
 | 
						|
      method_name: str
 | 
						|
      method_to_mock: a method object
 | 
						|
 | 
						|
    Returns:
 | 
						|
      A new MockMethod aware of MockAnything's state (record or replay).
 | 
						|
    """
 | 
						|
 | 
						|
    return MockMethod(method_name, self._expected_calls_queue,
 | 
						|
                      self._replay_mode, method_to_mock=method_to_mock,
 | 
						|
                      description=self._description)
 | 
						|
 | 
						|
  def __nonzero__(self):
 | 
						|
    """Return 1 for nonzero so the mock can be used as a conditional."""
 | 
						|
 | 
						|
    return 1
 | 
						|
 | 
						|
  def __eq__(self, rhs):
 | 
						|
    """Provide custom logic to compare objects."""
 | 
						|
 | 
						|
    return (isinstance(rhs, MockAnything) and
 | 
						|
            self._replay_mode == rhs._replay_mode and
 | 
						|
            self._expected_calls_queue == rhs._expected_calls_queue)
 | 
						|
 | 
						|
  def __ne__(self, rhs):
 | 
						|
    """Provide custom logic to compare objects."""
 | 
						|
 | 
						|
    return not self == rhs
 | 
						|
 | 
						|
  def _Replay(self):
 | 
						|
    """Start replaying expected method calls."""
 | 
						|
 | 
						|
    self._replay_mode = True
 | 
						|
 | 
						|
  def _Verify(self):
 | 
						|
    """Verify that all of the expected calls have been made.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      ExpectedMethodCallsError: if there are still more method calls in the
 | 
						|
        expected queue.
 | 
						|
    """
 | 
						|
 | 
						|
    # If the list of expected calls is not empty, raise an exception
 | 
						|
    if self._expected_calls_queue:
 | 
						|
      # The last MultipleTimesGroup is not popped from the queue.
 | 
						|
      if (len(self._expected_calls_queue) == 1 and
 | 
						|
          isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
 | 
						|
          self._expected_calls_queue[0].IsSatisfied()):
 | 
						|
        pass
 | 
						|
      else:
 | 
						|
        raise ExpectedMethodCallsError(self._expected_calls_queue)
 | 
						|
 | 
						|
  def _Reset(self):
 | 
						|
    """Reset the state of this mock to record mode with an empty queue."""
 | 
						|
 | 
						|
    # Maintain a list of method calls we are expecting
 | 
						|
    self._expected_calls_queue = deque()
 | 
						|
 | 
						|
    # Make sure we are in setup mode, not replay mode
 | 
						|
    self._replay_mode = False
 | 
						|
 | 
						|
 | 
						|
class MockObject(MockAnything, object):
 | 
						|
  """A mock object that simulates the public/protected interface of a class."""
 | 
						|
 | 
						|
  def __init__(self, class_to_mock):
 | 
						|
    """Initialize a mock object.
 | 
						|
 | 
						|
    This determines the methods and properties of the class and stores them.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # class_to_mock: class to be mocked
 | 
						|
      class_to_mock: class
 | 
						|
    """
 | 
						|
 | 
						|
    # This is used to hack around the mixin/inheritance of MockAnything, which
 | 
						|
    # is not a proper object (it can be anything. :-)
 | 
						|
    MockAnything.__dict__['__init__'](self)
 | 
						|
 | 
						|
    # Get a list of all the public and special methods we should mock.
 | 
						|
    self._known_methods = set()
 | 
						|
    self._known_vars = set()
 | 
						|
    self._class_to_mock = class_to_mock
 | 
						|
    for method in dir(class_to_mock):
 | 
						|
      if callable(getattr(class_to_mock, method)):
 | 
						|
        self._known_methods.add(method)
 | 
						|
      else:
 | 
						|
        self._known_vars.add(method)
 | 
						|
 | 
						|
  def __getattr__(self, name):
 | 
						|
    """Intercept attribute request on this object.
 | 
						|
 | 
						|
    If the attribute is a public class variable, it will be returned and not
 | 
						|
    recorded as a call.
 | 
						|
 | 
						|
    If the attribute is not a variable, it is handled like a method
 | 
						|
    call. The method name is checked against the set of mockable
 | 
						|
    methods, and a new MockMethod is returned that is aware of the
 | 
						|
    MockObject's state (record or replay).  The call will be recorded
 | 
						|
    or replayed by the MockMethod's __call__.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # name: the name of the attribute being requested.
 | 
						|
      name: str
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Either a class variable or a new MockMethod that is aware of the state
 | 
						|
      of the mock (record or replay).
 | 
						|
 | 
						|
    Raises:
 | 
						|
      UnknownMethodCallError if the MockObject does not mock the requested
 | 
						|
          method.
 | 
						|
    """
 | 
						|
 | 
						|
    if name in self._known_vars:
 | 
						|
      return getattr(self._class_to_mock, name)
 | 
						|
 | 
						|
    if name in self._known_methods:
 | 
						|
      return self._CreateMockMethod(
 | 
						|
          name,
 | 
						|
          method_to_mock=getattr(self._class_to_mock, name))
 | 
						|
 | 
						|
    raise UnknownMethodCallError(name)
 | 
						|
 | 
						|
  def __eq__(self, rhs):
 | 
						|
    """Provide custom logic to compare objects."""
 | 
						|
 | 
						|
    return (isinstance(rhs, MockObject) and
 | 
						|
            self._class_to_mock == rhs._class_to_mock and
 | 
						|
            self._replay_mode == rhs._replay_mode and
 | 
						|
            self._expected_calls_queue == rhs._expected_calls_queue)
 | 
						|
 | 
						|
  def __setitem__(self, key, value):
 | 
						|
    """Provide custom logic for mocking classes that support item assignment.
 | 
						|
 | 
						|
    Args:
 | 
						|
      key: Key to set the value for.
 | 
						|
      value: Value to set.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Expected return value in replay mode.  A MockMethod object for the
 | 
						|
      __setitem__ method that has already been called if not in replay mode.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      TypeError if the underlying class does not support item assignment.
 | 
						|
      UnexpectedMethodCallError if the object does not expect the call to
 | 
						|
        __setitem__.
 | 
						|
 | 
						|
    """
 | 
						|
    # Verify the class supports item assignment.
 | 
						|
    if '__setitem__' not in dir(self._class_to_mock):
 | 
						|
      raise TypeError('object does not support item assignment')
 | 
						|
 | 
						|
    # If we are in replay mode then simply call the mock __setitem__ method.
 | 
						|
    if self._replay_mode:
 | 
						|
      return MockMethod('__setitem__', self._expected_calls_queue,
 | 
						|
                        self._replay_mode)(key, value)
 | 
						|
 | 
						|
 | 
						|
    # Otherwise, create a mock method __setitem__.
 | 
						|
    return self._CreateMockMethod('__setitem__')(key, value)
 | 
						|
 | 
						|
  def __getitem__(self, key):
 | 
						|
    """Provide custom logic for mocking classes that are subscriptable.
 | 
						|
 | 
						|
    Args:
 | 
						|
      key: Key to return the value for.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Expected return value in replay mode.  A MockMethod object for the
 | 
						|
      __getitem__ method that has already been called if not in replay mode.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      TypeError if the underlying class is not subscriptable.
 | 
						|
      UnexpectedMethodCallError if the object does not expect the call to
 | 
						|
        __getitem__.
 | 
						|
 | 
						|
    """
 | 
						|
    # Verify the class supports item assignment.
 | 
						|
    if '__getitem__' not in dir(self._class_to_mock):
 | 
						|
      raise TypeError('unsubscriptable object')
 | 
						|
 | 
						|
    # If we are in replay mode then simply call the mock __getitem__ method.
 | 
						|
    if self._replay_mode:
 | 
						|
      return MockMethod('__getitem__', self._expected_calls_queue,
 | 
						|
                        self._replay_mode)(key)
 | 
						|
 | 
						|
 | 
						|
    # Otherwise, create a mock method __getitem__.
 | 
						|
    return self._CreateMockMethod('__getitem__')(key)
 | 
						|
 | 
						|
  def __iter__(self):
 | 
						|
    """Provide custom logic for mocking classes that are iterable.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Expected return value in replay mode.  A MockMethod object for the
 | 
						|
      __iter__ method that has already been called if not in replay mode.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      TypeError if the underlying class is not iterable.
 | 
						|
      UnexpectedMethodCallError if the object does not expect the call to
 | 
						|
        __iter__.
 | 
						|
 | 
						|
    """
 | 
						|
    methods = dir(self._class_to_mock)
 | 
						|
 | 
						|
    # Verify the class supports iteration.
 | 
						|
    if '__iter__' not in methods:
 | 
						|
      # If it doesn't have iter method and we are in replay method, then try to
 | 
						|
      # iterate using subscripts.
 | 
						|
      if '__getitem__' not in methods or not self._replay_mode:
 | 
						|
        raise TypeError('not iterable object')
 | 
						|
      else:
 | 
						|
        results = []
 | 
						|
        index = 0
 | 
						|
        try:
 | 
						|
          while True:
 | 
						|
            results.append(self[index])
 | 
						|
            index += 1
 | 
						|
        except IndexError:
 | 
						|
          return iter(results)
 | 
						|
 | 
						|
    # If we are in replay mode then simply call the mock __iter__ method.
 | 
						|
    if self._replay_mode:
 | 
						|
      return MockMethod('__iter__', self._expected_calls_queue,
 | 
						|
                        self._replay_mode)()
 | 
						|
 | 
						|
 | 
						|
    # Otherwise, create a mock method __iter__.
 | 
						|
    return self._CreateMockMethod('__iter__')()
 | 
						|
 | 
						|
 | 
						|
  def __contains__(self, key):
 | 
						|
    """Provide custom logic for mocking classes that contain items.
 | 
						|
 | 
						|
    Args:
 | 
						|
      key: Key to look in container for.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Expected return value in replay mode.  A MockMethod object for the
 | 
						|
      __contains__ method that has already been called if not in replay mode.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      TypeError if the underlying class does not implement __contains__
 | 
						|
      UnexpectedMethodCaller if the object does not expect the call to
 | 
						|
      __contains__.
 | 
						|
 | 
						|
    """
 | 
						|
    contains = self._class_to_mock.__dict__.get('__contains__', None)
 | 
						|
 | 
						|
    if contains is None:
 | 
						|
      raise TypeError('unsubscriptable object')
 | 
						|
 | 
						|
    if self._replay_mode:
 | 
						|
      return MockMethod('__contains__', self._expected_calls_queue,
 | 
						|
                        self._replay_mode)(key)
 | 
						|
 | 
						|
    return self._CreateMockMethod('__contains__')(key)
 | 
						|
 | 
						|
  def __call__(self, *params, **named_params):
 | 
						|
    """Provide custom logic for mocking classes that are callable."""
 | 
						|
 | 
						|
    # Verify the class we are mocking is callable.
 | 
						|
    callable = hasattr(self._class_to_mock, '__call__')
 | 
						|
    if not callable:
 | 
						|
      raise TypeError('Not callable')
 | 
						|
 | 
						|
    # Because the call is happening directly on this object instead of a method,
 | 
						|
    # the call on the mock method is made right here
 | 
						|
    mock_method = self._CreateMockMethod('__call__')
 | 
						|
    return mock_method(*params, **named_params)
 | 
						|
 | 
						|
  @property
 | 
						|
  def __class__(self):
 | 
						|
    """Return the class that is being mocked."""
 | 
						|
 | 
						|
    return self._class_to_mock
 | 
						|
 | 
						|
 | 
						|
class MethodCallChecker(object):
 | 
						|
  """Ensures that methods are called correctly."""
 | 
						|
 | 
						|
  _NEEDED, _DEFAULT, _GIVEN = range(3)
 | 
						|
 | 
						|
  def __init__(self, method):
 | 
						|
    """Creates a checker.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # method: A method to check.
 | 
						|
      method: function
 | 
						|
 | 
						|
    Raises:
 | 
						|
      ValueError: method could not be inspected, so checks aren't possible.
 | 
						|
        Some methods and functions like built-ins can't be inspected.
 | 
						|
    """
 | 
						|
    try:
 | 
						|
      self._args, varargs, varkw, defaults = inspect.getargspec(method)
 | 
						|
    except TypeError:
 | 
						|
      raise ValueError('Could not get argument specification for %r'
 | 
						|
                       % (method,))
 | 
						|
    if inspect.ismethod(method):
 | 
						|
      self._args = self._args[1:]  # Skip 'self'.
 | 
						|
    self._method = method
 | 
						|
 | 
						|
    self._has_varargs = varargs is not None
 | 
						|
    self._has_varkw = varkw is not None
 | 
						|
    if defaults is None:
 | 
						|
      self._required_args = self._args
 | 
						|
      self._default_args = []
 | 
						|
    else:
 | 
						|
      self._required_args = self._args[:-len(defaults)]
 | 
						|
      self._default_args = self._args[-len(defaults):]
 | 
						|
 | 
						|
  def _RecordArgumentGiven(self, arg_name, arg_status):
 | 
						|
    """Mark an argument as being given.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # arg_name: The name of the argument to mark in arg_status.
 | 
						|
      # arg_status: Maps argument names to one of _NEEDED, _DEFAULT, _GIVEN.
 | 
						|
      arg_name: string
 | 
						|
      arg_status: dict
 | 
						|
 | 
						|
    Raises:
 | 
						|
      AttributeError: arg_name is already marked as _GIVEN.
 | 
						|
    """
 | 
						|
    if arg_status.get(arg_name, None) == MethodCallChecker._GIVEN:
 | 
						|
      raise AttributeError('%s provided more than once' % (arg_name,))
 | 
						|
    arg_status[arg_name] = MethodCallChecker._GIVEN
 | 
						|
 | 
						|
  def Check(self, params, named_params):
 | 
						|
    """Ensures that the parameters used while recording a call are valid.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # params: A list of positional parameters.
 | 
						|
      # named_params: A dict of named parameters.
 | 
						|
      params: list
 | 
						|
      named_params: dict
 | 
						|
 | 
						|
    Raises:
 | 
						|
      AttributeError: the given parameters don't work with the given method.
 | 
						|
    """
 | 
						|
    arg_status = dict((a, MethodCallChecker._NEEDED)
 | 
						|
                      for a in self._required_args)
 | 
						|
    for arg in self._default_args:
 | 
						|
      arg_status[arg] = MethodCallChecker._DEFAULT
 | 
						|
 | 
						|
    # Check that each positional param is valid.
 | 
						|
    for i in range(len(params)):
 | 
						|
      try:
 | 
						|
        arg_name = self._args[i]
 | 
						|
      except IndexError:
 | 
						|
        if not self._has_varargs:
 | 
						|
          raise AttributeError('%s does not take %d or more positional '
 | 
						|
                               'arguments' % (self._method.__name__, i))
 | 
						|
      else:
 | 
						|
        self._RecordArgumentGiven(arg_name, arg_status)
 | 
						|
 | 
						|
    # Check each keyword argument.
 | 
						|
    for arg_name in named_params:
 | 
						|
      if arg_name not in arg_status and not self._has_varkw:
 | 
						|
        raise AttributeError('%s is not expecting keyword argument %s'
 | 
						|
                             % (self._method.__name__, arg_name))
 | 
						|
      self._RecordArgumentGiven(arg_name, arg_status)
 | 
						|
 | 
						|
    # Ensure all the required arguments have been given.
 | 
						|
    still_needed = [k for k, v in arg_status.iteritems()
 | 
						|
                    if v == MethodCallChecker._NEEDED]
 | 
						|
    if still_needed:
 | 
						|
      raise AttributeError('No values given for arguments %s'
 | 
						|
                           % (' '.join(sorted(still_needed))))
 | 
						|
 | 
						|
 | 
						|
class MockMethod(object):
 | 
						|
  """Callable mock method.
 | 
						|
 | 
						|
  A MockMethod should act exactly like the method it mocks, accepting parameters
 | 
						|
  and returning a value, or throwing an exception (as specified).  When this
 | 
						|
  method is called, it can optionally verify whether the called method (name and
 | 
						|
  signature) matches the expected method.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, method_name, call_queue, replay_mode,
 | 
						|
               method_to_mock=None, description=None):
 | 
						|
    """Construct a new mock method.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # method_name: the name of the method
 | 
						|
      # call_queue: deque of calls, verify this call against the head, or add
 | 
						|
      #     this call to the queue.
 | 
						|
      # replay_mode: False if we are recording, True if we are verifying calls
 | 
						|
      #     against the call queue.
 | 
						|
      # method_to_mock: The actual method being mocked, used for introspection.
 | 
						|
      # description: optionally, a descriptive name for this method. Typically
 | 
						|
      #     this is equal to the descriptive name of the method's class.
 | 
						|
      method_name: str
 | 
						|
      call_queue: list or deque
 | 
						|
      replay_mode: bool
 | 
						|
      method_to_mock: a method object
 | 
						|
      description: str or None
 | 
						|
    """
 | 
						|
 | 
						|
    self._name = method_name
 | 
						|
    self._call_queue = call_queue
 | 
						|
    if not isinstance(call_queue, deque):
 | 
						|
      self._call_queue = deque(self._call_queue)
 | 
						|
    self._replay_mode = replay_mode
 | 
						|
    self._description = description
 | 
						|
 | 
						|
    self._params = None
 | 
						|
    self._named_params = None
 | 
						|
    self._return_value = None
 | 
						|
    self._exception = None
 | 
						|
    self._side_effects = None
 | 
						|
 | 
						|
    try:
 | 
						|
      self._checker = MethodCallChecker(method_to_mock)
 | 
						|
    except ValueError:
 | 
						|
      self._checker = None
 | 
						|
 | 
						|
  def __call__(self, *params, **named_params):
 | 
						|
    """Log parameters and return the specified return value.
 | 
						|
 | 
						|
    If the Mock(Anything/Object) associated with this call is in record mode,
 | 
						|
    this MockMethod will be pushed onto the expected call queue.  If the mock
 | 
						|
    is in replay mode, this will pop a MockMethod off the top of the queue and
 | 
						|
    verify this call is equal to the expected call.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      UnexpectedMethodCall if this call is supposed to match an expected method
 | 
						|
        call and it does not.
 | 
						|
    """
 | 
						|
 | 
						|
    self._params = params
 | 
						|
    self._named_params = named_params
 | 
						|
 | 
						|
    if not self._replay_mode:
 | 
						|
      if self._checker is not None:
 | 
						|
        self._checker.Check(params, named_params)
 | 
						|
      self._call_queue.append(self)
 | 
						|
      return self
 | 
						|
 | 
						|
    expected_method = self._VerifyMethodCall()
 | 
						|
 | 
						|
    if expected_method._side_effects:
 | 
						|
      expected_method._side_effects(*params, **named_params)
 | 
						|
 | 
						|
    if expected_method._exception:
 | 
						|
      raise expected_method._exception
 | 
						|
 | 
						|
    return expected_method._return_value
 | 
						|
 | 
						|
  def __getattr__(self, name):
 | 
						|
    """Raise an AttributeError with a helpful message."""
 | 
						|
 | 
						|
    raise AttributeError('MockMethod has no attribute "%s". '
 | 
						|
        'Did you remember to put your mocks in replay mode?' % name)
 | 
						|
 | 
						|
  def __iter__(self):
 | 
						|
    """Raise a TypeError with a helpful message."""
 | 
						|
    raise TypeError('MockMethod cannot be iterated. '
 | 
						|
                    'Did you remember to put your mocks in replay mode?')
 | 
						|
 | 
						|
  def next(self):
 | 
						|
    """Raise a TypeError with a helpful message."""
 | 
						|
    raise TypeError('MockMethod cannot be iterated. '
 | 
						|
                    'Did you remember to put your mocks in replay mode?')
 | 
						|
 | 
						|
  def _PopNextMethod(self):
 | 
						|
    """Pop the next method from our call queue."""
 | 
						|
    try:
 | 
						|
      return self._call_queue.popleft()
 | 
						|
    except IndexError:
 | 
						|
      raise UnexpectedMethodCallError(self, None)
 | 
						|
 | 
						|
  def _VerifyMethodCall(self):
 | 
						|
    """Verify the called method is expected.
 | 
						|
 | 
						|
    This can be an ordered method, or part of an unordered set.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      The expected mock method.
 | 
						|
 | 
						|
    Raises:
 | 
						|
      UnexpectedMethodCall if the method called was not expected.
 | 
						|
    """
 | 
						|
 | 
						|
    expected = self._PopNextMethod()
 | 
						|
 | 
						|
    # Loop here, because we might have a MethodGroup followed by another
 | 
						|
    # group.
 | 
						|
    while isinstance(expected, MethodGroup):
 | 
						|
      expected, method = expected.MethodCalled(self)
 | 
						|
      if method is not None:
 | 
						|
        return method
 | 
						|
 | 
						|
    # This is a mock method, so just check equality.
 | 
						|
    if expected != self:
 | 
						|
      raise UnexpectedMethodCallError(self, expected)
 | 
						|
 | 
						|
    return expected
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    params = ', '.join(
 | 
						|
        [repr(p) for p in self._params or []] +
 | 
						|
        ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
 | 
						|
    full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
 | 
						|
    if self._description:
 | 
						|
      full_desc = "%s.%s" % (self._description, full_desc)
 | 
						|
    return full_desc
 | 
						|
 | 
						|
  def __eq__(self, rhs):
 | 
						|
    """Test whether this MockMethod is equivalent to another MockMethod.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: the right hand side of the test
 | 
						|
      rhs: MockMethod
 | 
						|
    """
 | 
						|
 | 
						|
    return (isinstance(rhs, MockMethod) and
 | 
						|
            self._name == rhs._name and
 | 
						|
            self._params == rhs._params and
 | 
						|
            self._named_params == rhs._named_params)
 | 
						|
 | 
						|
  def __ne__(self, rhs):
 | 
						|
    """Test whether this MockMethod is not equivalent to another MockMethod.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: the right hand side of the test
 | 
						|
      rhs: MockMethod
 | 
						|
    """
 | 
						|
 | 
						|
    return not self == rhs
 | 
						|
 | 
						|
  def GetPossibleGroup(self):
 | 
						|
    """Returns a possible group from the end of the call queue or None if no
 | 
						|
    other methods are on the stack.
 | 
						|
    """
 | 
						|
 | 
						|
    # Remove this method from the tail of the queue so we can add it to a group.
 | 
						|
    this_method = self._call_queue.pop()
 | 
						|
    assert this_method == self
 | 
						|
 | 
						|
    # Determine if the tail of the queue is a group, or just a regular ordered
 | 
						|
    # mock method.
 | 
						|
    group = None
 | 
						|
    try:
 | 
						|
      group = self._call_queue[-1]
 | 
						|
    except IndexError:
 | 
						|
      pass
 | 
						|
 | 
						|
    return group
 | 
						|
 | 
						|
  def _CheckAndCreateNewGroup(self, group_name, group_class):
 | 
						|
    """Checks if the last method (a possible group) is an instance of our
 | 
						|
    group_class. Adds the current method to this group or creates a new one.
 | 
						|
 | 
						|
    Args:
 | 
						|
 | 
						|
      group_name: the name of the group.
 | 
						|
      group_class: the class used to create instance of this new group
 | 
						|
    """
 | 
						|
    group = self.GetPossibleGroup()
 | 
						|
 | 
						|
    # If this is a group, and it is the correct group, add the method.
 | 
						|
    if isinstance(group, group_class) and group.group_name() == group_name:
 | 
						|
      group.AddMethod(self)
 | 
						|
      return self
 | 
						|
 | 
						|
    # Create a new group and add the method.
 | 
						|
    new_group = group_class(group_name)
 | 
						|
    new_group.AddMethod(self)
 | 
						|
    self._call_queue.append(new_group)
 | 
						|
    return self
 | 
						|
 | 
						|
  def InAnyOrder(self, group_name="default"):
 | 
						|
    """Move this method into a group of unordered calls.
 | 
						|
 | 
						|
    A group of unordered calls must be defined together, and must be executed
 | 
						|
    in full before the next expected method can be called.  There can be
 | 
						|
    multiple groups that are expected serially, if they are given
 | 
						|
    different group names.  The same group name can be reused if there is a
 | 
						|
    standard method call, or a group with a different name, spliced between
 | 
						|
    usages.
 | 
						|
 | 
						|
    Args:
 | 
						|
      group_name: the name of the unordered group.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      self
 | 
						|
    """
 | 
						|
    return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
 | 
						|
 | 
						|
  def MultipleTimes(self, group_name="default"):
 | 
						|
    """Move this method into group of calls which may be called multiple times.
 | 
						|
 | 
						|
    A group of repeating calls must be defined together, and must be executed in
 | 
						|
    full before the next expected mehtod can be called.
 | 
						|
 | 
						|
    Args:
 | 
						|
      group_name: the name of the unordered group.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      self
 | 
						|
    """
 | 
						|
    return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
 | 
						|
 | 
						|
  def AndReturn(self, return_value):
 | 
						|
    """Set the value to return when this method is called.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # return_value can be anything.
 | 
						|
    """
 | 
						|
 | 
						|
    self._return_value = return_value
 | 
						|
    return return_value
 | 
						|
 | 
						|
  def AndRaise(self, exception):
 | 
						|
    """Set the exception to raise when this method is called.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # exception: the exception to raise when this method is called.
 | 
						|
      exception: Exception
 | 
						|
    """
 | 
						|
 | 
						|
    self._exception = exception
 | 
						|
 | 
						|
  def WithSideEffects(self, side_effects):
 | 
						|
    """Set the side effects that are simulated when this method is called.
 | 
						|
 | 
						|
    Args:
 | 
						|
      side_effects: A callable which modifies the parameters or other relevant
 | 
						|
        state which a given test case depends on.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      Self for chaining with AndReturn and AndRaise.
 | 
						|
    """
 | 
						|
    self._side_effects = side_effects
 | 
						|
    return self
 | 
						|
 | 
						|
class Comparator:
 | 
						|
  """Base class for all Mox comparators.
 | 
						|
 | 
						|
  A Comparator can be used as a parameter to a mocked method when the exact
 | 
						|
  value is not known.  For example, the code you are testing might build up a
 | 
						|
  long SQL string that is passed to your mock DAO. You're only interested that
 | 
						|
  the IN clause contains the proper primary keys, so you can set your mock
 | 
						|
  up as follows:
 | 
						|
 | 
						|
  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
 | 
						|
 | 
						|
  Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
 | 
						|
 | 
						|
  A Comparator may replace one or more parameters, for example:
 | 
						|
  # return at most 10 rows
 | 
						|
  mock_dao.RunQuery(StrContains('SELECT'), 10)
 | 
						|
 | 
						|
  or
 | 
						|
 | 
						|
  # Return some non-deterministic number of rows
 | 
						|
  mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
 | 
						|
  """
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Special equals method that all comparators must implement.
 | 
						|
 | 
						|
    Args:
 | 
						|
      rhs: any python object
 | 
						|
    """
 | 
						|
 | 
						|
    raise NotImplementedError, 'method must be implemented by a subclass.'
 | 
						|
 | 
						|
  def __eq__(self, rhs):
 | 
						|
    return self.equals(rhs)
 | 
						|
 | 
						|
  def __ne__(self, rhs):
 | 
						|
    return not self.equals(rhs)
 | 
						|
 | 
						|
 | 
						|
class IsA(Comparator):
 | 
						|
  """This class wraps a basic Python type or class.  It is used to verify
 | 
						|
  that a parameter is of the given type or class.
 | 
						|
 | 
						|
  Example:
 | 
						|
  mock_dao.Connect(IsA(DbConnectInfo))
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, class_name):
 | 
						|
    """Initialize IsA
 | 
						|
 | 
						|
    Args:
 | 
						|
      class_name: basic python type or a class
 | 
						|
    """
 | 
						|
 | 
						|
    self._class_name = class_name
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see if the RHS is an instance of class_name.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: the right hand side of the test
 | 
						|
      rhs: object
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    try:
 | 
						|
      return isinstance(rhs, self._class_name)
 | 
						|
    except TypeError:
 | 
						|
      # Check raw types if there was a type error.  This is helpful for
 | 
						|
      # things like cStringIO.StringIO.
 | 
						|
      return type(rhs) == type(self._class_name)
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return str(self._class_name)
 | 
						|
 | 
						|
class IsAlmost(Comparator):
 | 
						|
  """Comparison class used to check whether a parameter is nearly equal
 | 
						|
  to a given value.  Generally useful for floating point numbers.
 | 
						|
 | 
						|
  Example mock_dao.SetTimeout((IsAlmost(3.9)))
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, float_value, places=7):
 | 
						|
    """Initialize IsAlmost.
 | 
						|
 | 
						|
    Args:
 | 
						|
      float_value: The value for making the comparison.
 | 
						|
      places: The number of decimal places to round to.
 | 
						|
    """
 | 
						|
 | 
						|
    self._float_value = float_value
 | 
						|
    self._places = places
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see if RHS is almost equal to float_value
 | 
						|
 | 
						|
    Args:
 | 
						|
      rhs: the value to compare to float_value
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    try:
 | 
						|
      return round(rhs-self._float_value, self._places) == 0
 | 
						|
    except TypeError:
 | 
						|
      # This is probably because either float_value or rhs is not a number.
 | 
						|
      return False
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return str(self._float_value)
 | 
						|
 | 
						|
class StrContains(Comparator):
 | 
						|
  """Comparison class used to check whether a substring exists in a
 | 
						|
  string parameter.  This can be useful in mocking a database with SQL
 | 
						|
  passed in as a string parameter, for example.
 | 
						|
 | 
						|
  Example:
 | 
						|
  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, search_string):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # search_string: the string you are searching for
 | 
						|
      search_string: str
 | 
						|
    """
 | 
						|
 | 
						|
    self._search_string = search_string
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see if the search_string is contained in the rhs string.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: the right hand side of the test
 | 
						|
      rhs: object
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    try:
 | 
						|
      return rhs.find(self._search_string) > -1
 | 
						|
    except Exception:
 | 
						|
      return False
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<str containing \'%s\'>' % self._search_string
 | 
						|
 | 
						|
 | 
						|
class Regex(Comparator):
 | 
						|
  """Checks if a string matches a regular expression.
 | 
						|
 | 
						|
  This uses a given regular expression to determine equality.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, pattern, flags=0):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # pattern is the regular expression to search for
 | 
						|
      pattern: str
 | 
						|
      # flags passed to re.compile function as the second argument
 | 
						|
      flags: int
 | 
						|
    """
 | 
						|
 | 
						|
    self.regex = re.compile(pattern, flags=flags)
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see if rhs matches regular expression pattern.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    return self.regex.search(rhs) is not None
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    s = '<regular expression \'%s\'' % self.regex.pattern
 | 
						|
    if self.regex.flags:
 | 
						|
      s += ', flags=%d' % self.regex.flags
 | 
						|
    s += '>'
 | 
						|
    return s
 | 
						|
 | 
						|
 | 
						|
class In(Comparator):
 | 
						|
  """Checks whether an item (or key) is in a list (or dict) parameter.
 | 
						|
 | 
						|
  Example:
 | 
						|
  mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, key):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # key is any thing that could be in a list or a key in a dict
 | 
						|
    """
 | 
						|
 | 
						|
    self._key = key
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see whether key is in rhs.
 | 
						|
 | 
						|
    Args:
 | 
						|
      rhs: dict
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    return self._key in rhs
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<sequence or map containing \'%s\'>' % self._key
 | 
						|
 | 
						|
 | 
						|
class Not(Comparator):
 | 
						|
  """Checks whether a predicates is False.
 | 
						|
 | 
						|
  Example:
 | 
						|
    mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', stevepm_user_info)))
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, predicate):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # predicate: a Comparator instance.
 | 
						|
    """
 | 
						|
 | 
						|
    assert isinstance(predicate, Comparator), ("predicate %r must be a"
 | 
						|
                                               " Comparator." % predicate)
 | 
						|
    self._predicate = predicate
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check to see whether the predicate is False.
 | 
						|
 | 
						|
    Args:
 | 
						|
      rhs: A value that will be given in argument of the predicate.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    return not self._predicate.equals(rhs)
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<not \'%s\'>' % self._predicate
 | 
						|
 | 
						|
 | 
						|
class ContainsKeyValue(Comparator):
 | 
						|
  """Checks whether a key/value pair is in a dict parameter.
 | 
						|
 | 
						|
  Example:
 | 
						|
  mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, key, value):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # key: a key in a dict
 | 
						|
      # value: the corresponding value
 | 
						|
    """
 | 
						|
 | 
						|
    self._key = key
 | 
						|
    self._value = value
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Check whether the given key/value pair is in the rhs dict.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    try:
 | 
						|
      return rhs[self._key] == self._value
 | 
						|
    except Exception:
 | 
						|
      return False
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
 | 
						|
 | 
						|
 | 
						|
class SameElementsAs(Comparator):
 | 
						|
  """Checks whether iterables contain the same elements (ignoring order).
 | 
						|
 | 
						|
  Example:
 | 
						|
  mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, expected_seq):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      expected_seq: a sequence
 | 
						|
    """
 | 
						|
 | 
						|
    self._expected_seq = expected_seq
 | 
						|
 | 
						|
  def equals(self, actual_seq):
 | 
						|
    """Check to see whether actual_seq has same elements as expected_seq.
 | 
						|
 | 
						|
    Args:
 | 
						|
      actual_seq: sequence
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    try:
 | 
						|
      expected = dict([(element, None) for element in self._expected_seq])
 | 
						|
      actual = dict([(element, None) for element in actual_seq])
 | 
						|
    except TypeError:
 | 
						|
      # Fall back to slower list-compare if any of the objects are unhashable.
 | 
						|
      expected = list(self._expected_seq)
 | 
						|
      actual = list(actual_seq)
 | 
						|
      expected.sort()
 | 
						|
      actual.sort()
 | 
						|
    return expected == actual
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<sequence with same elements as \'%s\'>' % self._expected_seq
 | 
						|
 | 
						|
 | 
						|
class And(Comparator):
 | 
						|
  """Evaluates one or more Comparators on RHS and returns an AND of the results.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, *args):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      *args: One or more Comparator
 | 
						|
    """
 | 
						|
 | 
						|
    self._comparators = args
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Checks whether all Comparators are equal to rhs.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: can be anything
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    for comparator in self._comparators:
 | 
						|
      if not comparator.equals(rhs):
 | 
						|
        return False
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<AND %s>' % str(self._comparators)
 | 
						|
 | 
						|
 | 
						|
class Or(Comparator):
 | 
						|
  """Evaluates one or more Comparators on RHS and returns an OR of the results.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, *args):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      *args: One or more Mox comparators
 | 
						|
    """
 | 
						|
 | 
						|
    self._comparators = args
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Checks whether any Comparator is equal to rhs.
 | 
						|
 | 
						|
    Args:
 | 
						|
      # rhs: can be anything
 | 
						|
 | 
						|
    Returns:
 | 
						|
      bool
 | 
						|
    """
 | 
						|
 | 
						|
    for comparator in self._comparators:
 | 
						|
      if comparator.equals(rhs):
 | 
						|
        return True
 | 
						|
 | 
						|
    return False
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<OR %s>' % str(self._comparators)
 | 
						|
 | 
						|
 | 
						|
class Func(Comparator):
 | 
						|
  """Call a function that should verify the parameter passed in is correct.
 | 
						|
 | 
						|
  You may need the ability to perform more advanced operations on the parameter
 | 
						|
  in order to validate it.  You can use this to have a callable validate any
 | 
						|
  parameter. The callable should return either True or False.
 | 
						|
 | 
						|
 | 
						|
  Example:
 | 
						|
 | 
						|
  def myParamValidator(param):
 | 
						|
    # Advanced logic here
 | 
						|
    return True
 | 
						|
 | 
						|
  mock_dao.DoSomething(Func(myParamValidator), true)
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, func):
 | 
						|
    """Initialize.
 | 
						|
 | 
						|
    Args:
 | 
						|
      func: callable that takes one parameter and returns a bool
 | 
						|
    """
 | 
						|
 | 
						|
    self._func = func
 | 
						|
 | 
						|
  def equals(self, rhs):
 | 
						|
    """Test whether rhs passes the function test.
 | 
						|
 | 
						|
    rhs is passed into func.
 | 
						|
 | 
						|
    Args:
 | 
						|
      rhs: any python object
 | 
						|
 | 
						|
    Returns:
 | 
						|
      the result of func(rhs)
 | 
						|
    """
 | 
						|
 | 
						|
    return self._func(rhs)
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return str(self._func)
 | 
						|
 | 
						|
 | 
						|
class IgnoreArg(Comparator):
 | 
						|
  """Ignore an argument.
 | 
						|
 | 
						|
  This can be used when we don't care about an argument of a method call.
 | 
						|
 | 
						|
  Example:
 | 
						|
  # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
 | 
						|
  mymock.CastMagic(3, IgnoreArg(), 'disappear')
 | 
						|
  """
 | 
						|
 | 
						|
  def equals(self, unused_rhs):
 | 
						|
    """Ignores arguments and returns True.
 | 
						|
 | 
						|
    Args:
 | 
						|
      unused_rhs: any python object
 | 
						|
 | 
						|
    Returns:
 | 
						|
      always returns True
 | 
						|
    """
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return '<IgnoreArg>'
 | 
						|
 | 
						|
 | 
						|
class MethodGroup(object):
 | 
						|
  """Base class containing common behaviour for MethodGroups."""
 | 
						|
 | 
						|
  def __init__(self, group_name):
 | 
						|
    self._group_name = group_name
 | 
						|
 | 
						|
  def group_name(self):
 | 
						|
    return self._group_name
 | 
						|
 | 
						|
  def __str__(self):
 | 
						|
    return '<%s "%s">' % (self.__class__.__name__, self._group_name)
 | 
						|
 | 
						|
  def AddMethod(self, mock_method):
 | 
						|
    raise NotImplementedError
 | 
						|
 | 
						|
  def MethodCalled(self, mock_method):
 | 
						|
    raise NotImplementedError
 | 
						|
 | 
						|
  def IsSatisfied(self):
 | 
						|
    raise NotImplementedError
 | 
						|
 | 
						|
class UnorderedGroup(MethodGroup):
 | 
						|
  """UnorderedGroup holds a set of method calls that may occur in any order.
 | 
						|
 | 
						|
  This construct is helpful for non-deterministic events, such as iterating
 | 
						|
  over the keys of a dict.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, group_name):
 | 
						|
    super(UnorderedGroup, self).__init__(group_name)
 | 
						|
    self._methods = []
 | 
						|
 | 
						|
  def AddMethod(self, mock_method):
 | 
						|
    """Add a method to this group.
 | 
						|
 | 
						|
    Args:
 | 
						|
      mock_method: A mock method to be added to this group.
 | 
						|
    """
 | 
						|
 | 
						|
    self._methods.append(mock_method)
 | 
						|
 | 
						|
  def MethodCalled(self, mock_method):
 | 
						|
    """Remove a method call from the group.
 | 
						|
 | 
						|
    If the method is not in the set, an UnexpectedMethodCallError will be
 | 
						|
    raised.
 | 
						|
 | 
						|
    Args:
 | 
						|
      mock_method: a mock method that should be equal to a method in the group.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      The mock method from the group
 | 
						|
 | 
						|
    Raises:
 | 
						|
      UnexpectedMethodCallError if the mock_method was not in the group.
 | 
						|
    """
 | 
						|
 | 
						|
    # Check to see if this method exists, and if so, remove it from the set
 | 
						|
    # and return it.
 | 
						|
    for method in self._methods:
 | 
						|
      if method == mock_method:
 | 
						|
        # Remove the called mock_method instead of the method in the group.
 | 
						|
        # The called method will match any comparators when equality is checked
 | 
						|
        # during removal.  The method in the group could pass a comparator to
 | 
						|
        # another comparator during the equality check.
 | 
						|
        self._methods.remove(mock_method)
 | 
						|
 | 
						|
        # If this group is not empty, put it back at the head of the queue.
 | 
						|
        if not self.IsSatisfied():
 | 
						|
          mock_method._call_queue.appendleft(self)
 | 
						|
 | 
						|
        return self, method
 | 
						|
 | 
						|
    raise UnexpectedMethodCallError(mock_method, self)
 | 
						|
 | 
						|
  def IsSatisfied(self):
 | 
						|
    """Return True if there are not any methods in this group."""
 | 
						|
 | 
						|
    return len(self._methods) == 0
 | 
						|
 | 
						|
 | 
						|
class MultipleTimesGroup(MethodGroup):
 | 
						|
  """MultipleTimesGroup holds methods that may be called any number of times.
 | 
						|
 | 
						|
  Note: Each method must be called at least once.
 | 
						|
 | 
						|
  This is helpful, if you don't know or care how many times a method is called.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(self, group_name):
 | 
						|
    super(MultipleTimesGroup, self).__init__(group_name)
 | 
						|
    self._methods = set()
 | 
						|
    self._methods_left = set()
 | 
						|
 | 
						|
  def AddMethod(self, mock_method):
 | 
						|
    """Add a method to this group.
 | 
						|
 | 
						|
    Args:
 | 
						|
      mock_method: A mock method to be added to this group.
 | 
						|
    """
 | 
						|
 | 
						|
    self._methods.add(mock_method)
 | 
						|
    self._methods_left.add(mock_method)
 | 
						|
 | 
						|
  def MethodCalled(self, mock_method):
 | 
						|
    """Remove a method call from the group.
 | 
						|
 | 
						|
    If the method is not in the set, an UnexpectedMethodCallError will be
 | 
						|
    raised.
 | 
						|
 | 
						|
    Args:
 | 
						|
      mock_method: a mock method that should be equal to a method in the group.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      The mock method from the group
 | 
						|
 | 
						|
    Raises:
 | 
						|
      UnexpectedMethodCallError if the mock_method was not in the group.
 | 
						|
    """
 | 
						|
 | 
						|
    # Check to see if this method exists, and if so add it to the set of
 | 
						|
    # called methods.
 | 
						|
    for method in self._methods:
 | 
						|
      if method == mock_method:
 | 
						|
        self._methods_left.discard(method)
 | 
						|
        # Always put this group back on top of the queue, because we don't know
 | 
						|
        # when we are done.
 | 
						|
        mock_method._call_queue.appendleft(self)
 | 
						|
        return self, method
 | 
						|
 | 
						|
    if self.IsSatisfied():
 | 
						|
      next_method = mock_method._PopNextMethod();
 | 
						|
      return next_method, None
 | 
						|
    else:
 | 
						|
      raise UnexpectedMethodCallError(mock_method, self)
 | 
						|
 | 
						|
  def IsSatisfied(self):
 | 
						|
    """Return True if all methods in this group are called at least once."""
 | 
						|
    return len(self._methods_left) == 0
 | 
						|
 | 
						|
 | 
						|
class MoxMetaTestBase(type):
 | 
						|
  """Metaclass to add mox cleanup and verification to every test.
 | 
						|
 | 
						|
  As the mox unit testing class is being constructed (MoxTestBase or a
 | 
						|
  subclass), this metaclass will modify all test functions to call the
 | 
						|
  CleanUpMox method of the test class after they finish. This means that
 | 
						|
  unstubbing and verifying will happen for every test with no additional code,
 | 
						|
  and any failures will result in test failures as opposed to errors.
 | 
						|
  """
 | 
						|
 | 
						|
  def __init__(cls, name, bases, d):
 | 
						|
    type.__init__(cls, name, bases, d)
 | 
						|
 | 
						|
    # also get all the attributes from the base classes to account
 | 
						|
    # for a case when test class is not the immediate child of MoxTestBase
 | 
						|
    for base in bases:
 | 
						|
      for attr_name in dir(base):
 | 
						|
        d[attr_name] = getattr(base, attr_name)
 | 
						|
 | 
						|
    for func_name, func in d.items():
 | 
						|
      if func_name.startswith('test') and callable(func):
 | 
						|
        setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
 | 
						|
 | 
						|
  @staticmethod
 | 
						|
  def CleanUpTest(cls, func):
 | 
						|
    """Adds Mox cleanup code to any MoxTestBase method.
 | 
						|
 | 
						|
    Always unsets stubs after a test. Will verify all mocks for tests that
 | 
						|
    otherwise pass.
 | 
						|
 | 
						|
    Args:
 | 
						|
      cls: MoxTestBase or subclass; the class whose test method we are altering.
 | 
						|
      func: method; the method of the MoxTestBase test class we wish to alter.
 | 
						|
 | 
						|
    Returns:
 | 
						|
      The modified method.
 | 
						|
    """
 | 
						|
    def new_method(self, *args, **kwargs):
 | 
						|
      mox_obj = getattr(self, 'mox', None)
 | 
						|
      cleanup_mox = False
 | 
						|
      if mox_obj and isinstance(mox_obj, Mox):
 | 
						|
        cleanup_mox = True
 | 
						|
      try:
 | 
						|
        func(self, *args, **kwargs)
 | 
						|
      finally:
 | 
						|
        if cleanup_mox:
 | 
						|
          mox_obj.UnsetStubs()
 | 
						|
      if cleanup_mox:
 | 
						|
        mox_obj.VerifyAll()
 | 
						|
    new_method.__name__ = func.__name__
 | 
						|
    new_method.__doc__ = func.__doc__
 | 
						|
    new_method.__module__ = func.__module__
 | 
						|
    return new_method
 | 
						|
 | 
						|
 | 
						|
class MoxTestBase(unittest.TestCase):
 | 
						|
  """Convenience test class to make stubbing easier.
 | 
						|
 | 
						|
  Sets up a "mox" attribute which is an instance of Mox - any mox tests will
 | 
						|
  want this. Also automatically unsets any stubs and verifies that all mock
 | 
						|
  methods have been called at the end of each test, eliminating boilerplate
 | 
						|
  code.
 | 
						|
  """
 | 
						|
 | 
						|
  __metaclass__ = MoxMetaTestBase
 | 
						|
 | 
						|
  def setUp(self):
 | 
						|
    super(MoxTestBase, self).setUp()
 | 
						|
    self.mox = Mox()
 |