diff --git a/owners_client.py b/owners_client.py index 6f6ecf29c..fc756928b 100644 --- a/owners_client.py +++ b/owners_client.py @@ -135,10 +135,14 @@ class DepotToolsClient(OwnersClient): self._branch = branch self._fopen = fopen self._os_path = os_path + self._db = None + self._db_lock = threading.Lock() - self._db = owners_db.Database(root, fopen, os_path) + def _ensure_db(self): + if self._db is not None: + return + self._db = owners_db.Database(self._root, self._fopen, self._os_path) self._db.override_files = self._GetOriginalOwnersFiles() - self._db_lock = threading.Lock() def _GetOriginalOwnersFiles(self): return { @@ -150,6 +154,7 @@ class DepotToolsClient(OwnersClient): def ListOwners(self, path): # all_possible_owners is not thread safe. with self._db_lock: + self._ensure_db() # all_possible_owners returns a dict {owner: [(path, distance)]}. We want # to return a list of owners sorted by increasing distance. distance_by_owner = self._db.all_possible_owners([path], None) diff --git a/tests/git_cl_test.py b/tests/git_cl_test.py index 00564cf7e..3ed937132 100755 --- a/tests/git_cl_test.py +++ b/tests/git_cl_test.py @@ -3922,12 +3922,21 @@ class CMDStatusTestCase(CMDTestCaseBase): class CMDOwnersTestCase(CMDTestCaseBase): def setUp(self): super(CMDOwnersTestCase, self).setUp() + self.owners_by_path = { + 'foo': ['a@example.com'], + 'bar': ['b@example.com', 'c@example.com'], + } mock.patch('git_cl.Settings.GetRoot', return_value='root').start() mock.patch('git_cl.Changelist.GetAuthor', return_value='author').start() + mock.patch( + 'git_cl.Changelist.GetAffectedFiles', + return_value=list(self.owners_by_path)).start() mock.patch( 'git_cl.Changelist.GetCommonAncestorWithUpstream', return_value='upstream').start() - mock.patch('owners_client.DepotToolsClient').start() + mock.patch( + 'owners_client.DepotToolsClient.BatchListOwners', + return_value=self.owners_by_path).start() self.addCleanup(mock.patch.stopall) def testShowAllNoArgs(self): @@ -3937,15 +3946,11 @@ class CMDOwnersTestCase(CMDTestCaseBase): git_cl.sys.stdout.getvalue()) def testShowAll(self): - batch_mock = owners_client.DepotToolsClient.return_value.BatchListOwners - batch_mock.return_value = { - 'foo': ['a@example.com'], - 'bar': ['b@example.com', 'c@example.com'], - } self.assertEqual( 0, git_cl.main(['owners', '--show-all', 'foo', 'bar', 'baz'])) - batch_mock.assert_called_once_with(['foo', 'bar', 'baz']) + owners_client.DepotToolsClient.BatchListOwners.assert_called_once_with( + ['foo', 'bar', 'baz']) self.assertEqual( '\n'.join([ 'Owners for foo:', @@ -3959,6 +3964,11 @@ class CMDOwnersTestCase(CMDTestCaseBase): ]), sys.stdout.getvalue()) + def testBatch(self): + self.assertEqual(0, git_cl.main(['owners', '--batch'])) + self.assertIn('a@example.com', sys.stdout.getvalue()) + self.assertIn('b@example.com', sys.stdout.getvalue()) + if __name__ == '__main__': logging.basicConfig(