diff --git a/testing_support/.style.yapf b/testing_support/.style.yapf new file mode 100644 index 000000000..4741fb4f3 --- /dev/null +++ b/testing_support/.style.yapf @@ -0,0 +1,3 @@ +[style] +based_on_style = pep8 +column_limit = 80 diff --git a/testing_support/coverage_utils.py b/testing_support/coverage_utils.py index 92d562db3..c3f9b18b4 100644 --- a/testing_support/coverage_utils.py +++ b/testing_support/coverage_utils.py @@ -10,22 +10,26 @@ import sys import textwrap import unittest -ROOT_PATH = os.path.abspath(os.path.join( - os.path.dirname(os.path.dirname(__file__)))) +ROOT_PATH = os.path.abspath( + os.path.join(os.path.dirname(os.path.dirname(__file__)))) def native_error(msg, version): - print(textwrap.dedent("""\ + print( + textwrap.dedent("""\ ERROR: Native python-coverage (version: %s) is required to be installed on your PYTHONPATH to run this test. Recommendation: sudo apt-get install pip sudo pip install --upgrade coverage %s""") % (version, msg)) - sys.exit(1) + sys.exit(1) -def covered_main(includes, require_native=None, required_percentage=100.0, + +def covered_main(includes, + require_native=None, + required_percentage=100.0, disable_coverage=True): - """Equivalent of unittest.main(), except that it gathers coverage data, and + """Equivalent of unittest.main(), except that it gathers coverage data, and asserts if the test is not at 100% coverage. Args: @@ -37,43 +41,45 @@ def covered_main(includes, require_native=None, required_percentage=100.0, disable_coverage (bool) - If True, just run unittest.main() without any coverage tracking. Bug: crbug.com/662277 """ - if disable_coverage: - unittest.main() - return + if disable_coverage: + unittest.main() + return - try: - import coverage - if require_native is not None: - got_ver = coverage.__version__ - if not getattr(coverage.collector, 'CTracer', None): - native_error(( - "Native python-coverage module required.\n" - "Pure-python implementation (version: %s) found: %s" - ) % (got_ver, coverage), require_native) - if got_ver < distutils.version.LooseVersion(require_native): - native_error("Wrong version (%s) found: %s" % (got_ver, coverage), - require_native) - except ImportError: - if require_native is None: - sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party')) - import coverage - else: - print("ERROR: python-coverage (%s) is required to be installed on your " - "PYTHONPATH to run this test." % require_native) - sys.exit(1) + try: + import coverage + if require_native is not None: + got_ver = coverage.__version__ + if not getattr(coverage.collector, 'CTracer', None): + native_error( + ("Native python-coverage module required.\n" + "Pure-python implementation (version: %s) found: %s") % + (got_ver, coverage), require_native) + if got_ver < distutils.version.LooseVersion(require_native): + native_error( + "Wrong version (%s) found: %s" % (got_ver, coverage), + require_native) + except ImportError: + if require_native is None: + sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party')) + import coverage + else: + print( + "ERROR: python-coverage (%s) is required to be installed on " + "your PYTHONPATH to run this test." % require_native) + sys.exit(1) - COVERAGE = coverage.coverage(include=includes) - COVERAGE.start() + COVERAGE = coverage.coverage(include=includes) + COVERAGE.start() - retcode = 0 - try: - unittest.main() - except SystemExit as e: - retcode = e.code or retcode + retcode = 0 + try: + unittest.main() + except SystemExit as e: + retcode = e.code or retcode - COVERAGE.stop() - if COVERAGE.report() < required_percentage: - print('FATAL: not at required %f%% coverage.' % required_percentage) - retcode = 2 + COVERAGE.stop() + if COVERAGE.report() < required_percentage: + print('FATAL: not at required %f%% coverage.' % required_percentage) + retcode = 2 - return retcode + return retcode diff --git a/testing_support/fake_cipd.py b/testing_support/fake_cipd.py index f44aa6e64..cb9532c63 100644 --- a/testing_support/fake_cipd.py +++ b/testing_support/fake_cipd.py @@ -74,129 +74,131 @@ DESCRIBE_JSON_TEMPLATE = """{ def parse_cipd(root, contents): - tree = {} - current_subdir = None - for line in contents: - line = line.strip() - match = re.match(CIPD_SUBDIR_RE, line) - if match: - print('match') - current_subdir = os.path.join(root, *match.group(1).split('/')) - if not root: - current_subdir = match.group(1) - elif line and current_subdir: - print('no match') - tree.setdefault(current_subdir, []).append(line) - return tree + tree = {} + current_subdir = None + for line in contents: + line = line.strip() + match = re.match(CIPD_SUBDIR_RE, line) + if match: + print('match') + current_subdir = os.path.join(root, *match.group(1).split('/')) + if not root: + current_subdir = match.group(1) + elif line and current_subdir: + print('no match') + tree.setdefault(current_subdir, []).append(line) + return tree def expand_package_name_cmd(package_name): - package_split = package_name.split("/") - suffix = package_split[-1] - # Any use of var equality should return empty for testing. - if "=" in suffix: - if suffix != "${platform=fake-platform-ok}": - return "" - package_name = "/".join(package_split[:-1] + ["${platform}"]) - for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]: - var = "${%s}" % v - if package_name.endswith(var): - package_name = package_name.replace(var, "%s-expanded-test-only" % v) - return package_name + package_split = package_name.split("/") + suffix = package_split[-1] + # Any use of var equality should return empty for testing. + if "=" in suffix: + if suffix != "${platform=fake-platform-ok}": + return "" + package_name = "/".join(package_split[:-1] + ["${platform}"]) + for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]: + var = "${%s}" % v + if package_name.endswith(var): + package_name = package_name.replace(var, + "%s-expanded-test-only" % v) + return package_name def ensure_file_resolve(): - resolved = {"result": {}} - parser = argparse.ArgumentParser() - parser.add_argument('-ensure-file', required=True) - parser.add_argument('-json-output') - args, _ = parser.parse_known_args() - with io.open(args.ensure_file, 'r', encoding='utf-8') as f: - new_content = parse_cipd("", f.readlines()) - for path, packages in new_content.items(): - resolved_packages = [] - for package in packages: - package_name = expand_package_name_cmd(package.split(" ")[0]) - resolved_packages.append({ - "package": package_name, - "pin": { - "package": package_name, - "instance_id": package_name + "-fake-resolved-id", - } - }) - resolved["result"][path] = resolved_packages - with io.open(args.json_output, 'w', encoding='utf-8') as f: - f.write(json.dumps(resolved, indent=4)) + resolved = {"result": {}} + parser = argparse.ArgumentParser() + parser.add_argument('-ensure-file', required=True) + parser.add_argument('-json-output') + args, _ = parser.parse_known_args() + with io.open(args.ensure_file, 'r', encoding='utf-8') as f: + new_content = parse_cipd("", f.readlines()) + for path, packages in new_content.items(): + resolved_packages = [] + for package in packages: + package_name = expand_package_name_cmd(package.split(" ")[0]) + resolved_packages.append({ + "package": package_name, + "pin": { + "package": package_name, + "instance_id": package_name + "-fake-resolved-id", + } + }) + resolved["result"][path] = resolved_packages + with io.open(args.json_output, 'w', encoding='utf-8') as f: + f.write(json.dumps(resolved, indent=4)) def describe_cmd(package_name): - parser = argparse.ArgumentParser() - parser.add_argument('-json-output') - parser.add_argument('-version', required=True) - args, _ = parser.parse_known_args() - json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute( - package=package_name) - cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute(package=package_name) - json_out = json.loads(json_template) - found = False - for tag in json_out['result']['tags']: - if tag['tag'] == args.version: - found = True - break - for tag in json_out['result']['refs']: - if tag['ref'] == args.version: - found = True - break - if found: - if args.json_output: - with io.open(args.json_output, 'w', encoding='utf-8') as f: - f.write(json.dumps(json_out, indent=4)) - sys.stdout.write(cli_out) - return 0 - sys.stdout.write('Error: no such ref.\n') - return 1 + parser = argparse.ArgumentParser() + parser.add_argument('-json-output') + parser.add_argument('-version', required=True) + args, _ = parser.parse_known_args() + json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute( + package=package_name) + cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute( + package=package_name) + json_out = json.loads(json_template) + found = False + for tag in json_out['result']['tags']: + if tag['tag'] == args.version: + found = True + break + for tag in json_out['result']['refs']: + if tag['ref'] == args.version: + found = True + break + if found: + if args.json_output: + with io.open(args.json_output, 'w', encoding='utf-8') as f: + f.write(json.dumps(json_out, indent=4)) + sys.stdout.write(cli_out) + return 0 + sys.stdout.write('Error: no such ref.\n') + return 1 def main(): - cmd = sys.argv[1] - assert cmd in [ - CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG, - CIPD_EXPORT - ] - # Handle cipd expand-package-name - if cmd == CIPD_EXPAND_PKG: - # Expecting argument after cmd - assert len(sys.argv) == 3 - # Write result to stdout - sys.stdout.write(expand_package_name_cmd(sys.argv[2])) - return 0 - if cmd == CIPD_DESCRIBE: - # Expecting argument after cmd - assert len(sys.argv) >= 3 - return describe_cmd(sys.argv[2]) - if cmd == CIPD_ENSURE_FILE_RESOLVE: - return ensure_file_resolve() - - parser = argparse.ArgumentParser() - parser.add_argument('-ensure-file') - parser.add_argument('-root') - args, _ = parser.parse_known_args() - - with io.open(args.ensure_file, 'r', encoding='utf-8') as f: - new_content = parse_cipd(args.root, f.readlines()) - - # Install new packages - for path, packages in new_content.items(): - if not os.path.exists(path): - os.makedirs(path) - with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f: - f.write('\n'.join(packages)) + cmd = sys.argv[1] + assert cmd in [ + CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG, + CIPD_EXPORT + ] + # Handle cipd expand-package-name + if cmd == CIPD_EXPAND_PKG: + # Expecting argument after cmd + assert len(sys.argv) == 3 + # Write result to stdout + sys.stdout.write(expand_package_name_cmd(sys.argv[2])) + return 0 + if cmd == CIPD_DESCRIBE: + # Expecting argument after cmd + assert len(sys.argv) >= 3 + return describe_cmd(sys.argv[2]) + if cmd == CIPD_ENSURE_FILE_RESOLVE: + return ensure_file_resolve() + + parser = argparse.ArgumentParser() + parser.add_argument('-ensure-file') + parser.add_argument('-root') + args, _ = parser.parse_known_args() + + with io.open(args.ensure_file, 'r', encoding='utf-8') as f: + new_content = parse_cipd(args.root, f.readlines()) + + # Install new packages + for path, packages in new_content.items(): + if not os.path.exists(path): + os.makedirs(path) + with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f: + f.write('\n'.join(packages)) - # Save the ensure file that we got - shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd')) + # Save the ensure file that we got + shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd')) - return 0 + return 0 if __name__ == '__main__': - sys.exit(main()) + sys.exit(main()) diff --git a/testing_support/fake_repos.py b/testing_support/fake_repos.py index 31c5a3d27..06f3c3528 100755 --- a/testing_support/fake_repos.py +++ b/testing_support/fake_repos.py @@ -3,7 +3,6 @@ # Copyright (c) 2011 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. - """Generate fake repositories for testing.""" from __future__ import print_function @@ -33,54 +32,55 @@ DEFAULT_BRANCH = 'main' def write(path, content): - f = open(path, 'wb') - f.write(content.encode()) - f.close() + f = open(path, 'wb') + f.write(content.encode()) + f.close() join = os.path.join def read_tree(tree_root): - """Returns a dict of all the files in a tree. Defaults to self.root_dir.""" - tree = {} - for root, dirs, files in os.walk(tree_root): - for d in filter(lambda x: x.startswith('.'), dirs): - dirs.remove(d) - for f in [join(root, f) for f in files if not f.startswith('.')]: - filepath = f[len(tree_root) + 1:].replace(os.sep, '/') - assert len(filepath) > 0, f - with io.open(join(root, f), encoding='utf-8') as f: - tree[filepath] = f.read() - return tree + """Returns a dict of all the files in a tree. Defaults to self.root_dir.""" + tree = {} + for root, dirs, files in os.walk(tree_root): + for d in filter(lambda x: x.startswith('.'), dirs): + dirs.remove(d) + for f in [join(root, f) for f in files if not f.startswith('.')]: + filepath = f[len(tree_root) + 1:].replace(os.sep, '/') + assert len(filepath) > 0, f + with io.open(join(root, f), encoding='utf-8') as f: + tree[filepath] = f.read() + return tree def dict_diff(dict1, dict2): - diff = {} - for k, v in dict1.items(): - if k not in dict2: - diff[k] = v - elif v != dict2[k]: - diff[k] = (v, dict2[k]) - for k, v in dict2.items(): - if k not in dict1: - diff[k] = v - return diff + diff = {} + for k, v in dict1.items(): + if k not in dict2: + diff[k] = v + elif v != dict2[k]: + diff[k] = (v, dict2[k]) + for k, v in dict2.items(): + if k not in dict1: + diff[k] = v + return diff def commit_git(repo): - """Commits the changes and returns the new hash.""" - subprocess2.check_call(['git', 'add', '-A', '-f'], cwd=repo) - subprocess2.check_call(['git', 'commit', '-q', '--message', 'foo'], cwd=repo) - rev = subprocess2.check_output( - ['git', 'show-ref', '--head', 'HEAD'], cwd=repo).split(b' ', 1)[0] - rev = rev.decode('utf-8') - logging.debug('At revision %s' % rev) - return rev + """Commits the changes and returns the new hash.""" + subprocess2.check_call(['git', 'add', '-A', '-f'], cwd=repo) + subprocess2.check_call(['git', 'commit', '-q', '--message', 'foo'], + cwd=repo) + rev = subprocess2.check_output(['git', 'show-ref', '--head', 'HEAD'], + cwd=repo).split(b' ', 1)[0] + rev = rev.decode('utf-8') + logging.debug('At revision %s' % rev) + return rev class FakeReposBase(object): - """Generate git repositories to test gclient functionality. + """Generate git repositories to test gclient functionality. Many DEPS functionalities need to be tested: Var, deps_os, hooks, use_relative_paths. @@ -89,154 +89,158 @@ class FakeReposBase(object): populateGit() needs to be implemented by the subclass. """ - # Hostname - NB_GIT_REPOS = 1 - USERS = [ - ('user1@example.com', 'foo Fuß'), - ('user2@example.com', 'bar'), - ] - - def __init__(self, host=None): - self.trial = trial_dir.TrialDir('repos') - self.host = host or '127.0.0.1' - # Format is { repo: [ None, (hash, tree), (hash, tree), ... ], ... } - # so reference looks like self.git_hashes[repo][rev][0] for hash and - # self.git_hashes[repo][rev][1] for it's tree snapshot. - # It is 1-based too. - self.git_hashes = {} - self.git_pid_file_name = None - self.git_base = None - self.initialized = False - - @property - def root_dir(self): - return self.trial.root_dir - - def set_up(self): - """All late initialization comes here.""" - if not self.root_dir: - try: - # self.root_dir is not set before this call. - self.trial.set_up() - self.git_base = join(self.root_dir, 'git') + os.sep - finally: - # Registers cleanup. - atexit.register(self.tear_down) - - def tear_down(self): - """Kills the servers and delete the directories.""" - self.tear_down_git() - # This deletes the directories. - self.trial.tear_down() - self.trial = None - - def tear_down_git(self): - if self.trial.SHOULD_LEAK: - return False - logging.debug('Removing %s' % self.git_base) - gclient_utils.rmtree(self.git_base) - return True - - @staticmethod - def _genTree(root, tree_dict): - """For a dictionary of file contents, generate a filesystem.""" - if not os.path.isdir(root): - os.makedirs(root) - for (k, v) in tree_dict.items(): - k_os = k.replace('/', os.sep) - k_arr = k_os.split(os.sep) - if len(k_arr) > 1: - p = os.sep.join([root] + k_arr[:-1]) - if not os.path.isdir(p): - os.makedirs(p) - if v is None: - os.remove(join(root, k)) - else: - write(join(root, k), v) - - def set_up_git(self): - """Creates git repositories and start the servers.""" - self.set_up() - if self.initialized: - return True - try: - subprocess2.check_output(['git', '--version']) - except (OSError, subprocess2.CalledProcessError): - return False - for repo in ['repo_%d' % r for r in range(1, self.NB_GIT_REPOS + 1)]: - # TODO(crbug.com/114712) use git.init -b and remove 'checkout' once git is - # upgraded to 2.28 on all builders. - subprocess2.check_call(['git', 'init', '-q', join(self.git_base, repo)]) - subprocess2.check_call(['git', 'checkout', '-q', '-b', DEFAULT_BRANCH], - cwd=join(self.git_base, repo)) - self.git_hashes[repo] = [(None, None)] - self.populateGit() - self.initialized = True - return True - - def _git_rev_parse(self, path): - return subprocess2.check_output( - ['git', 'rev-parse', 'HEAD'], cwd=path).strip() - - def _commit_git(self, repo, tree, base=None): - repo_root = join(self.git_base, repo) - if base: - base_commit = self.git_hashes[repo][base][0] - subprocess2.check_call( - ['git', 'checkout', base_commit], cwd=repo_root) - self._genTree(repo_root, tree) - commit_hash = commit_git(repo_root) - base = base or -1 - if self.git_hashes[repo][base][1]: - new_tree = self.git_hashes[repo][base][1].copy() - new_tree.update(tree) - else: - new_tree = tree.copy() - self.git_hashes[repo].append((commit_hash, new_tree)) - - def _create_ref(self, repo, ref, revision): - repo_root = join(self.git_base, repo) - subprocess2.check_call( - ['git', 'update-ref', ref, self.git_hashes[repo][revision][0]], - cwd=repo_root) - - def _fast_import_git(self, repo, data): - repo_root = join(self.git_base, repo) - logging.debug('%s: fast-import %s', repo, data) - subprocess2.check_call( - ['git', 'fast-import', '--quiet'], cwd=repo_root, stdin=data.encode()) - - def populateGit(self): - raise NotImplementedError() + # Hostname + NB_GIT_REPOS = 1 + USERS = [ + ('user1@example.com', 'foo Fuß'), + ('user2@example.com', 'bar'), + ] + + def __init__(self, host=None): + self.trial = trial_dir.TrialDir('repos') + self.host = host or '127.0.0.1' + # Format is { repo: [ None, (hash, tree), (hash, tree), ... ], ... } + # so reference looks like self.git_hashes[repo][rev][0] for hash and + # self.git_hashes[repo][rev][1] for it's tree snapshot. + # It is 1-based too. + self.git_hashes = {} + self.git_pid_file_name = None + self.git_base = None + self.initialized = False + + @property + def root_dir(self): + return self.trial.root_dir + + def set_up(self): + """All late initialization comes here.""" + if not self.root_dir: + try: + # self.root_dir is not set before this call. + self.trial.set_up() + self.git_base = join(self.root_dir, 'git') + os.sep + finally: + # Registers cleanup. + atexit.register(self.tear_down) + + def tear_down(self): + """Kills the servers and delete the directories.""" + self.tear_down_git() + # This deletes the directories. + self.trial.tear_down() + self.trial = None + + def tear_down_git(self): + if self.trial.SHOULD_LEAK: + return False + logging.debug('Removing %s' % self.git_base) + gclient_utils.rmtree(self.git_base) + return True + + @staticmethod + def _genTree(root, tree_dict): + """For a dictionary of file contents, generate a filesystem.""" + if not os.path.isdir(root): + os.makedirs(root) + for (k, v) in tree_dict.items(): + k_os = k.replace('/', os.sep) + k_arr = k_os.split(os.sep) + if len(k_arr) > 1: + p = os.sep.join([root] + k_arr[:-1]) + if not os.path.isdir(p): + os.makedirs(p) + if v is None: + os.remove(join(root, k)) + else: + write(join(root, k), v) + + def set_up_git(self): + """Creates git repositories and start the servers.""" + self.set_up() + if self.initialized: + return True + try: + subprocess2.check_output(['git', '--version']) + except (OSError, subprocess2.CalledProcessError): + return False + for repo in ['repo_%d' % r for r in range(1, self.NB_GIT_REPOS + 1)]: + # TODO(crbug.com/114712) use git.init -b and remove 'checkout' once + # git is upgraded to 2.28 on all builders. + subprocess2.check_call( + ['git', 'init', '-q', + join(self.git_base, repo)]) + subprocess2.check_call( + ['git', 'checkout', '-q', '-b', DEFAULT_BRANCH], + cwd=join(self.git_base, repo)) + self.git_hashes[repo] = [(None, None)] + self.populateGit() + self.initialized = True + return True + + def _git_rev_parse(self, path): + return subprocess2.check_output(['git', 'rev-parse', 'HEAD'], + cwd=path).strip() + + def _commit_git(self, repo, tree, base=None): + repo_root = join(self.git_base, repo) + if base: + base_commit = self.git_hashes[repo][base][0] + subprocess2.check_call(['git', 'checkout', base_commit], + cwd=repo_root) + self._genTree(repo_root, tree) + commit_hash = commit_git(repo_root) + base = base or -1 + if self.git_hashes[repo][base][1]: + new_tree = self.git_hashes[repo][base][1].copy() + new_tree.update(tree) + else: + new_tree = tree.copy() + self.git_hashes[repo].append((commit_hash, new_tree)) + + def _create_ref(self, repo, ref, revision): + repo_root = join(self.git_base, repo) + subprocess2.check_call( + ['git', 'update-ref', ref, self.git_hashes[repo][revision][0]], + cwd=repo_root) + + def _fast_import_git(self, repo, data): + repo_root = join(self.git_base, repo) + logging.debug('%s: fast-import %s', repo, data) + subprocess2.check_call(['git', 'fast-import', '--quiet'], + cwd=repo_root, + stdin=data.encode()) + + def populateGit(self): + raise NotImplementedError() class FakeRepos(FakeReposBase): - """Implements populateGit().""" - NB_GIT_REPOS = 20 - - def populateGit(self): - # Testing: - # - dependency disappear - # - dependency renamed - # - versioned and unversioned reference - # - relative and full reference - # - deps_os - # - var - # - hooks - # TODO(maruel): - # - use_relative_paths - self._commit_git('repo_3', { - 'origin': 'git/repo_3@1\n', - }) - - self._commit_git('repo_3', { - 'origin': 'git/repo_3@2\n', - }) - - self._commit_git( - 'repo_1', - { - 'DEPS': """ + """Implements populateGit().""" + NB_GIT_REPOS = 20 + + def populateGit(self): + # Testing: + # - dependency disappear + # - dependency renamed + # - versioned and unversioned reference + # - relative and full reference + # - deps_os + # - var + # - hooks + # TODO(maruel): + # - use_relative_paths + self._commit_git('repo_3', { + 'origin': 'git/repo_3@1\n', + }) + + self._commit_git('repo_3', { + 'origin': 'git/repo_3@2\n', + }) + + self._commit_git( + 'repo_1', + { + 'DEPS': """ vars = { 'DummyVariable': 'repo', 'false_var': False, @@ -274,20 +278,23 @@ deps_os = { 'src/repo4': '/repo_4', }, }""" % { - 'git_base': self.git_base, - # See self.__init__() for the format. Grab's the hash of the - # first commit in repo_2. Only keep the first 7 character - # because of: TODO(maruel): http://crosbug.com/3591 We need to - # strip the hash.. duh. - 'hash3': self.git_hashes['repo_3'][1][0][:7] - }, - 'origin': 'git/repo_1@1\n', - 'foo bar': 'some file with a space', - }) - - self._commit_git('repo_2', { - 'origin': 'git/repo_2@1\n', - 'DEPS': """ + 'git_base': self.git_base, + # See self.__init__() for the format. Grab's the hash of the + # first commit in repo_2. Only keep the first 7 character + # because of: TODO(maruel): http://crosbug.com/3591 We need + # to strip the hash.. duh. + 'hash3': self.git_hashes['repo_3'][1][0][:7] + }, + 'origin': 'git/repo_1@1\n', + 'foo bar': 'some file with a space', + }) + + self._commit_git( + 'repo_2', { + 'origin': + 'git/repo_2@1\n', + 'DEPS': + """ vars = { 'repo2_false_var': 'False', } @@ -299,22 +306,24 @@ deps = { } } """, - }) + }) - self._commit_git('repo_2', { - 'origin': 'git/repo_2@2\n', - }) + self._commit_git('repo_2', { + 'origin': 'git/repo_2@2\n', + }) - self._commit_git('repo_4', { - 'origin': 'git/repo_4@1\n', - }) + self._commit_git('repo_4', { + 'origin': 'git/repo_4@1\n', + }) - self._commit_git('repo_4', { - 'origin': 'git/repo_4@2\n', - }) + self._commit_git('repo_4', { + 'origin': 'git/repo_4@2\n', + }) - self._commit_git('repo_1', { - 'DEPS': """ + self._commit_git( + 'repo_1', + { + 'DEPS': """ deps = { 'src/repo2': %(git_base)r + 'repo_2@%(hash)s', 'src/repo2/repo_renamed': '/repo_3', @@ -339,18 +348,20 @@ hooks = [ }, ] """ % { - 'git_base': self.git_base, - # See self.__init__() for the format. Grab's the hash of the first - # commit in repo_2. Only keep the first 7 character because of: - # TODO(maruel): http://crosbug.com/3591 We need to strip the hash.. duh. - 'hash': self.git_hashes['repo_2'][1][0][:7] - }, - 'origin': 'git/repo_1@2\n', - }) - - self._commit_git('repo_5', {'origin': 'git/repo_5@1\n'}) - self._commit_git('repo_5', { - 'DEPS': """ + 'git_base': self.git_base, + # See self.__init__() for the format. Grab's the hash of the + # first commit in repo_2. Only keep the first 7 character + # because of: TODO(maruel): http://crosbug.com/3591 We need + # to strip the hash.. duh. + 'hash': self.git_hashes['repo_2'][1][0][:7] + }, + 'origin': 'git/repo_1@2\n', + }) + + self._commit_git('repo_5', {'origin': 'git/repo_5@1\n'}) + self._commit_git( + 'repo_5', { + 'DEPS': """ deps = { 'src/repo1': %(git_base)r + 'repo_1@%(hash1)s', 'src/repo2': %(git_base)r + 'repo_2@%(hash2)s', @@ -365,14 +376,15 @@ pre_deps_hooks = [ } ] """ % { - 'git_base': self.git_base, - 'hash1': self.git_hashes['repo_1'][2][0][:7], - 'hash2': self.git_hashes['repo_2'][1][0][:7], - }, - 'origin': 'git/repo_5@2\n', - }) - self._commit_git('repo_5', { - 'DEPS': """ + 'git_base': self.git_base, + 'hash1': self.git_hashes['repo_1'][2][0][:7], + 'hash2': self.git_hashes['repo_2'][1][0][:7], + }, + 'origin': 'git/repo_5@2\n', + }) + self._commit_git( + 'repo_5', { + 'DEPS': """ deps = { 'src/repo1': %(git_base)r + 'repo_1@%(hash1)s', 'src/repo2': %(git_base)r + 'repo_2@%(hash2)s', @@ -390,15 +402,16 @@ pre_deps_hooks = [ } ] """ % { - 'git_base': self.git_base, - 'hash1': self.git_hashes['repo_1'][2][0][:7], - 'hash2': self.git_hashes['repo_2'][1][0][:7], - }, - 'origin': 'git/repo_5@3\n', - }) - - self._commit_git('repo_6', { - 'DEPS': """ + 'git_base': self.git_base, + 'hash1': self.git_hashes['repo_1'][2][0][:7], + 'hash2': self.git_hashes['repo_2'][1][0][:7], + }, + 'origin': 'git/repo_5@3\n', + }) + + self._commit_git( + 'repo_6', { + 'DEPS': """ vars = { 'DummyVariable': 'repo', 'git_base': %(git_base)r, @@ -492,14 +505,15 @@ recursedeps = [ 'src/repo15', 'src/repo16', ]""" % { - 'git_base': self.git_base, - 'hash': self.git_hashes['repo_2'][1][0][:7] - }, - 'origin': 'git/repo_6@1\n', - }) - - self._commit_git('repo_7', { - 'DEPS': """ + 'git_base': self.git_base, + 'hash': self.git_hashes['repo_2'][1][0][:7] + }, + 'origin': 'git/repo_6@1\n', + }) + + self._commit_git( + 'repo_7', { + 'DEPS': """ vars = { 'true_var': 'True', 'false_var': 'true_var and False', @@ -516,11 +530,12 @@ hooks = [ 'condition': 'false_var', }, ]""", - 'origin': 'git/repo_7@1\n', - }) + 'origin': 'git/repo_7@1\n', + }) - self._commit_git('repo_8', { - 'DEPS': """ + self._commit_git( + 'repo_8', { + 'DEPS': """ deps_os ={ 'mac': { 'src/recursed_os_repo': '/repo_5', @@ -529,11 +544,12 @@ deps_os ={ 'src/recursed_os_repo': '/repo_5', }, }""", - 'origin': 'git/repo_8@1\n', - }) + 'origin': 'git/repo_8@1\n', + }) - self._commit_git('repo_9', { - 'DEPS': """ + self._commit_git( + 'repo_9', { + 'DEPS': """ vars = { 'str_var': 'xyz', } @@ -560,11 +576,12 @@ recursedeps = [ 'src/repo4', 'src/repo8', ]""", - 'origin': 'git/repo_9@1\n', - }) + 'origin': 'git/repo_9@1\n', + }) - self._commit_git('repo_10', { - 'DEPS': """ + self._commit_git( + 'repo_10', { + 'DEPS': """ gclient_gn_args_from = 'src/repo9' deps = { 'src/repo9': '/repo_9', @@ -586,22 +603,24 @@ recursedeps = [ 'src/repo9', 'src/repo11', ]""", - 'origin': 'git/repo_10@1\n', - }) + 'origin': 'git/repo_10@1\n', + }) - self._commit_git('repo_11', { - 'DEPS': """ + self._commit_git( + 'repo_11', { + 'DEPS': """ deps = { 'src/repo12': '/repo_12', }""", - 'origin': 'git/repo_11@1\n', - }) + 'origin': 'git/repo_11@1\n', + }) - self._commit_git('repo_12', { - 'origin': 'git/repo_12@1\n', - }) + self._commit_git('repo_12', { + 'origin': 'git/repo_12@1\n', + }) - self._fast_import_git('repo_12', """blob + self._fast_import_git( + 'repo_12', """blob mark :1 data 6 Hello @@ -622,25 +641,28 @@ M 100644 :1 a M 100644 :2 b """) - self._commit_git('repo_13', { - 'DEPS': """ + self._commit_git( + 'repo_13', { + 'DEPS': """ deps = { 'src/repo12': '/repo_12', }""", - 'origin': 'git/repo_13@1\n', - }) + 'origin': 'git/repo_13@1\n', + }) - self._commit_git('repo_13', { - 'DEPS': """ + self._commit_git( + 'repo_13', { + 'DEPS': """ deps = { 'src/repo12': '/repo_12@refs/changes/1212', }""", - 'origin': 'git/repo_13@2\n', - }) + 'origin': 'git/repo_13@2\n', + }) - # src/repo12 is now a CIPD dependency. - self._commit_git('repo_13', { - 'DEPS': """ + # src/repo12 is now a CIPD dependency. + self._commit_git( + 'repo_13', { + 'DEPS': """ deps = { 'src/repo12': { 'packages': [ @@ -657,11 +679,13 @@ hooks = [{ 'action': ['python3', '-c', 'with open("src/repo12/_cipd"): pass'], }] """, - 'origin': 'git/repo_13@3\n' - }) + 'origin': 'git/repo_13@3\n' + }) - self._commit_git('repo_14', { - 'DEPS': textwrap.dedent("""\ + self._commit_git( + 'repo_14', { + 'DEPS': + textwrap.dedent("""\ vars = {} deps = { 'src/cipd_dep': { @@ -696,34 +720,44 @@ hooks = [{ 'dep_type': 'cipd', }, }"""), - 'origin': 'git/repo_14@2\n' - }) - - # A repo with a hook to be recursed in, without use_relative_paths - self._commit_git('repo_15', { - 'DEPS': textwrap.dedent("""\ + 'origin': + 'git/repo_14@2\n' + }) + + # A repo with a hook to be recursed in, without use_relative_paths + self._commit_git( + 'repo_15', { + 'DEPS': + textwrap.dedent("""\ hooks = [{ "name": "absolute_cwd", "pattern": ".", "action": ["python3", "-c", "pass"] }]"""), - 'origin': 'git/repo_15@2\n' - }) - # A repo with a hook to be recursed in, with use_relative_paths - self._commit_git('repo_16', { - 'DEPS': textwrap.dedent("""\ + 'origin': + 'git/repo_15@2\n' + }) + # A repo with a hook to be recursed in, with use_relative_paths + self._commit_git( + 'repo_16', { + 'DEPS': + textwrap.dedent("""\ use_relative_paths=True hooks = [{ "name": "relative_cwd", "pattern": ".", "action": ["python3", "relative.py"] }]"""), - 'relative.py': 'pass', - 'origin': 'git/repo_16@2\n' - }) - # A repo with a gclient_gn_args_file and use_relative_paths - self._commit_git('repo_17', { - 'DEPS': textwrap.dedent("""\ + 'relative.py': + 'pass', + 'origin': + 'git/repo_16@2\n' + }) + # A repo with a gclient_gn_args_file and use_relative_paths + self._commit_git( + 'repo_17', { + 'DEPS': + textwrap.dedent("""\ use_relative_paths=True vars = { 'toto': 'tata', @@ -732,13 +766,14 @@ hooks = [{ gclient_gn_args = [ 'toto', ]"""), - 'origin': 'git/repo_17@2\n' - }) - - self._commit_git( - 'repo_18', { - 'DEPS': - textwrap.dedent("""\ + 'origin': + 'git/repo_17@2\n' + }) + + self._commit_git( + 'repo_18', { + 'DEPS': + textwrap.dedent("""\ deps = { 'src/cipd_dep': { 'packages': [ @@ -758,14 +793,14 @@ hooks = [{ 'dep_type': 'cipd', }, }"""), - 'origin': - 'git/repo_18@2\n' - }) + 'origin': + 'git/repo_18@2\n' + }) - # a relative path repo - self._commit_git( - 'repo_19', { - 'DEPS': """ + # a relative path repo + self._commit_git( + 'repo_19', { + 'DEPS': """ git_dependencies = "SUBMODULES" use_relative_paths = True @@ -791,15 +826,15 @@ deps = { "dep_type": "cipd", }, }""" % { - 'hash_2': self.git_hashes['repo_2'][1][0], - 'hash_3': self.git_hashes['repo_3'][1][0], - }, - }) + 'hash_2': self.git_hashes['repo_2'][1][0], + 'hash_3': self.git_hashes['repo_3'][1][0], + }, + }) - # a non-relative_path repo - self._commit_git( - 'repo_20', { - 'DEPS': """ + # a non-relative_path repo + self._commit_git( + 'repo_20', { + 'DEPS': """ git_dependencies = "SUBMODULES" vars = { @@ -824,232 +859,252 @@ deps = { "dep_type": "cipd", }, }""" % { - 'hash_2': self.git_hashes['repo_2'][1][0], - 'hash_3': self.git_hashes['repo_3'][1][0], - }, - }) + 'hash_2': self.git_hashes['repo_2'][1][0], + 'hash_3': self.git_hashes['repo_3'][1][0], + }, + }) class FakeRepoSkiaDEPS(FakeReposBase): - """Simulates the Skia DEPS transition in Chrome.""" + """Simulates the Skia DEPS transition in Chrome.""" - NB_GIT_REPOS = 5 + NB_GIT_REPOS = 5 - DEPS_git_pre = """deps = { + DEPS_git_pre = """deps = { 'src/third_party/skia/gyp': %(git_base)r + 'repo_3', 'src/third_party/skia/include': %(git_base)r + 'repo_4', 'src/third_party/skia/src': %(git_base)r + 'repo_5', }""" - DEPS_post = """deps = { + DEPS_post = """deps = { 'src/third_party/skia': %(git_base)r + 'repo_1', }""" - def populateGit(self): - # Skia repo. - self._commit_git('repo_1', { - 'skia_base_file': 'root-level file.', - 'gyp/gyp_file': 'file in the gyp directory', - 'include/include_file': 'file in the include directory', - 'src/src_file': 'file in the src directory', - }) - self._commit_git('repo_3', { # skia/gyp - 'gyp_file': 'file in the gyp directory', - }) - self._commit_git('repo_4', { # skia/include - 'include_file': 'file in the include directory', - }) - self._commit_git('repo_5', { # skia/src - 'src_file': 'file in the src directory', - }) - - # Chrome repo. - self._commit_git('repo_2', { - 'DEPS': self.DEPS_git_pre % {'git_base': self.git_base}, - 'myfile': 'src/trunk/src@1' - }) - self._commit_git('repo_2', { - 'DEPS': self.DEPS_post % {'git_base': self.git_base}, - 'myfile': 'src/trunk/src@2' - }) + def populateGit(self): + # Skia repo. + self._commit_git( + 'repo_1', { + 'skia_base_file': 'root-level file.', + 'gyp/gyp_file': 'file in the gyp directory', + 'include/include_file': 'file in the include directory', + 'src/src_file': 'file in the src directory', + }) + self._commit_git( + 'repo_3', + { # skia/gyp + 'gyp_file': 'file in the gyp directory', + }) + self._commit_git('repo_4', { # skia/include + 'include_file': 'file in the include directory', + }) + self._commit_git( + 'repo_5', + { # skia/src + 'src_file': 'file in the src directory', + }) + + # Chrome repo. + self._commit_git( + 'repo_2', { + 'DEPS': self.DEPS_git_pre % { + 'git_base': self.git_base + }, + 'myfile': 'src/trunk/src@1' + }) + self._commit_git( + 'repo_2', { + 'DEPS': self.DEPS_post % { + 'git_base': self.git_base + }, + 'myfile': 'src/trunk/src@2' + }) class FakeRepoBlinkDEPS(FakeReposBase): - """Simulates the Blink DEPS transition in Chrome.""" - - NB_GIT_REPOS = 2 - DEPS_pre = 'deps = {"src/third_party/WebKit": "%(git_base)srepo_2",}' - DEPS_post = 'deps = {}' - - def populateGit(self): - # Blink repo. - self._commit_git('repo_2', { - 'OWNERS': 'OWNERS-pre', - 'Source/exists_always': '_ignored_', - 'Source/exists_before_but_not_after': '_ignored_', - }) - - # Chrome repo. - self._commit_git('repo_1', { - 'DEPS': self.DEPS_pre % {'git_base': self.git_base}, - 'myfile': 'myfile@1', - '.gitignore': '/third_party/WebKit', - }) - self._commit_git('repo_1', { - 'DEPS': self.DEPS_post % {'git_base': self.git_base}, - 'myfile': 'myfile@2', - '.gitignore': '', - 'third_party/WebKit/OWNERS': 'OWNERS-post', - 'third_party/WebKit/Source/exists_always': '_ignored_', - 'third_party/WebKit/Source/exists_after_but_not_before': '_ignored', - }) - - def populateSvn(self): - raise NotImplementedError() + """Simulates the Blink DEPS transition in Chrome.""" + + NB_GIT_REPOS = 2 + DEPS_pre = 'deps = {"src/third_party/WebKit": "%(git_base)srepo_2",}' + DEPS_post = 'deps = {}' + + def populateGit(self): + # Blink repo. + self._commit_git( + 'repo_2', { + 'OWNERS': 'OWNERS-pre', + 'Source/exists_always': '_ignored_', + 'Source/exists_before_but_not_after': '_ignored_', + }) + + # Chrome repo. + self._commit_git( + 'repo_1', { + 'DEPS': self.DEPS_pre % { + 'git_base': self.git_base + }, + 'myfile': 'myfile@1', + '.gitignore': '/third_party/WebKit', + }) + self._commit_git( + 'repo_1', { + 'DEPS': self.DEPS_post % { + 'git_base': self.git_base + }, + 'myfile': 'myfile@2', + '.gitignore': '', + 'third_party/WebKit/OWNERS': 'OWNERS-post', + 'third_party/WebKit/Source/exists_always': '_ignored_', + 'third_party/WebKit/Source/exists_after_but_not_before': + '_ignored', + }) + + def populateSvn(self): + raise NotImplementedError() class FakeRepoNoSyncDEPS(FakeReposBase): - """Simulates a repo with some DEPS changes.""" + """Simulates a repo with some DEPS changes.""" - NB_GIT_REPOS = 2 + NB_GIT_REPOS = 2 - def populateGit(self): - self._commit_git('repo_2', {'myfile': 'then egg'}) - self._commit_git('repo_2', {'myfile': 'before egg!'}) + def populateGit(self): + self._commit_git('repo_2', {'myfile': 'then egg'}) + self._commit_git('repo_2', {'myfile': 'before egg!'}) - self._commit_git( - 'repo_1', { - 'DEPS': - textwrap.dedent( - """\ + self._commit_git( + 'repo_1', { + 'DEPS': + textwrap.dedent( + """\ deps = { 'src/repo2': { 'url': %(git_base)r + 'repo_2@%(repo2hash)s', }, }""" % { - 'git_base': self.git_base, - 'repo2hash': self.git_hashes['repo_2'][1][0][:7] - }) - }) - self._commit_git( - 'repo_1', { - 'DEPS': - textwrap.dedent( - """\ + 'git_base': self.git_base, + 'repo2hash': self.git_hashes['repo_2'][1][0][:7] + }) + }) + self._commit_git( + 'repo_1', { + 'DEPS': + textwrap.dedent( + """\ deps = { 'src/repo2': { 'url': %(git_base)r + 'repo_2@%(repo2hash)s', }, }""" % { - 'git_base': self.git_base, - 'repo2hash': self.git_hashes['repo_2'][2][0][:7] - }) - }) - self._commit_git( - 'repo_1', { - 'foo_file': - 'chicken content', - 'DEPS': - textwrap.dedent( - """\ + 'git_base': self.git_base, + 'repo2hash': self.git_hashes['repo_2'][2][0][:7] + }) + }) + self._commit_git( + 'repo_1', { + 'foo_file': + 'chicken content', + 'DEPS': + textwrap.dedent( + """\ deps = { 'src/repo2': { 'url': %(git_base)r + 'repo_2@%(repo2hash)s', }, }""" % { - 'git_base': self.git_base, - 'repo2hash': self.git_hashes['repo_2'][1][0][:7] - }) - }) + 'git_base': self.git_base, + 'repo2hash': self.git_hashes['repo_2'][1][0][:7] + }) + }) - self._commit_git('repo_1', {'foo_file': 'chicken content@4'}) + self._commit_git('repo_1', {'foo_file': 'chicken content@4'}) class FakeReposTestBase(trial_dir.TestCase): - """This is vaguely inspired by twisted.""" - # Static FakeRepos instances. Lazy loaded. - CACHED_FAKE_REPOS = {} - # Override if necessary. - FAKE_REPOS_CLASS = FakeRepos - - def setUp(self): - super(FakeReposTestBase, self).setUp() - if not self.FAKE_REPOS_CLASS in self.CACHED_FAKE_REPOS: - self.CACHED_FAKE_REPOS[self.FAKE_REPOS_CLASS] = self.FAKE_REPOS_CLASS() - self.FAKE_REPOS = self.CACHED_FAKE_REPOS[self.FAKE_REPOS_CLASS] - # No need to call self.FAKE_REPOS.setUp(), it will be called by the child - # class. - # Do not define tearDown(), since super's version does the right thing and - # self.FAKE_REPOS is kept across tests. - - @property - def git_base(self): - """Shortcut.""" - return self.FAKE_REPOS.git_base - - def checkString(self, expected, result, msg=None): - """Prints the diffs to ease debugging.""" - self.assertEqual(expected.splitlines(), result.splitlines(), msg) - if expected != result: - # Strip the beginning - while expected and result and expected[0] == result[0]: - expected = expected[1:] - result = result[1:] - # The exception trace makes it hard to read so dump it too. - if '\n' in result: - print(result) - self.assertEqual(expected, result, msg) - - def check(self, expected, results): - """Checks stdout, stderr, returncode.""" - self.checkString(expected[0], results[0]) - self.checkString(expected[1], results[1]) - self.assertEqual(expected[2], results[2]) - - def assertTree(self, tree, tree_root=None): - """Diff the checkout tree with a dict.""" - if not tree_root: - tree_root = self.root_dir - actual = read_tree(tree_root) - self.assertEqual(sorted(tree.keys()), sorted(actual.keys())) - self.assertEqual(tree, actual) - - def mangle_git_tree(self, *args): - """Creates a 'virtual directory snapshot' to compare with the actual result - on disk.""" - result = {} - for item, new_root in args: - repo, rev = item.split('@', 1) - tree = self.gittree(repo, rev) - for k, v in tree.items(): - path = join(new_root, k).replace(os.sep, '/') - result[path] = v - return result - - def githash(self, repo, rev): - """Sort-hand: Returns the hash for a git 'revision'.""" - return self.FAKE_REPOS.git_hashes[repo][int(rev)][0] - - def gittree(self, repo, rev): - """Sort-hand: returns the directory tree for a git 'revision'.""" - return self.FAKE_REPOS.git_hashes[repo][int(rev)][1] - - def gitrevparse(self, repo): - """Returns the actual revision for a given repo.""" - return self.FAKE_REPOS._git_rev_parse(repo).decode('utf-8') + """This is vaguely inspired by twisted.""" + # Static FakeRepos instances. Lazy loaded. + CACHED_FAKE_REPOS = {} + # Override if necessary. + FAKE_REPOS_CLASS = FakeRepos + + def setUp(self): + super(FakeReposTestBase, self).setUp() + if not self.FAKE_REPOS_CLASS in self.CACHED_FAKE_REPOS: + self.CACHED_FAKE_REPOS[ + self.FAKE_REPOS_CLASS] = self.FAKE_REPOS_CLASS() + self.FAKE_REPOS = self.CACHED_FAKE_REPOS[self.FAKE_REPOS_CLASS] + # No need to call self.FAKE_REPOS.setUp(), it will be called by the + # child class. Do not define tearDown(), since super's version does the + # right thing and self.FAKE_REPOS is kept across tests. + + @property + def git_base(self): + """Shortcut.""" + return self.FAKE_REPOS.git_base + + def checkString(self, expected, result, msg=None): + """Prints the diffs to ease debugging.""" + self.assertEqual(expected.splitlines(), result.splitlines(), msg) + if expected != result: + # Strip the beginning + while expected and result and expected[0] == result[0]: + expected = expected[1:] + result = result[1:] + # The exception trace makes it hard to read so dump it too. + if '\n' in result: + print(result) + self.assertEqual(expected, result, msg) + + def check(self, expected, results): + """Checks stdout, stderr, returncode.""" + self.checkString(expected[0], results[0]) + self.checkString(expected[1], results[1]) + self.assertEqual(expected[2], results[2]) + + def assertTree(self, tree, tree_root=None): + """Diff the checkout tree with a dict.""" + if not tree_root: + tree_root = self.root_dir + actual = read_tree(tree_root) + self.assertEqual(sorted(tree.keys()), sorted(actual.keys())) + self.assertEqual(tree, actual) + + def mangle_git_tree(self, *args): + """Creates a 'virtual directory snapshot' to compare with the actual + result on disk.""" + result = {} + for item, new_root in args: + repo, rev = item.split('@', 1) + tree = self.gittree(repo, rev) + for k, v in tree.items(): + path = join(new_root, k).replace(os.sep, '/') + result[path] = v + return result + + def githash(self, repo, rev): + """Sort-hand: Returns the hash for a git 'revision'.""" + return self.FAKE_REPOS.git_hashes[repo][int(rev)][0] + + def gittree(self, repo, rev): + """Sort-hand: returns the directory tree for a git 'revision'.""" + return self.FAKE_REPOS.git_hashes[repo][int(rev)][1] + + def gitrevparse(self, repo): + """Returns the actual revision for a given repo.""" + return self.FAKE_REPOS._git_rev_parse(repo).decode('utf-8') def main(argv): - fake = FakeRepos() - print('Using %s' % fake.root_dir) - try: - fake.set_up_git() - print('Fake setup, press enter to quit or Ctrl-C to keep the checkouts.') - sys.stdin.readline() - except KeyboardInterrupt: - trial_dir.TrialDir.SHOULD_LEAK.leak = True - return 0 + fake = FakeRepos() + print('Using %s' % fake.root_dir) + try: + fake.set_up_git() + print( + 'Fake setup, press enter to quit or Ctrl-C to keep the checkouts.') + sys.stdin.readline() + except KeyboardInterrupt: + trial_dir.TrialDir.SHOULD_LEAK.leak = True + return 0 if __name__ == '__main__': - sys.exit(main(sys.argv)) + sys.exit(main(sys.argv)) diff --git a/testing_support/filesystem_mock.py b/testing_support/filesystem_mock.py index 563447c60..660abbb5d 100644 --- a/testing_support/filesystem_mock.py +++ b/testing_support/filesystem_mock.py @@ -10,98 +10,98 @@ from io import StringIO def _RaiseNotFound(path): - raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT)) + raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT)) class MockFileSystem(object): - """Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock + """Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock Implements a filesystem-like interface on top of a dict of filenames -> file contents. A file content value of None indicates that the file should not exist (IOError will be raised if it is opened; reading from a missing key raises a KeyError, not an IOError.""" - - def __init__(self, files=None): - self.files = files or {} - self.written_files = {} - self._sep = '/' - - @property - def sep(self): - return self._sep - - def abspath(self, path): - if path.endswith(self.sep): - return path[:-1] - return path - - def basename(self, path): - if self.sep not in path: - return '' - return self.split(path)[-1] or self.sep - - def dirname(self, path): - if self.sep not in path: - return '' - return self.split(path)[0] or self.sep - - def exists(self, path): - return self.isfile(path) or self.isdir(path) - - def isabs(self, path): - return path.startswith(self.sep) - - def isfile(self, path): - return path in self.files and self.files[path] is not None - - def isdir(self, path): - if path in self.files: - return False - if not path.endswith(self.sep): - path += self.sep - - # We need to use a copy of the keys here in order to avoid switching - # to a different thread and potentially modifying the dict in - # mid-iteration. - files = list(self.files.keys())[:] - return any(f.startswith(path) for f in files) - - def join(self, *comps): - # TODO: Might want tests for this and/or a better comment about how - # it works. - return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps)) - - def glob(self, path): - return fnmatch.filter(self.files.keys(), path) - - def open_for_reading(self, path): - return StringIO(self.read_binary_file(path)) - - def normpath(self, path): - # This is not a complete implementation of normpath. Only covers what we - # use in tests. - result = [] - for part in path.split(self.sep): - if part == '..': - result.pop() - elif part == '.': - continue - else: - result.append(part) - return self.sep.join(result) - - def read_binary_file(self, path): - # Intentionally raises KeyError if we don't recognize the path. - if self.files[path] is None: - _RaiseNotFound(path) - return self.files[path] - - def relpath(self, path, base): - # This implementation is wrong in many ways; assert to check them for now. - if not base.endswith(self.sep): - base += self.sep - assert path.startswith(base) - return path[len(base):] - - def split(self, path): - return path.rsplit(self.sep, 1) + def __init__(self, files=None): + self.files = files or {} + self.written_files = {} + self._sep = '/' + + @property + def sep(self): + return self._sep + + def abspath(self, path): + if path.endswith(self.sep): + return path[:-1] + return path + + def basename(self, path): + if self.sep not in path: + return '' + return self.split(path)[-1] or self.sep + + def dirname(self, path): + if self.sep not in path: + return '' + return self.split(path)[0] or self.sep + + def exists(self, path): + return self.isfile(path) or self.isdir(path) + + def isabs(self, path): + return path.startswith(self.sep) + + def isfile(self, path): + return path in self.files and self.files[path] is not None + + def isdir(self, path): + if path in self.files: + return False + if not path.endswith(self.sep): + path += self.sep + + # We need to use a copy of the keys here in order to avoid switching + # to a different thread and potentially modifying the dict in + # mid-iteration. + files = list(self.files.keys())[:] + return any(f.startswith(path) for f in files) + + def join(self, *comps): + # TODO: Might want tests for this and/or a better comment about how + # it works. + return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps)) + + def glob(self, path): + return fnmatch.filter(self.files.keys(), path) + + def open_for_reading(self, path): + return StringIO(self.read_binary_file(path)) + + def normpath(self, path): + # This is not a complete implementation of normpath. Only covers what we + # use in tests. + result = [] + for part in path.split(self.sep): + if part == '..': + result.pop() + elif part == '.': + continue + else: + result.append(part) + return self.sep.join(result) + + def read_binary_file(self, path): + # Intentionally raises KeyError if we don't recognize the path. + if self.files[path] is None: + _RaiseNotFound(path) + return self.files[path] + + def relpath(self, path, base): + # This implementation is wrong in many ways; assert to check them for + # now. + if not base.endswith(self.sep): + base += self.sep + assert path.startswith(base) + return path[len(base):] + + def split(self, path): + return path.rsplit(self.sep, 1) diff --git a/testing_support/git_test_utils.py b/testing_support/git_test_utils.py index 2a7a9c22d..a6aa4426d 100644 --- a/testing_support/git_test_utils.py +++ b/testing_support/git_test_utils.py @@ -17,108 +17,107 @@ import unittest import gclient_utils - DEFAULT_BRANCH = 'main' def git_hash_data(data, typ='blob'): - """Calculate the git-style SHA1 for some data. + """Calculate the git-style SHA1 for some data. Only supports 'blob' type data at the moment. """ - assert typ == 'blob', 'Only support blobs for now' - return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest() + assert typ == 'blob', 'Only support blobs for now' + return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest() class OrderedSet(collections.MutableSet): - # from http://code.activestate.com/recipes/576694/ - def __init__(self, iterable=None): - self.end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.data = {} # key --> [key, prev, next] - if iterable is not None: - self |= iterable - - def __contains__(self, key): - return key in self.data - - def __eq__(self, other): - if isinstance(other, OrderedSet): - return len(self) == len(other) and list(self) == list(other) - return set(self) == set(other) - - def __ne__(self, other): - if isinstance(other, OrderedSet): - return len(self) != len(other) or list(self) != list(other) - return set(self) != set(other) - - def __len__(self): - return len(self.data) - - def __iter__(self): - end = self.end - curr = end[2] - while curr is not end: - yield curr[0] - curr = curr[2] - - def __repr__(self): - if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) - - def __reversed__(self): - end = self.end - curr = end[1] - while curr is not end: - yield curr[0] - curr = curr[1] - - def add(self, key): - if key not in self.data: - end = self.end - curr = end[1] - curr[2] = end[1] = self.data[key] = [key, curr, end] - - def difference_update(self, *others): - for other in others: - for i in other: - self.discard(i) - - def discard(self, key): - if key in self.data: - key, prev, nxt = self.data.pop(key) - prev[2] = nxt - nxt[1] = prev - - def pop(self, last=True): # pylint: disable=arguments-differ - if not self: - raise KeyError('set is empty') - key = self.end[1][0] if last else self.end[2][0] - self.discard(key) - return key + # from http://code.activestate.com/recipes/576694/ + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.data = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __contains__(self, key): + return key in self.data + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + + def __ne__(self, other): + if isinstance(other, OrderedSet): + return len(self) != len(other) or list(self) != list(other) + return set(self) != set(other) + + def __len__(self): + return len(self.data) + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__, ) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def add(self, key): + if key not in self.data: + end = self.end + curr = end[1] + curr[2] = end[1] = self.data[key] = [key, curr, end] + + def difference_update(self, *others): + for other in others: + for i in other: + self.discard(i) + + def discard(self, key): + if key in self.data: + key, prev, nxt = self.data.pop(key) + prev[2] = nxt + nxt[1] = prev + + def pop(self, last=True): # pylint: disable=arguments-differ + if not self: + raise KeyError('set is empty') + key = self.end[1][0] if last else self.end[2][0] + self.discard(key) + return key class UTC(datetime.tzinfo): - """UTC time zone. + """UTC time zone. from https://docs.python.org/2/library/datetime.html#tzinfo-objects """ - def utcoffset(self, dt): - return datetime.timedelta(0) + def utcoffset(self, dt): + return datetime.timedelta(0) - def tzname(self, dt): - return "UTC" + def tzname(self, dt): + return "UTC" - def dst(self, dt): - return datetime.timedelta(0) + def dst(self, dt): + return datetime.timedelta(0) UTC = UTC() class GitRepoSchema(object): - """A declarative git testing repo. + """A declarative git testing repo. Pass a schema to __init__ in the form of: A B C D @@ -141,11 +140,10 @@ class GitRepoSchema(object): in the schema) get earlier timestamps. Stamps start at the Unix Epoch, and increment by 1 day each. """ - COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root') + COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root') - def __init__(self, repo_schema='', - content_fn=lambda v: {v: {'data': v}}): - """Builds a new GitRepoSchema. + def __init__(self, repo_schema='', content_fn=lambda v: {v: {'data': v}}): + """Builds a new GitRepoSchema. Args: repo_schema (str) - Initial schema for this repo. See class docstring for @@ -156,88 +154,88 @@ class GitRepoSchema(object): commit_name). See the docstring on the GitRepo class for the format of the data returned by this function. """ - self.main = None - self.par_map = {} - self.data_cache = {} - self.content_fn = content_fn - self.add_commits(repo_schema) + self.main = None + self.par_map = {} + self.data_cache = {} + self.content_fn = content_fn + self.add_commits(repo_schema) - def walk(self): - """(Generator) Walks the repo schema from roots to tips. + def walk(self): + """(Generator) Walks the repo schema from roots to tips. Generates GitRepoSchema.COMMIT objects for each commit. Throws an AssertionError if it detects a cycle. """ - is_root = True - par_map = copy.deepcopy(self.par_map) - while par_map: - empty_keys = set(k for k, v in par_map.items() if not v) - assert empty_keys, 'Cycle detected! %s' % par_map - - for k in sorted(empty_keys): - yield self.COMMIT(k, self.par_map[k], - not any(k in v for v in self.par_map.values()), - is_root) - del par_map[k] - for v in par_map.values(): - v.difference_update(empty_keys) - is_root = False - - def add_partial(self, commit, parent=None): - if commit not in self.par_map: - self.par_map[commit] = OrderedSet() - if parent is not None: - self.par_map[commit].add(parent) - - def add_commits(self, schema): - """Adds more commits from a schema into the existing Schema. + is_root = True + par_map = copy.deepcopy(self.par_map) + while par_map: + empty_keys = set(k for k, v in par_map.items() if not v) + assert empty_keys, 'Cycle detected! %s' % par_map + + for k in sorted(empty_keys): + yield self.COMMIT( + k, self.par_map[k], + not any(k in v for v in self.par_map.values()), is_root) + del par_map[k] + for v in par_map.values(): + v.difference_update(empty_keys) + is_root = False + + def add_partial(self, commit, parent=None): + if commit not in self.par_map: + self.par_map[commit] = OrderedSet() + if parent is not None: + self.par_map[commit].add(parent) + + def add_commits(self, schema): + """Adds more commits from a schema into the existing Schema. Args: schema (str) - See class docstring for info on schema format. Throws an AssertionError if it detects a cycle. """ - for commits in (l.split() for l in schema.splitlines() if l.strip()): - parent = None - for commit in commits: - self.add_partial(commit, parent) - parent = commit - if parent and not self.main: - self.main = parent - for _ in self.walk(): # This will throw if there are any cycles. - pass - - def reify(self): - """Returns a real GitRepo for this GitRepoSchema""" - return GitRepo(self) - - def data_for(self, commit): - """Obtains the data for |commit|. + for commits in (l.split() for l in schema.splitlines() if l.strip()): + parent = None + for commit in commits: + self.add_partial(commit, parent) + parent = commit + if parent and not self.main: + self.main = parent + for _ in self.walk(): # This will throw if there are any cycles. + pass + + def reify(self): + """Returns a real GitRepo for this GitRepoSchema""" + return GitRepo(self) + + def data_for(self, commit): + """Obtains the data for |commit|. See the docstring on the GitRepo class for the format of the returned data. Caches the result on this GitRepoSchema instance. """ - if commit not in self.data_cache: - self.data_cache[commit] = self.content_fn(commit) - return self.data_cache[commit] + if commit not in self.data_cache: + self.data_cache[commit] = self.content_fn(commit) + return self.data_cache[commit] - def simple_graph(self): - """Returns a dictionary of {commit_subject: {parent commit_subjects}} + def simple_graph(self): + """Returns a dictionary of {commit_subject: {parent commit_subjects}} This allows you to get a very simple connection graph over the whole repo for comparison purposes. Only commit subjects (not ids, not content/data) are considered """ - ret = {} - for commit in self.walk(): - ret.setdefault(commit.name, set()).update(commit.parents) - return ret + ret = {} + for commit in self.walk(): + ret.setdefault(commit.name, set()).update(commit.parents) + return ret class GitRepo(object): - """Creates a real git repo for a GitRepoSchema. + """Creates a real git repo for a GitRepoSchema. Obtains schema and content information from the GitRepoSchema. @@ -260,26 +258,26 @@ class GitRepo(object): For file content, if 'data' is None, then this commit will `git rm` that file. """ - BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo') - atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR) + BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo') + atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR) - # Singleton objects to specify specific data in a commit dictionary. - AUTHOR_NAME = object() - AUTHOR_EMAIL = object() - AUTHOR_DATE = object() - COMMITTER_NAME = object() - COMMITTER_EMAIL = object() - COMMITTER_DATE = object() + # Singleton objects to specify specific data in a commit dictionary. + AUTHOR_NAME = object() + AUTHOR_EMAIL = object() + AUTHOR_DATE = object() + COMMITTER_NAME = object() + COMMITTER_EMAIL = object() + COMMITTER_DATE = object() - DEFAULT_AUTHOR_NAME = 'Author McAuthorly' - DEFAULT_AUTHOR_EMAIL = 'author@example.com' - DEFAULT_COMMITTER_NAME = 'Charles Committish' - DEFAULT_COMMITTER_EMAIL = 'commitish@example.com' + DEFAULT_AUTHOR_NAME = 'Author McAuthorly' + DEFAULT_AUTHOR_EMAIL = 'author@example.com' + DEFAULT_COMMITTER_NAME = 'Charles Committish' + DEFAULT_COMMITTER_EMAIL = 'commitish@example.com' - COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout') + COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout') - def __init__(self, schema): - """Makes new GitRepo. + def __init__(self, schema): + """Makes new GitRepo. Automatically creates a temp folder under GitRepo.BASE_TEMP_DIR. It's recommended that you clean this repo up by calling nuke() on it, but if not, @@ -289,194 +287,198 @@ class GitRepo(object): Args: schema - An instance of GitRepoSchema """ - self.repo_path = os.path.realpath(tempfile.mkdtemp(dir=self.BASE_TEMP_DIR)) - self.commit_map = {} - self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC) + self.repo_path = os.path.realpath( + tempfile.mkdtemp(dir=self.BASE_TEMP_DIR)) + self.commit_map = {} + self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC) - self.to_schema_refs = ['--branches'] + self.to_schema_refs = ['--branches'] - self.git('init', '-b', DEFAULT_BRANCH) - self.git('config', 'user.name', 'testcase') - self.git('config', 'user.email', 'testcase@example.com') - for commit in schema.walk(): - self._add_schema_commit(commit, schema.data_for(commit.name)) - self.last_commit = self[commit.name] - if schema.main: - self.git('update-ref', 'refs/heads/main', self[schema.main]) + self.git('init', '-b', DEFAULT_BRANCH) + self.git('config', 'user.name', 'testcase') + self.git('config', 'user.email', 'testcase@example.com') + for commit in schema.walk(): + self._add_schema_commit(commit, schema.data_for(commit.name)) + self.last_commit = self[commit.name] + if schema.main: + self.git('update-ref', 'refs/heads/main', self[schema.main]) - def __getitem__(self, commit_name): - """Gets the hash of a commit by its schema name. + def __getitem__(self, commit_name): + """Gets the hash of a commit by its schema name. >>> r = GitRepo(GitRepoSchema('A B C')) >>> r['B'] '7381febe1da03b09da47f009963ab7998a974935' """ - return self.commit_map[commit_name] - - def _add_schema_commit(self, commit, commit_data): - commit_data = commit_data or {} - - if commit.parents: - parents = list(commit.parents) - self.git('checkout', '--detach', '-q', self[parents[0]]) - if len(parents) > 1: - self.git('merge', '--no-commit', '-q', *[self[x] for x in parents[1:]]) - else: - self.git('checkout', '--orphan', 'root_%s' % commit.name) - self.git('rm', '-rf', '.') - - env = self.get_git_commit_env(commit_data) - - for fname, file_data in commit_data.items(): - # If it isn't a string, it's one of the special keys. - if not isinstance(fname, str): - continue - - deleted = False - if 'data' in file_data: - data = file_data.get('data') - if data is None: - deleted = True - self.git('rm', fname) - else: - path = os.path.join(self.repo_path, fname) - pardir = os.path.dirname(path) - if not os.path.exists(pardir): - os.makedirs(pardir) - with open(path, 'wb') as f: - f.write(data) - - mode = file_data.get('mode') - if mode and not deleted: - os.chmod(path, mode) - - self.git('add', fname) - - rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env) - assert rslt.retcode == 0, 'Failed to commit %s' % str(commit) - self.commit_map[commit.name] = self.git('rev-parse', 'HEAD').stdout.strip() - self.git('tag', 'tag_%s' % commit.name, self[commit.name]) - if commit.is_branch: - self.git('branch', '-f', 'branch_%s' % commit.name, self[commit.name]) - - def get_git_commit_env(self, commit_data=None): - commit_data = commit_data or {} - env = os.environ.copy() - for prefix in ('AUTHOR', 'COMMITTER'): - for suffix in ('NAME', 'EMAIL', 'DATE'): - singleton = '%s_%s' % (prefix, suffix) - key = getattr(self, singleton) - if key in commit_data: - val = commit_data[key] - elif suffix == 'DATE': - val = self._date - self._date += datetime.timedelta(days=1) + return self.commit_map[commit_name] + + def _add_schema_commit(self, commit, commit_data): + commit_data = commit_data or {} + + if commit.parents: + parents = list(commit.parents) + self.git('checkout', '--detach', '-q', self[parents[0]]) + if len(parents) > 1: + self.git('merge', '--no-commit', '-q', + *[self[x] for x in parents[1:]]) else: - val = getattr(self, 'DEFAULT_%s' % singleton) - if not isinstance(val, str) and not isinstance(val, bytes): - val = str(val) - env['GIT_%s' % singleton] = val - return env - - def git(self, *args, **kwargs): - """Runs a git command specified by |args| in this repo.""" - assert self.repo_path is not None - try: - with open(os.devnull, 'wb') as devnull: - shell = sys.platform == 'win32' - output = subprocess.check_output( - ('git', ) + args, - shell=shell, - cwd=self.repo_path, - stderr=devnull, - **kwargs) - output = output.decode('utf-8') - return self.COMMAND_OUTPUT(0, output) - except subprocess.CalledProcessError as e: - return self.COMMAND_OUTPUT(e.returncode, e.output) - - def show_commit(self, commit_name, format_string): - """Shows a commit (by its schema name) with a given format string.""" - return self.git('show', '-q', '--pretty=format:%s' % format_string, - self[commit_name]).stdout - - def git_commit(self, message): - return self.git('commit', '-am', message, env=self.get_git_commit_env()) - - def nuke(self): - """Obliterates the git repo on disk. + self.git('checkout', '--orphan', 'root_%s' % commit.name) + self.git('rm', '-rf', '.') + + env = self.get_git_commit_env(commit_data) + + for fname, file_data in commit_data.items(): + # If it isn't a string, it's one of the special keys. + if not isinstance(fname, str): + continue + + deleted = False + if 'data' in file_data: + data = file_data.get('data') + if data is None: + deleted = True + self.git('rm', fname) + else: + path = os.path.join(self.repo_path, fname) + pardir = os.path.dirname(path) + if not os.path.exists(pardir): + os.makedirs(pardir) + with open(path, 'wb') as f: + f.write(data) + + mode = file_data.get('mode') + if mode and not deleted: + os.chmod(path, mode) + + self.git('add', fname) + + rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env) + assert rslt.retcode == 0, 'Failed to commit %s' % str(commit) + self.commit_map[commit.name] = self.git('rev-parse', + 'HEAD').stdout.strip() + self.git('tag', 'tag_%s' % commit.name, self[commit.name]) + if commit.is_branch: + self.git('branch', '-f', 'branch_%s' % commit.name, + self[commit.name]) + + def get_git_commit_env(self, commit_data=None): + commit_data = commit_data or {} + env = os.environ.copy() + for prefix in ('AUTHOR', 'COMMITTER'): + for suffix in ('NAME', 'EMAIL', 'DATE'): + singleton = '%s_%s' % (prefix, suffix) + key = getattr(self, singleton) + if key in commit_data: + val = commit_data[key] + elif suffix == 'DATE': + val = self._date + self._date += datetime.timedelta(days=1) + else: + val = getattr(self, 'DEFAULT_%s' % singleton) + if not isinstance(val, str) and not isinstance(val, bytes): + val = str(val) + env['GIT_%s' % singleton] = val + return env + + def git(self, *args, **kwargs): + """Runs a git command specified by |args| in this repo.""" + assert self.repo_path is not None + try: + with open(os.devnull, 'wb') as devnull: + shell = sys.platform == 'win32' + output = subprocess.check_output(('git', ) + args, + shell=shell, + cwd=self.repo_path, + stderr=devnull, + **kwargs) + output = output.decode('utf-8') + return self.COMMAND_OUTPUT(0, output) + except subprocess.CalledProcessError as e: + return self.COMMAND_OUTPUT(e.returncode, e.output) + + def show_commit(self, commit_name, format_string): + """Shows a commit (by its schema name) with a given format string.""" + return self.git('show', '-q', '--pretty=format:%s' % format_string, + self[commit_name]).stdout + + def git_commit(self, message): + return self.git('commit', '-am', message, env=self.get_git_commit_env()) + + def nuke(self): + """Obliterates the git repo on disk. Causes this GitRepo to be unusable. """ - gclient_utils.rmtree(self.repo_path) - self.repo_path = None - - def run(self, fn, *args, **kwargs): - """Run a python function with the given args and kwargs with the cwd set to - the git repo.""" - assert self.repo_path is not None - curdir = os.getcwd() - try: - os.chdir(self.repo_path) - return fn(*args, **kwargs) - finally: - os.chdir(curdir) - - def capture_stdio(self, fn, *args, **kwargs): - """Run a python function with the given args and kwargs with the cwd set to - the git repo. + gclient_utils.rmtree(self.repo_path) + self.repo_path = None + + def run(self, fn, *args, **kwargs): + """Run a python function with the given args and kwargs with the cwd + set to the git repo.""" + assert self.repo_path is not None + curdir = os.getcwd() + try: + os.chdir(self.repo_path) + return fn(*args, **kwargs) + finally: + os.chdir(curdir) + + def capture_stdio(self, fn, *args, **kwargs): + """Run a python function with the given args and kwargs with the cwd set + to the git repo. Returns the (stdout, stderr) of whatever ran, instead of the what |fn| returned. """ - stdout = sys.stdout - stderr = sys.stderr - try: - with tempfile.TemporaryFile('w+') as out: - with tempfile.TemporaryFile('w+') as err: - sys.stdout = out - sys.stderr = err - try: - self.run(fn, *args, **kwargs) - except SystemExit: - pass - out.seek(0) - err.seek(0) - return out.read(), err.read() - finally: - sys.stdout = stdout - sys.stderr = stderr - - def open(self, path, mode='rb'): - return open(os.path.join(self.repo_path, path), mode) - - def to_schema(self): - lines = self.git('rev-list', '--parents', '--reverse', '--topo-order', - '--format=%s', *self.to_schema_refs).stdout.splitlines() - hash_to_msg = {} - ret = GitRepoSchema() - current = None - parents = [] - for line in lines: - if line.startswith('commit'): - assert current is None - tokens = line.split() - current, parents = tokens[1], tokens[2:] - assert all(p in hash_to_msg for p in parents) - else: - assert current is not None - hash_to_msg[current] = line - ret.add_partial(line) - for parent in parents: - ret.add_partial(line, hash_to_msg[parent]) + stdout = sys.stdout + stderr = sys.stderr + try: + with tempfile.TemporaryFile('w+') as out: + with tempfile.TemporaryFile('w+') as err: + sys.stdout = out + sys.stderr = err + try: + self.run(fn, *args, **kwargs) + except SystemExit: + pass + out.seek(0) + err.seek(0) + return out.read(), err.read() + finally: + sys.stdout = stdout + sys.stderr = stderr + + def open(self, path, mode='rb'): + return open(os.path.join(self.repo_path, path), mode) + + def to_schema(self): + lines = self.git('rev-list', '--parents', '--reverse', '--topo-order', + '--format=%s', + *self.to_schema_refs).stdout.splitlines() + hash_to_msg = {} + ret = GitRepoSchema() current = None parents = [] - assert current is None - return ret + for line in lines: + if line.startswith('commit'): + assert current is None + tokens = line.split() + current, parents = tokens[1], tokens[2:] + assert all(p in hash_to_msg for p in parents) + else: + assert current is not None + hash_to_msg[current] = line + ret.add_partial(line) + for parent in parents: + ret.add_partial(line, hash_to_msg[parent]) + current = None + parents = [] + assert current is None + return ret class GitRepoSchemaTestBase(unittest.TestCase): - """A TestCase with a built-in GitRepoSchema. + """A TestCase with a built-in GitRepoSchema. Expects a class variable REPO_SCHEMA to be a GitRepoSchema string in the form described by that class. @@ -487,61 +489,62 @@ class GitRepoSchemaTestBase(unittest.TestCase): You probably will end up using either GitRepoReadOnlyTestBase or GitRepoReadWriteTestBase for real tests. """ - REPO_SCHEMA = None + REPO_SCHEMA = None - @classmethod - def getRepoContent(cls, commit): - commit = 'COMMIT_%s' % commit - return getattr(cls, commit, None) + @classmethod + def getRepoContent(cls, commit): + commit = 'COMMIT_%s' % commit + return getattr(cls, commit, None) - @classmethod - def setUpClass(cls): - super(GitRepoSchemaTestBase, cls).setUpClass() - assert cls.REPO_SCHEMA is not None - cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent) + @classmethod + def setUpClass(cls): + super(GitRepoSchemaTestBase, cls).setUpClass() + assert cls.REPO_SCHEMA is not None + cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent) class GitRepoReadOnlyTestBase(GitRepoSchemaTestBase): - """Injects a GitRepo object given the schema and content from + """Injects a GitRepo object given the schema and content from GitRepoSchemaTestBase into TestCase classes which subclass this. This GitRepo will appear as self.repo, and will be deleted and recreated once for the duration of all the tests in the subclass. """ - REPO_SCHEMA = None + REPO_SCHEMA = None - @classmethod - def setUpClass(cls): - super(GitRepoReadOnlyTestBase, cls).setUpClass() - assert cls.REPO_SCHEMA is not None - cls.repo = cls.r_schema.reify() + @classmethod + def setUpClass(cls): + super(GitRepoReadOnlyTestBase, cls).setUpClass() + assert cls.REPO_SCHEMA is not None + cls.repo = cls.r_schema.reify() - def setUp(self): - self.repo.git('checkout', '-f', self.repo.last_commit) + def setUp(self): + self.repo.git('checkout', '-f', self.repo.last_commit) - @classmethod - def tearDownClass(cls): - cls.repo.nuke() - super(GitRepoReadOnlyTestBase, cls).tearDownClass() + @classmethod + def tearDownClass(cls): + cls.repo.nuke() + super(GitRepoReadOnlyTestBase, cls).tearDownClass() class GitRepoReadWriteTestBase(GitRepoSchemaTestBase): - """Injects a GitRepo object given the schema and content from + """Injects a GitRepo object given the schema and content from GitRepoSchemaTestBase into TestCase classes which subclass this. This GitRepo will appear as self.repo, and will be deleted and recreated for each test function in the subclass. """ - REPO_SCHEMA = None + REPO_SCHEMA = None - def setUp(self): - super(GitRepoReadWriteTestBase, self).setUp() - self.repo = self.r_schema.reify() + def setUp(self): + super(GitRepoReadWriteTestBase, self).setUp() + self.repo = self.r_schema.reify() - def tearDown(self): - self.repo.nuke() - super(GitRepoReadWriteTestBase, self).tearDown() + def tearDown(self): + self.repo.nuke() + super(GitRepoReadWriteTestBase, self).tearDown() - def assertSchema(self, schema_string): - self.assertEqual(GitRepoSchema(schema_string).simple_graph(), - self.repo.to_schema().simple_graph()) + def assertSchema(self, schema_string): + self.assertEqual( + GitRepoSchema(schema_string).simple_graph(), + self.repo.to_schema().simple_graph()) diff --git a/testing_support/presubmit_canned_checks_test_mocks.py b/testing_support/presubmit_canned_checks_test_mocks.py index a1053b0ed..93196aa19 100644 --- a/testing_support/presubmit_canned_checks_test_mocks.py +++ b/testing_support/presubmit_canned_checks_test_mocks.py @@ -15,10 +15,12 @@ from presubmit_canned_checks import _ReportErrorFileAndLine class MockCannedChecks(object): - def _FindNewViolationsOfRule(self, callable_rule, input_api, - source_file_filter=None, - error_formatter=_ReportErrorFileAndLine): - """Find all newly introduced violations of a per-line rule (a callable). + def _FindNewViolationsOfRule(self, + callable_rule, + input_api, + source_file_filter=None, + error_formatter=_ReportErrorFileAndLine): + """Find all newly introduced violations of a per-line rule (a callable). Arguments: callable_rule: a callable taking a file extension and line of input and @@ -32,232 +34,246 @@ class MockCannedChecks(object): Returns: A list of the newly-introduced violations reported by the rule. """ - errors = [] - for f in input_api.AffectedFiles(include_deletes=False, - file_filter=source_file_filter): - # For speed, we do two passes, checking first the full file. Shelling out - # to the SCM to determine the changed region can be quite expensive on - # Win32. Assuming that most files will be kept problem-free, we can - # skip the SCM operations most of the time. - extension = str(f.LocalPath()).rsplit('.', 1)[-1] - if all(callable_rule(extension, line) for line in f.NewContents()): - continue # No violation found in full text: can skip considering diff. - - for line_num, line in f.ChangedContents(): - if not callable_rule(extension, line): - errors.append(error_formatter(f.LocalPath(), line_num, line)) - - return errors + errors = [] + for f in input_api.AffectedFiles(include_deletes=False, + file_filter=source_file_filter): + # For speed, we do two passes, checking first the full file. + # Shelling out to the SCM to determine the changed region can be + # quite expensive on Win32. Assuming that most files will be kept + # problem-free, we can skip the SCM operations most of the time. + extension = str(f.LocalPath()).rsplit('.', 1)[-1] + if all(callable_rule(extension, line) for line in f.NewContents()): + # No violation found in full text: can skip considering diff. + continue + + for line_num, line in f.ChangedContents(): + if not callable_rule(extension, line): + errors.append(error_formatter(f.LocalPath(), line_num, + line)) + + return errors class MockInputApi(object): - """Mock class for the InputApi class. + """Mock class for the InputApi class. This class can be used for unittests for presubmit by initializing the files attribute as the list of changed files. """ - DEFAULT_FILES_TO_SKIP = () - - def __init__(self): - self.canned_checks = MockCannedChecks() - self.fnmatch = fnmatch - self.json = json - self.re = re - self.os_path = os.path - self.platform = sys.platform - self.python_executable = sys.executable - self.platform = sys.platform - self.subprocess = subprocess - self.sys = sys - self.files = [] - self.is_committing = False - self.no_diffs = False - self.change = MockChange([]) - self.presubmit_local_path = os.path.dirname(__file__) - self.logging = logging.getLogger('PRESUBMIT') - - def CreateMockFileInPath(self, f_list): - self.os_path.exists = lambda x: x in f_list - - def AffectedFiles(self, file_filter=None, include_deletes=True): - for file in self.files: # pylint: disable=redefined-builtin - if file_filter and not file_filter(file): - continue - if not include_deletes and file.Action() == 'D': - continue - yield file - - def AffectedSourceFiles(self, file_filter=None): - return self.AffectedFiles(file_filter=file_filter) - - def FilterSourceFile(self, file, # pylint: disable=redefined-builtin - files_to_check=(), files_to_skip=()): - local_path = file.LocalPath() - found_in_files_to_check = not files_to_check - if files_to_check: - if isinstance(files_to_check, str): - raise TypeError('files_to_check should be an iterable of strings') - for pattern in files_to_check: - compiled_pattern = re.compile(pattern) - if compiled_pattern.search(local_path): - found_in_files_to_check = True - break - if files_to_skip: - if isinstance(files_to_skip, str): - raise TypeError('files_to_skip should be an iterable of strings') - for pattern in files_to_skip: - compiled_pattern = re.compile(pattern) - if compiled_pattern.search(local_path): - return False - return found_in_files_to_check - - def LocalPaths(self): - return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin - - def PresubmitLocalPath(self): - return self.presubmit_local_path - - def ReadFile(self, filename, mode='rU'): - if hasattr(filename, 'AbsoluteLocalPath'): - filename = filename.AbsoluteLocalPath() - for file_ in self.files: - if file_.LocalPath() == filename: - return '\n'.join(file_.NewContents()) - # Otherwise, file is not in our mock API. - raise IOError("No such file or directory: '%s'" % filename) + DEFAULT_FILES_TO_SKIP = () + + def __init__(self): + self.canned_checks = MockCannedChecks() + self.fnmatch = fnmatch + self.json = json + self.re = re + self.os_path = os.path + self.platform = sys.platform + self.python_executable = sys.executable + self.platform = sys.platform + self.subprocess = subprocess + self.sys = sys + self.files = [] + self.is_committing = False + self.no_diffs = False + self.change = MockChange([]) + self.presubmit_local_path = os.path.dirname(__file__) + self.logging = logging.getLogger('PRESUBMIT') + + def CreateMockFileInPath(self, f_list): + self.os_path.exists = lambda x: x in f_list + + def AffectedFiles(self, file_filter=None, include_deletes=True): + for file in self.files: # pylint: disable=redefined-builtin + if file_filter and not file_filter(file): + continue + if not include_deletes and file.Action() == 'D': + continue + yield file + + def AffectedSourceFiles(self, file_filter=None): + return self.AffectedFiles(file_filter=file_filter) + + def FilterSourceFile( + self, + file, # pylint: disable=redefined-builtin + files_to_check=(), + files_to_skip=()): + local_path = file.LocalPath() + found_in_files_to_check = not files_to_check + if files_to_check: + if isinstance(files_to_check, str): + raise TypeError( + 'files_to_check should be an iterable of strings') + for pattern in files_to_check: + compiled_pattern = re.compile(pattern) + if compiled_pattern.search(local_path): + found_in_files_to_check = True + break + if files_to_skip: + if isinstance(files_to_skip, str): + raise TypeError( + 'files_to_skip should be an iterable of strings') + for pattern in files_to_skip: + compiled_pattern = re.compile(pattern) + if compiled_pattern.search(local_path): + return False + return found_in_files_to_check + + def LocalPaths(self): + return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin + + def PresubmitLocalPath(self): + return self.presubmit_local_path + + def ReadFile(self, filename, mode='rU'): + if hasattr(filename, 'AbsoluteLocalPath'): + filename = filename.AbsoluteLocalPath() + for file_ in self.files: + if file_.LocalPath() == filename: + return '\n'.join(file_.NewContents()) + # Otherwise, file is not in our mock API. + raise IOError("No such file or directory: '%s'" % filename) class MockOutputApi(object): - """Mock class for the OutputApi class. + """Mock class for the OutputApi class. An instance of this class can be passed to presubmit unittests for outputing various types of results. """ - - class PresubmitResult(object): - def __init__(self, message, items=None, long_text=''): - self.message = message - self.items = items - self.long_text = long_text - - def __repr__(self): - return self.message - - class PresubmitError(PresubmitResult): - def __init__(self, message, items=None, long_text=''): - MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) - self.type = 'error' - - class PresubmitPromptWarning(PresubmitResult): - def __init__(self, message, items=None, long_text=''): - MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) - self.type = 'warning' - - class PresubmitNotifyResult(PresubmitResult): - def __init__(self, message, items=None, long_text=''): - MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) - self.type = 'notify' - - class PresubmitPromptOrNotify(PresubmitResult): - def __init__(self, message, items=None, long_text=''): - MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) - self.type = 'promptOrNotify' - - def __init__(self): - self.more_cc = [] - - def AppendCC(self, more_cc): - self.more_cc.extend(more_cc) + class PresubmitResult(object): + def __init__(self, message, items=None, long_text=''): + self.message = message + self.items = items + self.long_text = long_text + + def __repr__(self): + return self.message + + class PresubmitError(PresubmitResult): + def __init__(self, message, items=None, long_text=''): + MockOutputApi.PresubmitResult.__init__(self, message, items, + long_text) + self.type = 'error' + + class PresubmitPromptWarning(PresubmitResult): + def __init__(self, message, items=None, long_text=''): + MockOutputApi.PresubmitResult.__init__(self, message, items, + long_text) + self.type = 'warning' + + class PresubmitNotifyResult(PresubmitResult): + def __init__(self, message, items=None, long_text=''): + MockOutputApi.PresubmitResult.__init__(self, message, items, + long_text) + self.type = 'notify' + + class PresubmitPromptOrNotify(PresubmitResult): + def __init__(self, message, items=None, long_text=''): + MockOutputApi.PresubmitResult.__init__(self, message, items, + long_text) + self.type = 'promptOrNotify' + + def __init__(self): + self.more_cc = [] + + def AppendCC(self, more_cc): + self.more_cc.extend(more_cc) class MockFile(object): - """Mock class for the File class. + """Mock class for the File class. This class can be used to form the mock list of changed files in MockInputApi for presubmit unittests. """ - - def __init__(self, local_path, new_contents, old_contents=None, action='A', - scm_diff=None): - self._local_path = local_path - self._new_contents = new_contents - self._changed_contents = [(i + 1, l) for i, l in enumerate(new_contents)] - self._action = action - if scm_diff: - self._scm_diff = scm_diff - else: - self._scm_diff = ( - "--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" % - (local_path, len(new_contents))) - for l in new_contents: - self._scm_diff += "+%s\n" % l - self._old_contents = old_contents - - def Action(self): - return self._action - - def ChangedContents(self): - return self._changed_contents - - def NewContents(self): - return self._new_contents - - def LocalPath(self): - return self._local_path - - def AbsoluteLocalPath(self): - return self._local_path - - def GenerateScmDiff(self): - return self._scm_diff - - def OldContents(self): - return self._old_contents - - def rfind(self, p): - """os.path.basename is called on MockFile so we need an rfind method.""" - return self._local_path.rfind(p) - - def __getitem__(self, i): - """os.path.basename is called on MockFile so we need a get method.""" - return self._local_path[i] - - def __len__(self): - """os.path.basename is called on MockFile so we need a len method.""" - return len(self._local_path) - - def replace(self, altsep, sep): - """os.path.basename is called on MockFile so we need a replace method.""" - return self._local_path.replace(altsep, sep) + def __init__(self, + local_path, + new_contents, + old_contents=None, + action='A', + scm_diff=None): + self._local_path = local_path + self._new_contents = new_contents + self._changed_contents = [(i + 1, l) + for i, l in enumerate(new_contents)] + self._action = action + if scm_diff: + self._scm_diff = scm_diff + else: + self._scm_diff = ("--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" % + (local_path, len(new_contents))) + for l in new_contents: + self._scm_diff += "+%s\n" % l + self._old_contents = old_contents + + def Action(self): + return self._action + + def ChangedContents(self): + return self._changed_contents + + def NewContents(self): + return self._new_contents + + def LocalPath(self): + return self._local_path + + def AbsoluteLocalPath(self): + return self._local_path + + def GenerateScmDiff(self): + return self._scm_diff + + def OldContents(self): + return self._old_contents + + def rfind(self, p): + """os.path.basename is used on MockFile so we need an rfind method.""" + return self._local_path.rfind(p) + + def __getitem__(self, i): + """os.path.basename is used on MockFile so we need a get method.""" + return self._local_path[i] + + def __len__(self): + """os.path.basename is used on MockFile so we need a len method.""" + return len(self._local_path) + + def replace(self, altsep, sep): + """os.path.basename is used on MockFile so we need a replace method.""" + return self._local_path.replace(altsep, sep) class MockAffectedFile(MockFile): - def AbsoluteLocalPath(self): - return self._local_path + def AbsoluteLocalPath(self): + return self._local_path class MockChange(object): - """Mock class for Change class. + """Mock class for Change class. This class can be used in presubmit unittests to mock the query of the current change. """ + def __init__(self, changed_files, description=''): + self._changed_files = changed_files + self.footers = defaultdict(list) + self._description = description - def __init__(self, changed_files, description=''): - self._changed_files = changed_files - self.footers = defaultdict(list) - self._description = description - - def LocalPaths(self): - return self._changed_files + def LocalPaths(self): + return self._changed_files - def AffectedFiles(self, include_dirs=False, include_deletes=True, - file_filter=None): - return self._changed_files + def AffectedFiles(self, + include_dirs=False, + include_deletes=True, + file_filter=None): + return self._changed_files - def GitFootersFromDescription(self): - return self.footers + def GitFootersFromDescription(self): + return self.footers - def DescriptionText(self): - return self._description + def DescriptionText(self): + return self._description diff --git a/testing_support/subprocess2_test_script.py b/testing_support/subprocess2_test_script.py index 13bf1286b..dc44e712d 100644 --- a/testing_support/subprocess2_test_script.py +++ b/testing_support/subprocess2_test_script.py @@ -2,7 +2,6 @@ # Copyright (c) 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. - """Script used to test subprocess2.""" import optparse @@ -10,51 +9,52 @@ import os import sys import time - if sys.platform == 'win32': - # Annoying, make sure the output is not translated on Windows. - # pylint: disable=no-member,import-error - import msvcrt - msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) - msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY) + # Annoying, make sure the output is not translated on Windows. + # pylint: disable=no-member,import-error + import msvcrt + msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) + msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY) parser = optparse.OptionParser() -parser.add_option( - '--fail', - dest='return_value', - action='store_const', - default=0, - const=64) -parser.add_option( - '--crlf', action='store_const', const='\r\n', dest='eol', default='\n') -parser.add_option( - '--cr', action='store_const', const='\r', dest='eol') +parser.add_option('--fail', + dest='return_value', + action='store_const', + default=0, + const=64) +parser.add_option('--crlf', + action='store_const', + const='\r\n', + dest='eol', + default='\n') +parser.add_option('--cr', action='store_const', const='\r', dest='eol') parser.add_option('--stdout', action='store_true') parser.add_option('--stderr', action='store_true') parser.add_option('--read', action='store_true') options, args = parser.parse_args() if args: - parser.error('Internal error') + parser.error('Internal error') + def do(string): - if options.stdout: - sys.stdout.buffer.write(string.upper().encode('utf-8')) - sys.stdout.buffer.write(options.eol.encode('utf-8')) - if options.stderr: - sys.stderr.buffer.write(string.lower().encode('utf-8')) - sys.stderr.buffer.write(options.eol.encode('utf-8')) - sys.stderr.flush() + if options.stdout: + sys.stdout.buffer.write(string.upper().encode('utf-8')) + sys.stdout.buffer.write(options.eol.encode('utf-8')) + if options.stderr: + sys.stderr.buffer.write(string.lower().encode('utf-8')) + sys.stderr.buffer.write(options.eol.encode('utf-8')) + sys.stderr.flush() do('A') do('BB') do('CCC') if options.read: - assert options.return_value == 0 - try: - while sys.stdin.read(1): - options.return_value += 1 - except OSError: - pass + assert options.return_value == 0 + try: + while sys.stdin.read(1): + options.return_value += 1 + except OSError: + pass sys.exit(options.return_value) diff --git a/testing_support/test_case_utils.py b/testing_support/test_case_utils.py index 0cac13317..3747041c6 100644 --- a/testing_support/test_case_utils.py +++ b/testing_support/test_case_utils.py @@ -1,7 +1,6 @@ # Copyright (c) 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. - """Simplify unit tests based on pymox.""" from __future__ import print_function @@ -12,51 +11,52 @@ import string class TestCaseUtils(object): - """Base class with some additional functionalities. People will usually want + """Base class with some additional functionalities. People will usually want to use SuperMoxTestBase instead.""" - # Backup the separator in case it gets mocked - _OS_SEP = os.sep - _RANDOM_CHOICE = random.choice - _RANDOM_RANDINT = random.randint - _STRING_LETTERS = string.ascii_letters - - ## Some utilities for generating arbitrary arguments. - def String(self, max_length): - return ''.join([self._RANDOM_CHOICE(self._STRING_LETTERS) - for _ in range(self._RANDOM_RANDINT(1, max_length))]) - - def Strings(self, max_arg_count, max_arg_length): - return [self.String(max_arg_length) for _ in range(max_arg_count)] - - def Args(self, max_arg_count=8, max_arg_length=16): - return self.Strings(max_arg_count, - self._RANDOM_RANDINT(1, max_arg_length)) - - def _DirElts(self, max_elt_count=4, max_elt_length=8): - return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length)) - - def Dir(self, max_elt_count=4, max_elt_length=8): - return (self._RANDOM_CHOICE((self._OS_SEP, '')) + - self._DirElts(max_elt_count, max_elt_length)) - - def RootDir(self, max_elt_count=4, max_elt_length=8): - return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length) - - def compareMembers(self, obj, members): - """If you add a member, be sure to add the relevant test!""" - # Skip over members starting with '_' since they are usually not meant to - # be for public use. - actual_members = [x for x in sorted(dir(obj)) - if not x.startswith('_')] - expected_members = sorted(members) - if actual_members != expected_members: - diff = ([i for i in actual_members if i not in expected_members] + - [i for i in expected_members if i not in actual_members]) - print(diff, file=sys.stderr) - # pylint: disable=no-member - self.assertEqual(actual_members, expected_members) - - def setUp(self): - self.root_dir = self.Dir() - self.args = self.Args() - self.relpath = self.String(200) + # Backup the separator in case it gets mocked + _OS_SEP = os.sep + _RANDOM_CHOICE = random.choice + _RANDOM_RANDINT = random.randint + _STRING_LETTERS = string.ascii_letters + + ## Some utilities for generating arbitrary arguments. + def String(self, max_length): + return ''.join([ + self._RANDOM_CHOICE(self._STRING_LETTERS) + for _ in range(self._RANDOM_RANDINT(1, max_length)) + ]) + + def Strings(self, max_arg_count, max_arg_length): + return [self.String(max_arg_length) for _ in range(max_arg_count)] + + def Args(self, max_arg_count=8, max_arg_length=16): + return self.Strings(max_arg_count, + self._RANDOM_RANDINT(1, max_arg_length)) + + def _DirElts(self, max_elt_count=4, max_elt_length=8): + return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length)) + + def Dir(self, max_elt_count=4, max_elt_length=8): + return (self._RANDOM_CHOICE( + (self._OS_SEP, '')) + self._DirElts(max_elt_count, max_elt_length)) + + def RootDir(self, max_elt_count=4, max_elt_length=8): + return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length) + + def compareMembers(self, obj, members): + """If you add a member, be sure to add the relevant test!""" + # Skip over members starting with '_' since they are usually not meant + # to be for public use. + actual_members = [x for x in sorted(dir(obj)) if not x.startswith('_')] + expected_members = sorted(members) + if actual_members != expected_members: + diff = ([i for i in actual_members if i not in expected_members] + + [i for i in expected_members if i not in actual_members]) + print(diff, file=sys.stderr) + # pylint: disable=no-member + self.assertEqual(actual_members, expected_members) + + def setUp(self): + self.root_dir = self.Dir() + self.args = self.Args() + self.relpath = self.String(200) diff --git a/testing_support/trial_dir.py b/testing_support/trial_dir.py index 2c0a91892..72e1d131b 100644 --- a/testing_support/trial_dir.py +++ b/testing_support/trial_dir.py @@ -15,83 +15,84 @@ import gclient_utils class TrialDir(object): - """Manages a temporary directory. + """Manages a temporary directory. On first object creation, TrialDir.TRIAL_ROOT will be set to a new temporary directory created in /tmp or the equivalent. It will be deleted on process exit unless TrialDir.SHOULD_LEAK is set to True. """ - # When SHOULD_LEAK is set to True, temporary directories created while the - # tests are running aren't deleted at the end of the tests. Expect failures - # when running more than one test due to inter-test side-effects. Helps with - # debugging. - SHOULD_LEAK = False - - # Main root directory. - TRIAL_ROOT = None - - def __init__(self, subdir, leak=False): - self.leak = self.SHOULD_LEAK or leak - self.subdir = subdir - self.root_dir = None - - def set_up(self): - """All late initialization comes here.""" - # You can override self.TRIAL_ROOT. - if not self.TRIAL_ROOT: - # Was not yet initialized. - TrialDir.TRIAL_ROOT = os.path.realpath(tempfile.mkdtemp(prefix='trial')) - atexit.register(self._clean) - self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir) - gclient_utils.rmtree(self.root_dir) - os.makedirs(self.root_dir) - - def tear_down(self): - """Cleans the trial subdirectory for this instance.""" - if not self.leak: - logging.debug('Removing %s' % self.root_dir) - gclient_utils.rmtree(self.root_dir) - else: - logging.error('Leaking %s' % self.root_dir) - self.root_dir = None - - @staticmethod - def _clean(): - """Cleans the root trial directory.""" - if not TrialDir.SHOULD_LEAK: - logging.debug('Removing %s' % TrialDir.TRIAL_ROOT) - gclient_utils.rmtree(TrialDir.TRIAL_ROOT) - else: - logging.error('Leaking %s' % TrialDir.TRIAL_ROOT) + # When SHOULD_LEAK is set to True, temporary directories created while the + # tests are running aren't deleted at the end of the tests. Expect failures + # when running more than one test due to inter-test side-effects. Helps with + # debugging. + SHOULD_LEAK = False + + # Main root directory. + TRIAL_ROOT = None + + def __init__(self, subdir, leak=False): + self.leak = self.SHOULD_LEAK or leak + self.subdir = subdir + self.root_dir = None + + def set_up(self): + """All late initialization comes here.""" + # You can override self.TRIAL_ROOT. + if not self.TRIAL_ROOT: + # Was not yet initialized. + TrialDir.TRIAL_ROOT = os.path.realpath( + tempfile.mkdtemp(prefix='trial')) + atexit.register(self._clean) + self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir) + gclient_utils.rmtree(self.root_dir) + os.makedirs(self.root_dir) + + def tear_down(self): + """Cleans the trial subdirectory for this instance.""" + if not self.leak: + logging.debug('Removing %s' % self.root_dir) + gclient_utils.rmtree(self.root_dir) + else: + logging.error('Leaking %s' % self.root_dir) + self.root_dir = None + + @staticmethod + def _clean(): + """Cleans the root trial directory.""" + if not TrialDir.SHOULD_LEAK: + logging.debug('Removing %s' % TrialDir.TRIAL_ROOT) + gclient_utils.rmtree(TrialDir.TRIAL_ROOT) + else: + logging.error('Leaking %s' % TrialDir.TRIAL_ROOT) class TrialDirMixIn(object): - def setUp(self): - # Create a specific directory just for the test. - self.trial = TrialDir(self.id()) - self.trial.set_up() + def setUp(self): + # Create a specific directory just for the test. + self.trial = TrialDir(self.id()) + self.trial.set_up() - def tearDown(self): - self.trial.tear_down() + def tearDown(self): + self.trial.tear_down() - @property - def root_dir(self): - return self.trial.root_dir + @property + def root_dir(self): + return self.trial.root_dir class TestCase(unittest.TestCase, TrialDirMixIn): - """Base unittest class that cleans off a trial directory in tearDown().""" - def setUp(self): - unittest.TestCase.setUp(self) - TrialDirMixIn.setUp(self) + """Base unittest class that cleans off a trial directory in tearDown().""" + def setUp(self): + unittest.TestCase.setUp(self) + TrialDirMixIn.setUp(self) - def tearDown(self): - TrialDirMixIn.tearDown(self) - unittest.TestCase.tearDown(self) + def tearDown(self): + TrialDirMixIn.tearDown(self) + unittest.TestCase.tearDown(self) if '-l' in sys.argv: - # See SHOULD_LEAK definition in TrialDir for its purpose. - TrialDir.SHOULD_LEAK = True - print('Leaking!') - sys.argv.remove('-l') + # See SHOULD_LEAK definition in TrialDir for its purpose. + TrialDir.SHOULD_LEAK = True + print('Leaking!') + sys.argv.remove('-l')