testing_support: switch to 4 space indent

Reformat this dir by itself to help merging with conflicts with other CLs.

Reformatted using:
parallel ./yapf -i -- testing_support/*.py
~/chromiumos/chromite/contrib/reflow_overlong_comments testing_support/*.py

The files that still had strings that were too long were manually
reformatted.
testing_support/coverage_utils.py
testing_support/fake_repos.py
testing_support/git_test_utils.py
testing_support/presubmit_canned_checks_test_mocks.py

Change-Id: I4726a4bbd279a70bcf65d0987fcff0ff9a231386
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/4842593
Reviewed-by: Josip Sokcevic <sokcevic@chromium.org>
Commit-Queue: Josip Sokcevic <sokcevic@chromium.org>
Auto-Submit: Mike Frysinger <vapier@chromium.org>
changes/93/4842593/3
Mike Frysinger 2 years ago committed by LUCI CQ
parent 691128f836
commit f38dc929a8

@ -0,0 +1,3 @@
[style]
based_on_style = pep8
column_limit = 80

@ -10,22 +10,26 @@ import sys
import textwrap
import unittest
ROOT_PATH = os.path.abspath(os.path.join(
os.path.dirname(os.path.dirname(__file__))))
ROOT_PATH = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__))))
def native_error(msg, version):
print(textwrap.dedent("""\
print(
textwrap.dedent("""\
ERROR: Native python-coverage (version: %s) is required to be
installed on your PYTHONPATH to run this test. Recommendation:
sudo apt-get install pip
sudo pip install --upgrade coverage
%s""") % (version, msg))
sys.exit(1)
sys.exit(1)
def covered_main(includes, require_native=None, required_percentage=100.0,
def covered_main(includes,
require_native=None,
required_percentage=100.0,
disable_coverage=True):
"""Equivalent of unittest.main(), except that it gathers coverage data, and
"""Equivalent of unittest.main(), except that it gathers coverage data, and
asserts if the test is not at 100% coverage.
Args:
@ -37,43 +41,45 @@ def covered_main(includes, require_native=None, required_percentage=100.0,
disable_coverage (bool) - If True, just run unittest.main() without any
coverage tracking. Bug: crbug.com/662277
"""
if disable_coverage:
unittest.main()
return
if disable_coverage:
unittest.main()
return
try:
import coverage
if require_native is not None:
got_ver = coverage.__version__
if not getattr(coverage.collector, 'CTracer', None):
native_error((
"Native python-coverage module required.\n"
"Pure-python implementation (version: %s) found: %s"
) % (got_ver, coverage), require_native)
if got_ver < distutils.version.LooseVersion(require_native):
native_error("Wrong version (%s) found: %s" % (got_ver, coverage),
require_native)
except ImportError:
if require_native is None:
sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party'))
import coverage
else:
print("ERROR: python-coverage (%s) is required to be installed on your "
"PYTHONPATH to run this test." % require_native)
sys.exit(1)
try:
import coverage
if require_native is not None:
got_ver = coverage.__version__
if not getattr(coverage.collector, 'CTracer', None):
native_error(
("Native python-coverage module required.\n"
"Pure-python implementation (version: %s) found: %s") %
(got_ver, coverage), require_native)
if got_ver < distutils.version.LooseVersion(require_native):
native_error(
"Wrong version (%s) found: %s" % (got_ver, coverage),
require_native)
except ImportError:
if require_native is None:
sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party'))
import coverage
else:
print(
"ERROR: python-coverage (%s) is required to be installed on "
"your PYTHONPATH to run this test." % require_native)
sys.exit(1)
COVERAGE = coverage.coverage(include=includes)
COVERAGE.start()
COVERAGE = coverage.coverage(include=includes)
COVERAGE.start()
retcode = 0
try:
unittest.main()
except SystemExit as e:
retcode = e.code or retcode
retcode = 0
try:
unittest.main()
except SystemExit as e:
retcode = e.code or retcode
COVERAGE.stop()
if COVERAGE.report() < required_percentage:
print('FATAL: not at required %f%% coverage.' % required_percentage)
retcode = 2
COVERAGE.stop()
if COVERAGE.report() < required_percentage:
print('FATAL: not at required %f%% coverage.' % required_percentage)
retcode = 2
return retcode
return retcode

@ -74,129 +74,131 @@ DESCRIBE_JSON_TEMPLATE = """{
def parse_cipd(root, contents):
tree = {}
current_subdir = None
for line in contents:
line = line.strip()
match = re.match(CIPD_SUBDIR_RE, line)
if match:
print('match')
current_subdir = os.path.join(root, *match.group(1).split('/'))
if not root:
current_subdir = match.group(1)
elif line and current_subdir:
print('no match')
tree.setdefault(current_subdir, []).append(line)
return tree
tree = {}
current_subdir = None
for line in contents:
line = line.strip()
match = re.match(CIPD_SUBDIR_RE, line)
if match:
print('match')
current_subdir = os.path.join(root, *match.group(1).split('/'))
if not root:
current_subdir = match.group(1)
elif line and current_subdir:
print('no match')
tree.setdefault(current_subdir, []).append(line)
return tree
def expand_package_name_cmd(package_name):
package_split = package_name.split("/")
suffix = package_split[-1]
# Any use of var equality should return empty for testing.
if "=" in suffix:
if suffix != "${platform=fake-platform-ok}":
return ""
package_name = "/".join(package_split[:-1] + ["${platform}"])
for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]:
var = "${%s}" % v
if package_name.endswith(var):
package_name = package_name.replace(var, "%s-expanded-test-only" % v)
return package_name
package_split = package_name.split("/")
suffix = package_split[-1]
# Any use of var equality should return empty for testing.
if "=" in suffix:
if suffix != "${platform=fake-platform-ok}":
return ""
package_name = "/".join(package_split[:-1] + ["${platform}"])
for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]:
var = "${%s}" % v
if package_name.endswith(var):
package_name = package_name.replace(var,
"%s-expanded-test-only" % v)
return package_name
def ensure_file_resolve():
resolved = {"result": {}}
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file', required=True)
parser.add_argument('-json-output')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd("", f.readlines())
for path, packages in new_content.items():
resolved_packages = []
for package in packages:
package_name = expand_package_name_cmd(package.split(" ")[0])
resolved_packages.append({
"package": package_name,
"pin": {
"package": package_name,
"instance_id": package_name + "-fake-resolved-id",
}
})
resolved["result"][path] = resolved_packages
with io.open(args.json_output, 'w', encoding='utf-8') as f:
f.write(json.dumps(resolved, indent=4))
resolved = {"result": {}}
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file', required=True)
parser.add_argument('-json-output')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd("", f.readlines())
for path, packages in new_content.items():
resolved_packages = []
for package in packages:
package_name = expand_package_name_cmd(package.split(" ")[0])
resolved_packages.append({
"package": package_name,
"pin": {
"package": package_name,
"instance_id": package_name + "-fake-resolved-id",
}
})
resolved["result"][path] = resolved_packages
with io.open(args.json_output, 'w', encoding='utf-8') as f:
f.write(json.dumps(resolved, indent=4))
def describe_cmd(package_name):
parser = argparse.ArgumentParser()
parser.add_argument('-json-output')
parser.add_argument('-version', required=True)
args, _ = parser.parse_known_args()
json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute(
package=package_name)
cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute(package=package_name)
json_out = json.loads(json_template)
found = False
for tag in json_out['result']['tags']:
if tag['tag'] == args.version:
found = True
break
for tag in json_out['result']['refs']:
if tag['ref'] == args.version:
found = True
break
if found:
if args.json_output:
with io.open(args.json_output, 'w', encoding='utf-8') as f:
f.write(json.dumps(json_out, indent=4))
sys.stdout.write(cli_out)
return 0
sys.stdout.write('Error: no such ref.\n')
return 1
parser = argparse.ArgumentParser()
parser.add_argument('-json-output')
parser.add_argument('-version', required=True)
args, _ = parser.parse_known_args()
json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute(
package=package_name)
cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute(
package=package_name)
json_out = json.loads(json_template)
found = False
for tag in json_out['result']['tags']:
if tag['tag'] == args.version:
found = True
break
for tag in json_out['result']['refs']:
if tag['ref'] == args.version:
found = True
break
if found:
if args.json_output:
with io.open(args.json_output, 'w', encoding='utf-8') as f:
f.write(json.dumps(json_out, indent=4))
sys.stdout.write(cli_out)
return 0
sys.stdout.write('Error: no such ref.\n')
return 1
def main():
cmd = sys.argv[1]
assert cmd in [
CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG,
CIPD_EXPORT
]
# Handle cipd expand-package-name
if cmd == CIPD_EXPAND_PKG:
# Expecting argument after cmd
assert len(sys.argv) == 3
# Write result to stdout
sys.stdout.write(expand_package_name_cmd(sys.argv[2]))
return 0
if cmd == CIPD_DESCRIBE:
# Expecting argument after cmd
assert len(sys.argv) >= 3
return describe_cmd(sys.argv[2])
if cmd == CIPD_ENSURE_FILE_RESOLVE:
return ensure_file_resolve()
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file')
parser.add_argument('-root')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd(args.root, f.readlines())
# Install new packages
for path, packages in new_content.items():
if not os.path.exists(path):
os.makedirs(path)
with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f:
f.write('\n'.join(packages))
cmd = sys.argv[1]
assert cmd in [
CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG,
CIPD_EXPORT
]
# Handle cipd expand-package-name
if cmd == CIPD_EXPAND_PKG:
# Expecting argument after cmd
assert len(sys.argv) == 3
# Write result to stdout
sys.stdout.write(expand_package_name_cmd(sys.argv[2]))
return 0
if cmd == CIPD_DESCRIBE:
# Expecting argument after cmd
assert len(sys.argv) >= 3
return describe_cmd(sys.argv[2])
if cmd == CIPD_ENSURE_FILE_RESOLVE:
return ensure_file_resolve()
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file')
parser.add_argument('-root')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd(args.root, f.readlines())
# Install new packages
for path, packages in new_content.items():
if not os.path.exists(path):
os.makedirs(path)
with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f:
f.write('\n'.join(packages))
# Save the ensure file that we got
shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd'))
# Save the ensure file that we got
shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd'))
return 0
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

File diff suppressed because it is too large Load Diff

@ -10,98 +10,98 @@ from io import StringIO
def _RaiseNotFound(path):
raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT))
raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT))
class MockFileSystem(object):
"""Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock
"""Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock
Implements a filesystem-like interface on top of a dict of filenames ->
file contents. A file content value of None indicates that the file should
not exist (IOError will be raised if it is opened;
reading from a missing key raises a KeyError, not an IOError."""
def __init__(self, files=None):
self.files = files or {}
self.written_files = {}
self._sep = '/'
@property
def sep(self):
return self._sep
def abspath(self, path):
if path.endswith(self.sep):
return path[:-1]
return path
def basename(self, path):
if self.sep not in path:
return ''
return self.split(path)[-1] or self.sep
def dirname(self, path):
if self.sep not in path:
return ''
return self.split(path)[0] or self.sep
def exists(self, path):
return self.isfile(path) or self.isdir(path)
def isabs(self, path):
return path.startswith(self.sep)
def isfile(self, path):
return path in self.files and self.files[path] is not None
def isdir(self, path):
if path in self.files:
return False
if not path.endswith(self.sep):
path += self.sep
# We need to use a copy of the keys here in order to avoid switching
# to a different thread and potentially modifying the dict in
# mid-iteration.
files = list(self.files.keys())[:]
return any(f.startswith(path) for f in files)
def join(self, *comps):
# TODO: Might want tests for this and/or a better comment about how
# it works.
return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps))
def glob(self, path):
return fnmatch.filter(self.files.keys(), path)
def open_for_reading(self, path):
return StringIO(self.read_binary_file(path))
def normpath(self, path):
# This is not a complete implementation of normpath. Only covers what we
# use in tests.
result = []
for part in path.split(self.sep):
if part == '..':
result.pop()
elif part == '.':
continue
else:
result.append(part)
return self.sep.join(result)
def read_binary_file(self, path):
# Intentionally raises KeyError if we don't recognize the path.
if self.files[path] is None:
_RaiseNotFound(path)
return self.files[path]
def relpath(self, path, base):
# This implementation is wrong in many ways; assert to check them for now.
if not base.endswith(self.sep):
base += self.sep
assert path.startswith(base)
return path[len(base):]
def split(self, path):
return path.rsplit(self.sep, 1)
def __init__(self, files=None):
self.files = files or {}
self.written_files = {}
self._sep = '/'
@property
def sep(self):
return self._sep
def abspath(self, path):
if path.endswith(self.sep):
return path[:-1]
return path
def basename(self, path):
if self.sep not in path:
return ''
return self.split(path)[-1] or self.sep
def dirname(self, path):
if self.sep not in path:
return ''
return self.split(path)[0] or self.sep
def exists(self, path):
return self.isfile(path) or self.isdir(path)
def isabs(self, path):
return path.startswith(self.sep)
def isfile(self, path):
return path in self.files and self.files[path] is not None
def isdir(self, path):
if path in self.files:
return False
if not path.endswith(self.sep):
path += self.sep
# We need to use a copy of the keys here in order to avoid switching
# to a different thread and potentially modifying the dict in
# mid-iteration.
files = list(self.files.keys())[:]
return any(f.startswith(path) for f in files)
def join(self, *comps):
# TODO: Might want tests for this and/or a better comment about how
# it works.
return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps))
def glob(self, path):
return fnmatch.filter(self.files.keys(), path)
def open_for_reading(self, path):
return StringIO(self.read_binary_file(path))
def normpath(self, path):
# This is not a complete implementation of normpath. Only covers what we
# use in tests.
result = []
for part in path.split(self.sep):
if part == '..':
result.pop()
elif part == '.':
continue
else:
result.append(part)
return self.sep.join(result)
def read_binary_file(self, path):
# Intentionally raises KeyError if we don't recognize the path.
if self.files[path] is None:
_RaiseNotFound(path)
return self.files[path]
def relpath(self, path, base):
# This implementation is wrong in many ways; assert to check them for
# now.
if not base.endswith(self.sep):
base += self.sep
assert path.startswith(base)
return path[len(base):]
def split(self, path):
return path.rsplit(self.sep, 1)

@ -17,108 +17,107 @@ import unittest
import gclient_utils
DEFAULT_BRANCH = 'main'
def git_hash_data(data, typ='blob'):
"""Calculate the git-style SHA1 for some data.
"""Calculate the git-style SHA1 for some data.
Only supports 'blob' type data at the moment.
"""
assert typ == 'blob', 'Only support blobs for now'
return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest()
assert typ == 'blob', 'Only support blobs for now'
return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest()
class OrderedSet(collections.MutableSet):
# from http://code.activestate.com/recipes/576694/
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.data = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __contains__(self, key):
return key in self.data
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
def __ne__(self, other):
if isinstance(other, OrderedSet):
return len(self) != len(other) or list(self) != list(other)
return set(self) != set(other)
def __len__(self):
return len(self.data)
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def add(self, key):
if key not in self.data:
end = self.end
curr = end[1]
curr[2] = end[1] = self.data[key] = [key, curr, end]
def difference_update(self, *others):
for other in others:
for i in other:
self.discard(i)
def discard(self, key):
if key in self.data:
key, prev, nxt = self.data.pop(key)
prev[2] = nxt
nxt[1] = prev
def pop(self, last=True): # pylint: disable=arguments-differ
if not self:
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
# from http://code.activestate.com/recipes/576694/
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.data = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __contains__(self, key):
return key in self.data
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
def __ne__(self, other):
if isinstance(other, OrderedSet):
return len(self) != len(other) or list(self) != list(other)
return set(self) != set(other)
def __len__(self):
return len(self.data)
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__, )
return '%s(%r)' % (self.__class__.__name__, list(self))
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def add(self, key):
if key not in self.data:
end = self.end
curr = end[1]
curr[2] = end[1] = self.data[key] = [key, curr, end]
def difference_update(self, *others):
for other in others:
for i in other:
self.discard(i)
def discard(self, key):
if key in self.data:
key, prev, nxt = self.data.pop(key)
prev[2] = nxt
nxt[1] = prev
def pop(self, last=True): # pylint: disable=arguments-differ
if not self:
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
class UTC(datetime.tzinfo):
"""UTC time zone.
"""UTC time zone.
from https://docs.python.org/2/library/datetime.html#tzinfo-objects
"""
def utcoffset(self, dt):
return datetime.timedelta(0)
def utcoffset(self, dt):
return datetime.timedelta(0)
def tzname(self, dt):
return "UTC"
def tzname(self, dt):
return "UTC"
def dst(self, dt):
return datetime.timedelta(0)
def dst(self, dt):
return datetime.timedelta(0)
UTC = UTC()
class GitRepoSchema(object):
"""A declarative git testing repo.
"""A declarative git testing repo.
Pass a schema to __init__ in the form of:
A B C D
@ -141,11 +140,10 @@ class GitRepoSchema(object):
in the schema) get earlier timestamps. Stamps start at the Unix Epoch, and
increment by 1 day each.
"""
COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root')
COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root')
def __init__(self, repo_schema='',
content_fn=lambda v: {v: {'data': v}}):
"""Builds a new GitRepoSchema.
def __init__(self, repo_schema='', content_fn=lambda v: {v: {'data': v}}):
"""Builds a new GitRepoSchema.
Args:
repo_schema (str) - Initial schema for this repo. See class docstring for
@ -156,88 +154,88 @@ class GitRepoSchema(object):
commit_name). See the docstring on the GitRepo class for the format of
the data returned by this function.
"""
self.main = None
self.par_map = {}
self.data_cache = {}
self.content_fn = content_fn
self.add_commits(repo_schema)
self.main = None
self.par_map = {}
self.data_cache = {}
self.content_fn = content_fn
self.add_commits(repo_schema)
def walk(self):
"""(Generator) Walks the repo schema from roots to tips.
def walk(self):
"""(Generator) Walks the repo schema from roots to tips.
Generates GitRepoSchema.COMMIT objects for each commit.
Throws an AssertionError if it detects a cycle.
"""
is_root = True
par_map = copy.deepcopy(self.par_map)
while par_map:
empty_keys = set(k for k, v in par_map.items() if not v)
assert empty_keys, 'Cycle detected! %s' % par_map
for k in sorted(empty_keys):
yield self.COMMIT(k, self.par_map[k],
not any(k in v for v in self.par_map.values()),
is_root)
del par_map[k]
for v in par_map.values():
v.difference_update(empty_keys)
is_root = False
def add_partial(self, commit, parent=None):
if commit not in self.par_map:
self.par_map[commit] = OrderedSet()
if parent is not None:
self.par_map[commit].add(parent)
def add_commits(self, schema):
"""Adds more commits from a schema into the existing Schema.
is_root = True
par_map = copy.deepcopy(self.par_map)
while par_map:
empty_keys = set(k for k, v in par_map.items() if not v)
assert empty_keys, 'Cycle detected! %s' % par_map
for k in sorted(empty_keys):
yield self.COMMIT(
k, self.par_map[k],
not any(k in v for v in self.par_map.values()), is_root)
del par_map[k]
for v in par_map.values():
v.difference_update(empty_keys)
is_root = False
def add_partial(self, commit, parent=None):
if commit not in self.par_map:
self.par_map[commit] = OrderedSet()
if parent is not None:
self.par_map[commit].add(parent)
def add_commits(self, schema):
"""Adds more commits from a schema into the existing Schema.
Args:
schema (str) - See class docstring for info on schema format.
Throws an AssertionError if it detects a cycle.
"""
for commits in (l.split() for l in schema.splitlines() if l.strip()):
parent = None
for commit in commits:
self.add_partial(commit, parent)
parent = commit
if parent and not self.main:
self.main = parent
for _ in self.walk(): # This will throw if there are any cycles.
pass
def reify(self):
"""Returns a real GitRepo for this GitRepoSchema"""
return GitRepo(self)
def data_for(self, commit):
"""Obtains the data for |commit|.
for commits in (l.split() for l in schema.splitlines() if l.strip()):
parent = None
for commit in commits:
self.add_partial(commit, parent)
parent = commit
if parent and not self.main:
self.main = parent
for _ in self.walk(): # This will throw if there are any cycles.
pass
def reify(self):
"""Returns a real GitRepo for this GitRepoSchema"""
return GitRepo(self)
def data_for(self, commit):
"""Obtains the data for |commit|.
See the docstring on the GitRepo class for the format of the returned data.
Caches the result on this GitRepoSchema instance.
"""
if commit not in self.data_cache:
self.data_cache[commit] = self.content_fn(commit)
return self.data_cache[commit]
if commit not in self.data_cache:
self.data_cache[commit] = self.content_fn(commit)
return self.data_cache[commit]
def simple_graph(self):
"""Returns a dictionary of {commit_subject: {parent commit_subjects}}
def simple_graph(self):
"""Returns a dictionary of {commit_subject: {parent commit_subjects}}
This allows you to get a very simple connection graph over the whole repo
for comparison purposes. Only commit subjects (not ids, not content/data)
are considered
"""
ret = {}
for commit in self.walk():
ret.setdefault(commit.name, set()).update(commit.parents)
return ret
ret = {}
for commit in self.walk():
ret.setdefault(commit.name, set()).update(commit.parents)
return ret
class GitRepo(object):
"""Creates a real git repo for a GitRepoSchema.
"""Creates a real git repo for a GitRepoSchema.
Obtains schema and content information from the GitRepoSchema.
@ -260,26 +258,26 @@ class GitRepo(object):
For file content, if 'data' is None, then this commit will `git rm` that file.
"""
BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo')
atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR)
BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo')
atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR)
# Singleton objects to specify specific data in a commit dictionary.
AUTHOR_NAME = object()
AUTHOR_EMAIL = object()
AUTHOR_DATE = object()
COMMITTER_NAME = object()
COMMITTER_EMAIL = object()
COMMITTER_DATE = object()
# Singleton objects to specify specific data in a commit dictionary.
AUTHOR_NAME = object()
AUTHOR_EMAIL = object()
AUTHOR_DATE = object()
COMMITTER_NAME = object()
COMMITTER_EMAIL = object()
COMMITTER_DATE = object()
DEFAULT_AUTHOR_NAME = 'Author McAuthorly'
DEFAULT_AUTHOR_EMAIL = 'author@example.com'
DEFAULT_COMMITTER_NAME = 'Charles Committish'
DEFAULT_COMMITTER_EMAIL = 'commitish@example.com'
DEFAULT_AUTHOR_NAME = 'Author McAuthorly'
DEFAULT_AUTHOR_EMAIL = 'author@example.com'
DEFAULT_COMMITTER_NAME = 'Charles Committish'
DEFAULT_COMMITTER_EMAIL = 'commitish@example.com'
COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout')
COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout')
def __init__(self, schema):
"""Makes new GitRepo.
def __init__(self, schema):
"""Makes new GitRepo.
Automatically creates a temp folder under GitRepo.BASE_TEMP_DIR. It's
recommended that you clean this repo up by calling nuke() on it, but if not,
@ -289,194 +287,198 @@ class GitRepo(object):
Args:
schema - An instance of GitRepoSchema
"""
self.repo_path = os.path.realpath(tempfile.mkdtemp(dir=self.BASE_TEMP_DIR))
self.commit_map = {}
self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC)
self.repo_path = os.path.realpath(
tempfile.mkdtemp(dir=self.BASE_TEMP_DIR))
self.commit_map = {}
self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC)
self.to_schema_refs = ['--branches']
self.to_schema_refs = ['--branches']
self.git('init', '-b', DEFAULT_BRANCH)
self.git('config', 'user.name', 'testcase')
self.git('config', 'user.email', 'testcase@example.com')
for commit in schema.walk():
self._add_schema_commit(commit, schema.data_for(commit.name))
self.last_commit = self[commit.name]
if schema.main:
self.git('update-ref', 'refs/heads/main', self[schema.main])
self.git('init', '-b', DEFAULT_BRANCH)
self.git('config', 'user.name', 'testcase')
self.git('config', 'user.email', 'testcase@example.com')
for commit in schema.walk():
self._add_schema_commit(commit, schema.data_for(commit.name))
self.last_commit = self[commit.name]
if schema.main:
self.git('update-ref', 'refs/heads/main', self[schema.main])
def __getitem__(self, commit_name):
"""Gets the hash of a commit by its schema name.
def __getitem__(self, commit_name):
"""Gets the hash of a commit by its schema name.
>>> r = GitRepo(GitRepoSchema('A B C'))
>>> r['B']
'7381febe1da03b09da47f009963ab7998a974935'
"""
return self.commit_map[commit_name]
def _add_schema_commit(self, commit, commit_data):
commit_data = commit_data or {}
if commit.parents:
parents = list(commit.parents)
self.git('checkout', '--detach', '-q', self[parents[0]])
if len(parents) > 1:
self.git('merge', '--no-commit', '-q', *[self[x] for x in parents[1:]])
else:
self.git('checkout', '--orphan', 'root_%s' % commit.name)
self.git('rm', '-rf', '.')
env = self.get_git_commit_env(commit_data)
for fname, file_data in commit_data.items():
# If it isn't a string, it's one of the special keys.
if not isinstance(fname, str):
continue
deleted = False
if 'data' in file_data:
data = file_data.get('data')
if data is None:
deleted = True
self.git('rm', fname)
else:
path = os.path.join(self.repo_path, fname)
pardir = os.path.dirname(path)
if not os.path.exists(pardir):
os.makedirs(pardir)
with open(path, 'wb') as f:
f.write(data)
mode = file_data.get('mode')
if mode and not deleted:
os.chmod(path, mode)
self.git('add', fname)
rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env)
assert rslt.retcode == 0, 'Failed to commit %s' % str(commit)
self.commit_map[commit.name] = self.git('rev-parse', 'HEAD').stdout.strip()
self.git('tag', 'tag_%s' % commit.name, self[commit.name])
if commit.is_branch:
self.git('branch', '-f', 'branch_%s' % commit.name, self[commit.name])
def get_git_commit_env(self, commit_data=None):
commit_data = commit_data or {}
env = os.environ.copy()
for prefix in ('AUTHOR', 'COMMITTER'):
for suffix in ('NAME', 'EMAIL', 'DATE'):
singleton = '%s_%s' % (prefix, suffix)
key = getattr(self, singleton)
if key in commit_data:
val = commit_data[key]
elif suffix == 'DATE':
val = self._date
self._date += datetime.timedelta(days=1)
return self.commit_map[commit_name]
def _add_schema_commit(self, commit, commit_data):
commit_data = commit_data or {}
if commit.parents:
parents = list(commit.parents)
self.git('checkout', '--detach', '-q', self[parents[0]])
if len(parents) > 1:
self.git('merge', '--no-commit', '-q',
*[self[x] for x in parents[1:]])
else:
val = getattr(self, 'DEFAULT_%s' % singleton)
if not isinstance(val, str) and not isinstance(val, bytes):
val = str(val)
env['GIT_%s' % singleton] = val
return env
def git(self, *args, **kwargs):
"""Runs a git command specified by |args| in this repo."""
assert self.repo_path is not None
try:
with open(os.devnull, 'wb') as devnull:
shell = sys.platform == 'win32'
output = subprocess.check_output(
('git', ) + args,
shell=shell,
cwd=self.repo_path,
stderr=devnull,
**kwargs)
output = output.decode('utf-8')
return self.COMMAND_OUTPUT(0, output)
except subprocess.CalledProcessError as e:
return self.COMMAND_OUTPUT(e.returncode, e.output)
def show_commit(self, commit_name, format_string):
"""Shows a commit (by its schema name) with a given format string."""
return self.git('show', '-q', '--pretty=format:%s' % format_string,
self[commit_name]).stdout
def git_commit(self, message):
return self.git('commit', '-am', message, env=self.get_git_commit_env())
def nuke(self):
"""Obliterates the git repo on disk.
self.git('checkout', '--orphan', 'root_%s' % commit.name)
self.git('rm', '-rf', '.')
env = self.get_git_commit_env(commit_data)
for fname, file_data in commit_data.items():
# If it isn't a string, it's one of the special keys.
if not isinstance(fname, str):
continue
deleted = False
if 'data' in file_data:
data = file_data.get('data')
if data is None:
deleted = True
self.git('rm', fname)
else:
path = os.path.join(self.repo_path, fname)
pardir = os.path.dirname(path)
if not os.path.exists(pardir):
os.makedirs(pardir)
with open(path, 'wb') as f:
f.write(data)
mode = file_data.get('mode')
if mode and not deleted:
os.chmod(path, mode)
self.git('add', fname)
rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env)
assert rslt.retcode == 0, 'Failed to commit %s' % str(commit)
self.commit_map[commit.name] = self.git('rev-parse',
'HEAD').stdout.strip()
self.git('tag', 'tag_%s' % commit.name, self[commit.name])
if commit.is_branch:
self.git('branch', '-f', 'branch_%s' % commit.name,
self[commit.name])
def get_git_commit_env(self, commit_data=None):
commit_data = commit_data or {}
env = os.environ.copy()
for prefix in ('AUTHOR', 'COMMITTER'):
for suffix in ('NAME', 'EMAIL', 'DATE'):
singleton = '%s_%s' % (prefix, suffix)
key = getattr(self, singleton)
if key in commit_data:
val = commit_data[key]
elif suffix == 'DATE':
val = self._date
self._date += datetime.timedelta(days=1)
else:
val = getattr(self, 'DEFAULT_%s' % singleton)
if not isinstance(val, str) and not isinstance(val, bytes):
val = str(val)
env['GIT_%s' % singleton] = val
return env
def git(self, *args, **kwargs):
"""Runs a git command specified by |args| in this repo."""
assert self.repo_path is not None
try:
with open(os.devnull, 'wb') as devnull:
shell = sys.platform == 'win32'
output = subprocess.check_output(('git', ) + args,
shell=shell,
cwd=self.repo_path,
stderr=devnull,
**kwargs)
output = output.decode('utf-8')
return self.COMMAND_OUTPUT(0, output)
except subprocess.CalledProcessError as e:
return self.COMMAND_OUTPUT(e.returncode, e.output)
def show_commit(self, commit_name, format_string):
"""Shows a commit (by its schema name) with a given format string."""
return self.git('show', '-q', '--pretty=format:%s' % format_string,
self[commit_name]).stdout
def git_commit(self, message):
return self.git('commit', '-am', message, env=self.get_git_commit_env())
def nuke(self):
"""Obliterates the git repo on disk.
Causes this GitRepo to be unusable.
"""
gclient_utils.rmtree(self.repo_path)
self.repo_path = None
def run(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd set to
the git repo."""
assert self.repo_path is not None
curdir = os.getcwd()
try:
os.chdir(self.repo_path)
return fn(*args, **kwargs)
finally:
os.chdir(curdir)
def capture_stdio(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd set to
the git repo.
gclient_utils.rmtree(self.repo_path)
self.repo_path = None
def run(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd
set to the git repo."""
assert self.repo_path is not None
curdir = os.getcwd()
try:
os.chdir(self.repo_path)
return fn(*args, **kwargs)
finally:
os.chdir(curdir)
def capture_stdio(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd set
to the git repo.
Returns the (stdout, stderr) of whatever ran, instead of the what |fn|
returned.
"""
stdout = sys.stdout
stderr = sys.stderr
try:
with tempfile.TemporaryFile('w+') as out:
with tempfile.TemporaryFile('w+') as err:
sys.stdout = out
sys.stderr = err
try:
self.run(fn, *args, **kwargs)
except SystemExit:
pass
out.seek(0)
err.seek(0)
return out.read(), err.read()
finally:
sys.stdout = stdout
sys.stderr = stderr
def open(self, path, mode='rb'):
return open(os.path.join(self.repo_path, path), mode)
def to_schema(self):
lines = self.git('rev-list', '--parents', '--reverse', '--topo-order',
'--format=%s', *self.to_schema_refs).stdout.splitlines()
hash_to_msg = {}
ret = GitRepoSchema()
current = None
parents = []
for line in lines:
if line.startswith('commit'):
assert current is None
tokens = line.split()
current, parents = tokens[1], tokens[2:]
assert all(p in hash_to_msg for p in parents)
else:
assert current is not None
hash_to_msg[current] = line
ret.add_partial(line)
for parent in parents:
ret.add_partial(line, hash_to_msg[parent])
stdout = sys.stdout
stderr = sys.stderr
try:
with tempfile.TemporaryFile('w+') as out:
with tempfile.TemporaryFile('w+') as err:
sys.stdout = out
sys.stderr = err
try:
self.run(fn, *args, **kwargs)
except SystemExit:
pass
out.seek(0)
err.seek(0)
return out.read(), err.read()
finally:
sys.stdout = stdout
sys.stderr = stderr
def open(self, path, mode='rb'):
return open(os.path.join(self.repo_path, path), mode)
def to_schema(self):
lines = self.git('rev-list', '--parents', '--reverse', '--topo-order',
'--format=%s',
*self.to_schema_refs).stdout.splitlines()
hash_to_msg = {}
ret = GitRepoSchema()
current = None
parents = []
assert current is None
return ret
for line in lines:
if line.startswith('commit'):
assert current is None
tokens = line.split()
current, parents = tokens[1], tokens[2:]
assert all(p in hash_to_msg for p in parents)
else:
assert current is not None
hash_to_msg[current] = line
ret.add_partial(line)
for parent in parents:
ret.add_partial(line, hash_to_msg[parent])
current = None
parents = []
assert current is None
return ret
class GitRepoSchemaTestBase(unittest.TestCase):
"""A TestCase with a built-in GitRepoSchema.
"""A TestCase with a built-in GitRepoSchema.
Expects a class variable REPO_SCHEMA to be a GitRepoSchema string in the form
described by that class.
@ -487,61 +489,62 @@ class GitRepoSchemaTestBase(unittest.TestCase):
You probably will end up using either GitRepoReadOnlyTestBase or
GitRepoReadWriteTestBase for real tests.
"""
REPO_SCHEMA = None
REPO_SCHEMA = None
@classmethod
def getRepoContent(cls, commit):
commit = 'COMMIT_%s' % commit
return getattr(cls, commit, None)
@classmethod
def getRepoContent(cls, commit):
commit = 'COMMIT_%s' % commit
return getattr(cls, commit, None)
@classmethod
def setUpClass(cls):
super(GitRepoSchemaTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None
cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent)
@classmethod
def setUpClass(cls):
super(GitRepoSchemaTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None
cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent)
class GitRepoReadOnlyTestBase(GitRepoSchemaTestBase):
"""Injects a GitRepo object given the schema and content from
"""Injects a GitRepo object given the schema and content from
GitRepoSchemaTestBase into TestCase classes which subclass this.
This GitRepo will appear as self.repo, and will be deleted and recreated once
for the duration of all the tests in the subclass.
"""
REPO_SCHEMA = None
REPO_SCHEMA = None
@classmethod
def setUpClass(cls):
super(GitRepoReadOnlyTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None
cls.repo = cls.r_schema.reify()
@classmethod
def setUpClass(cls):
super(GitRepoReadOnlyTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None
cls.repo = cls.r_schema.reify()
def setUp(self):
self.repo.git('checkout', '-f', self.repo.last_commit)
def setUp(self):
self.repo.git('checkout', '-f', self.repo.last_commit)
@classmethod
def tearDownClass(cls):
cls.repo.nuke()
super(GitRepoReadOnlyTestBase, cls).tearDownClass()
@classmethod
def tearDownClass(cls):
cls.repo.nuke()
super(GitRepoReadOnlyTestBase, cls).tearDownClass()
class GitRepoReadWriteTestBase(GitRepoSchemaTestBase):
"""Injects a GitRepo object given the schema and content from
"""Injects a GitRepo object given the schema and content from
GitRepoSchemaTestBase into TestCase classes which subclass this.
This GitRepo will appear as self.repo, and will be deleted and recreated for
each test function in the subclass.
"""
REPO_SCHEMA = None
REPO_SCHEMA = None
def setUp(self):
super(GitRepoReadWriteTestBase, self).setUp()
self.repo = self.r_schema.reify()
def setUp(self):
super(GitRepoReadWriteTestBase, self).setUp()
self.repo = self.r_schema.reify()
def tearDown(self):
self.repo.nuke()
super(GitRepoReadWriteTestBase, self).tearDown()
def tearDown(self):
self.repo.nuke()
super(GitRepoReadWriteTestBase, self).tearDown()
def assertSchema(self, schema_string):
self.assertEqual(GitRepoSchema(schema_string).simple_graph(),
self.repo.to_schema().simple_graph())
def assertSchema(self, schema_string):
self.assertEqual(
GitRepoSchema(schema_string).simple_graph(),
self.repo.to_schema().simple_graph())

@ -15,10 +15,12 @@ from presubmit_canned_checks import _ReportErrorFileAndLine
class MockCannedChecks(object):
def _FindNewViolationsOfRule(self, callable_rule, input_api,
source_file_filter=None,
error_formatter=_ReportErrorFileAndLine):
"""Find all newly introduced violations of a per-line rule (a callable).
def _FindNewViolationsOfRule(self,
callable_rule,
input_api,
source_file_filter=None,
error_formatter=_ReportErrorFileAndLine):
"""Find all newly introduced violations of a per-line rule (a callable).
Arguments:
callable_rule: a callable taking a file extension and line of input and
@ -32,232 +34,246 @@ class MockCannedChecks(object):
Returns:
A list of the newly-introduced violations reported by the rule.
"""
errors = []
for f in input_api.AffectedFiles(include_deletes=False,
file_filter=source_file_filter):
# For speed, we do two passes, checking first the full file. Shelling out
# to the SCM to determine the changed region can be quite expensive on
# Win32. Assuming that most files will be kept problem-free, we can
# skip the SCM operations most of the time.
extension = str(f.LocalPath()).rsplit('.', 1)[-1]
if all(callable_rule(extension, line) for line in f.NewContents()):
continue # No violation found in full text: can skip considering diff.
for line_num, line in f.ChangedContents():
if not callable_rule(extension, line):
errors.append(error_formatter(f.LocalPath(), line_num, line))
return errors
errors = []
for f in input_api.AffectedFiles(include_deletes=False,
file_filter=source_file_filter):
# For speed, we do two passes, checking first the full file.
# Shelling out to the SCM to determine the changed region can be
# quite expensive on Win32. Assuming that most files will be kept
# problem-free, we can skip the SCM operations most of the time.
extension = str(f.LocalPath()).rsplit('.', 1)[-1]
if all(callable_rule(extension, line) for line in f.NewContents()):
# No violation found in full text: can skip considering diff.
continue
for line_num, line in f.ChangedContents():
if not callable_rule(extension, line):
errors.append(error_formatter(f.LocalPath(), line_num,
line))
return errors
class MockInputApi(object):
"""Mock class for the InputApi class.
"""Mock class for the InputApi class.
This class can be used for unittests for presubmit by initializing the files
attribute as the list of changed files.
"""
DEFAULT_FILES_TO_SKIP = ()
def __init__(self):
self.canned_checks = MockCannedChecks()
self.fnmatch = fnmatch
self.json = json
self.re = re
self.os_path = os.path
self.platform = sys.platform
self.python_executable = sys.executable
self.platform = sys.platform
self.subprocess = subprocess
self.sys = sys
self.files = []
self.is_committing = False
self.no_diffs = False
self.change = MockChange([])
self.presubmit_local_path = os.path.dirname(__file__)
self.logging = logging.getLogger('PRESUBMIT')
def CreateMockFileInPath(self, f_list):
self.os_path.exists = lambda x: x in f_list
def AffectedFiles(self, file_filter=None, include_deletes=True):
for file in self.files: # pylint: disable=redefined-builtin
if file_filter and not file_filter(file):
continue
if not include_deletes and file.Action() == 'D':
continue
yield file
def AffectedSourceFiles(self, file_filter=None):
return self.AffectedFiles(file_filter=file_filter)
def FilterSourceFile(self, file, # pylint: disable=redefined-builtin
files_to_check=(), files_to_skip=()):
local_path = file.LocalPath()
found_in_files_to_check = not files_to_check
if files_to_check:
if isinstance(files_to_check, str):
raise TypeError('files_to_check should be an iterable of strings')
for pattern in files_to_check:
compiled_pattern = re.compile(pattern)
if compiled_pattern.search(local_path):
found_in_files_to_check = True
break
if files_to_skip:
if isinstance(files_to_skip, str):
raise TypeError('files_to_skip should be an iterable of strings')
for pattern in files_to_skip:
compiled_pattern = re.compile(pattern)
if compiled_pattern.search(local_path):
return False
return found_in_files_to_check
def LocalPaths(self):
return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin
def PresubmitLocalPath(self):
return self.presubmit_local_path
def ReadFile(self, filename, mode='rU'):
if hasattr(filename, 'AbsoluteLocalPath'):
filename = filename.AbsoluteLocalPath()
for file_ in self.files:
if file_.LocalPath() == filename:
return '\n'.join(file_.NewContents())
# Otherwise, file is not in our mock API.
raise IOError("No such file or directory: '%s'" % filename)
DEFAULT_FILES_TO_SKIP = ()
def __init__(self):
self.canned_checks = MockCannedChecks()
self.fnmatch = fnmatch
self.json = json
self.re = re
self.os_path = os.path
self.platform = sys.platform
self.python_executable = sys.executable
self.platform = sys.platform
self.subprocess = subprocess
self.sys = sys
self.files = []
self.is_committing = False
self.no_diffs = False
self.change = MockChange([])
self.presubmit_local_path = os.path.dirname(__file__)
self.logging = logging.getLogger('PRESUBMIT')
def CreateMockFileInPath(self, f_list):
self.os_path.exists = lambda x: x in f_list
def AffectedFiles(self, file_filter=None, include_deletes=True):
for file in self.files: # pylint: disable=redefined-builtin
if file_filter and not file_filter(file):
continue
if not include_deletes and file.Action() == 'D':
continue
yield file
def AffectedSourceFiles(self, file_filter=None):
return self.AffectedFiles(file_filter=file_filter)
def FilterSourceFile(
self,
file, # pylint: disable=redefined-builtin
files_to_check=(),
files_to_skip=()):
local_path = file.LocalPath()
found_in_files_to_check = not files_to_check
if files_to_check:
if isinstance(files_to_check, str):
raise TypeError(
'files_to_check should be an iterable of strings')
for pattern in files_to_check:
compiled_pattern = re.compile(pattern)
if compiled_pattern.search(local_path):
found_in_files_to_check = True
break
if files_to_skip:
if isinstance(files_to_skip, str):
raise TypeError(
'files_to_skip should be an iterable of strings')
for pattern in files_to_skip:
compiled_pattern = re.compile(pattern)
if compiled_pattern.search(local_path):
return False
return found_in_files_to_check
def LocalPaths(self):
return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin
def PresubmitLocalPath(self):
return self.presubmit_local_path
def ReadFile(self, filename, mode='rU'):
if hasattr(filename, 'AbsoluteLocalPath'):
filename = filename.AbsoluteLocalPath()
for file_ in self.files:
if file_.LocalPath() == filename:
return '\n'.join(file_.NewContents())
# Otherwise, file is not in our mock API.
raise IOError("No such file or directory: '%s'" % filename)
class MockOutputApi(object):
"""Mock class for the OutputApi class.
"""Mock class for the OutputApi class.
An instance of this class can be passed to presubmit unittests for outputing
various types of results.
"""
class PresubmitResult(object):
def __init__(self, message, items=None, long_text=''):
self.message = message
self.items = items
self.long_text = long_text
def __repr__(self):
return self.message
class PresubmitError(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'error'
class PresubmitPromptWarning(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'warning'
class PresubmitNotifyResult(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'notify'
class PresubmitPromptOrNotify(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'promptOrNotify'
def __init__(self):
self.more_cc = []
def AppendCC(self, more_cc):
self.more_cc.extend(more_cc)
class PresubmitResult(object):
def __init__(self, message, items=None, long_text=''):
self.message = message
self.items = items
self.long_text = long_text
def __repr__(self):
return self.message
class PresubmitError(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'error'
class PresubmitPromptWarning(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'warning'
class PresubmitNotifyResult(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'notify'
class PresubmitPromptOrNotify(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'promptOrNotify'
def __init__(self):
self.more_cc = []
def AppendCC(self, more_cc):
self.more_cc.extend(more_cc)
class MockFile(object):
"""Mock class for the File class.
"""Mock class for the File class.
This class can be used to form the mock list of changed files in
MockInputApi for presubmit unittests.
"""
def __init__(self, local_path, new_contents, old_contents=None, action='A',
scm_diff=None):
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l) for i, l in enumerate(new_contents)]
self._action = action
if scm_diff:
self._scm_diff = scm_diff
else:
self._scm_diff = (
"--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" %
(local_path, len(new_contents)))
for l in new_contents:
self._scm_diff += "+%s\n" % l
self._old_contents = old_contents
def Action(self):
return self._action
def ChangedContents(self):
return self._changed_contents
def NewContents(self):
return self._new_contents
def LocalPath(self):
return self._local_path
def AbsoluteLocalPath(self):
return self._local_path
def GenerateScmDiff(self):
return self._scm_diff
def OldContents(self):
return self._old_contents
def rfind(self, p):
"""os.path.basename is called on MockFile so we need an rfind method."""
return self._local_path.rfind(p)
def __getitem__(self, i):
"""os.path.basename is called on MockFile so we need a get method."""
return self._local_path[i]
def __len__(self):
"""os.path.basename is called on MockFile so we need a len method."""
return len(self._local_path)
def replace(self, altsep, sep):
"""os.path.basename is called on MockFile so we need a replace method."""
return self._local_path.replace(altsep, sep)
def __init__(self,
local_path,
new_contents,
old_contents=None,
action='A',
scm_diff=None):
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l)
for i, l in enumerate(new_contents)]
self._action = action
if scm_diff:
self._scm_diff = scm_diff
else:
self._scm_diff = ("--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" %
(local_path, len(new_contents)))
for l in new_contents:
self._scm_diff += "+%s\n" % l
self._old_contents = old_contents
def Action(self):
return self._action
def ChangedContents(self):
return self._changed_contents
def NewContents(self):
return self._new_contents
def LocalPath(self):
return self._local_path
def AbsoluteLocalPath(self):
return self._local_path
def GenerateScmDiff(self):
return self._scm_diff
def OldContents(self):
return self._old_contents
def rfind(self, p):
"""os.path.basename is used on MockFile so we need an rfind method."""
return self._local_path.rfind(p)
def __getitem__(self, i):
"""os.path.basename is used on MockFile so we need a get method."""
return self._local_path[i]
def __len__(self):
"""os.path.basename is used on MockFile so we need a len method."""
return len(self._local_path)
def replace(self, altsep, sep):
"""os.path.basename is used on MockFile so we need a replace method."""
return self._local_path.replace(altsep, sep)
class MockAffectedFile(MockFile):
def AbsoluteLocalPath(self):
return self._local_path
def AbsoluteLocalPath(self):
return self._local_path
class MockChange(object):
"""Mock class for Change class.
"""Mock class for Change class.
This class can be used in presubmit unittests to mock the query of the
current change.
"""
def __init__(self, changed_files, description=''):
self._changed_files = changed_files
self.footers = defaultdict(list)
self._description = description
def __init__(self, changed_files, description=''):
self._changed_files = changed_files
self.footers = defaultdict(list)
self._description = description
def LocalPaths(self):
return self._changed_files
def LocalPaths(self):
return self._changed_files
def AffectedFiles(self, include_dirs=False, include_deletes=True,
file_filter=None):
return self._changed_files
def AffectedFiles(self,
include_dirs=False,
include_deletes=True,
file_filter=None):
return self._changed_files
def GitFootersFromDescription(self):
return self.footers
def GitFootersFromDescription(self):
return self.footers
def DescriptionText(self):
return self._description
def DescriptionText(self):
return self._description

@ -2,7 +2,6 @@
# Copyright (c) 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Script used to test subprocess2."""
import optparse
@ -10,51 +9,52 @@ import os
import sys
import time
if sys.platform == 'win32':
# Annoying, make sure the output is not translated on Windows.
# pylint: disable=no-member,import-error
import msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY)
# Annoying, make sure the output is not translated on Windows.
# pylint: disable=no-member,import-error
import msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY)
parser = optparse.OptionParser()
parser.add_option(
'--fail',
dest='return_value',
action='store_const',
default=0,
const=64)
parser.add_option(
'--crlf', action='store_const', const='\r\n', dest='eol', default='\n')
parser.add_option(
'--cr', action='store_const', const='\r', dest='eol')
parser.add_option('--fail',
dest='return_value',
action='store_const',
default=0,
const=64)
parser.add_option('--crlf',
action='store_const',
const='\r\n',
dest='eol',
default='\n')
parser.add_option('--cr', action='store_const', const='\r', dest='eol')
parser.add_option('--stdout', action='store_true')
parser.add_option('--stderr', action='store_true')
parser.add_option('--read', action='store_true')
options, args = parser.parse_args()
if args:
parser.error('Internal error')
parser.error('Internal error')
def do(string):
if options.stdout:
sys.stdout.buffer.write(string.upper().encode('utf-8'))
sys.stdout.buffer.write(options.eol.encode('utf-8'))
if options.stderr:
sys.stderr.buffer.write(string.lower().encode('utf-8'))
sys.stderr.buffer.write(options.eol.encode('utf-8'))
sys.stderr.flush()
if options.stdout:
sys.stdout.buffer.write(string.upper().encode('utf-8'))
sys.stdout.buffer.write(options.eol.encode('utf-8'))
if options.stderr:
sys.stderr.buffer.write(string.lower().encode('utf-8'))
sys.stderr.buffer.write(options.eol.encode('utf-8'))
sys.stderr.flush()
do('A')
do('BB')
do('CCC')
if options.read:
assert options.return_value == 0
try:
while sys.stdin.read(1):
options.return_value += 1
except OSError:
pass
assert options.return_value == 0
try:
while sys.stdin.read(1):
options.return_value += 1
except OSError:
pass
sys.exit(options.return_value)

@ -1,7 +1,6 @@
# Copyright (c) 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Simplify unit tests based on pymox."""
from __future__ import print_function
@ -12,51 +11,52 @@ import string
class TestCaseUtils(object):
"""Base class with some additional functionalities. People will usually want
"""Base class with some additional functionalities. People will usually want
to use SuperMoxTestBase instead."""
# Backup the separator in case it gets mocked
_OS_SEP = os.sep
_RANDOM_CHOICE = random.choice
_RANDOM_RANDINT = random.randint
_STRING_LETTERS = string.ascii_letters
## Some utilities for generating arbitrary arguments.
def String(self, max_length):
return ''.join([self._RANDOM_CHOICE(self._STRING_LETTERS)
for _ in range(self._RANDOM_RANDINT(1, max_length))])
def Strings(self, max_arg_count, max_arg_length):
return [self.String(max_arg_length) for _ in range(max_arg_count)]
def Args(self, max_arg_count=8, max_arg_length=16):
return self.Strings(max_arg_count,
self._RANDOM_RANDINT(1, max_arg_length))
def _DirElts(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length))
def Dir(self, max_elt_count=4, max_elt_length=8):
return (self._RANDOM_CHOICE((self._OS_SEP, '')) +
self._DirElts(max_elt_count, max_elt_length))
def RootDir(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length)
def compareMembers(self, obj, members):
"""If you add a member, be sure to add the relevant test!"""
# Skip over members starting with '_' since they are usually not meant to
# be for public use.
actual_members = [x for x in sorted(dir(obj))
if not x.startswith('_')]
expected_members = sorted(members)
if actual_members != expected_members:
diff = ([i for i in actual_members if i not in expected_members] +
[i for i in expected_members if i not in actual_members])
print(diff, file=sys.stderr)
# pylint: disable=no-member
self.assertEqual(actual_members, expected_members)
def setUp(self):
self.root_dir = self.Dir()
self.args = self.Args()
self.relpath = self.String(200)
# Backup the separator in case it gets mocked
_OS_SEP = os.sep
_RANDOM_CHOICE = random.choice
_RANDOM_RANDINT = random.randint
_STRING_LETTERS = string.ascii_letters
## Some utilities for generating arbitrary arguments.
def String(self, max_length):
return ''.join([
self._RANDOM_CHOICE(self._STRING_LETTERS)
for _ in range(self._RANDOM_RANDINT(1, max_length))
])
def Strings(self, max_arg_count, max_arg_length):
return [self.String(max_arg_length) for _ in range(max_arg_count)]
def Args(self, max_arg_count=8, max_arg_length=16):
return self.Strings(max_arg_count,
self._RANDOM_RANDINT(1, max_arg_length))
def _DirElts(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length))
def Dir(self, max_elt_count=4, max_elt_length=8):
return (self._RANDOM_CHOICE(
(self._OS_SEP, '')) + self._DirElts(max_elt_count, max_elt_length))
def RootDir(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length)
def compareMembers(self, obj, members):
"""If you add a member, be sure to add the relevant test!"""
# Skip over members starting with '_' since they are usually not meant
# to be for public use.
actual_members = [x for x in sorted(dir(obj)) if not x.startswith('_')]
expected_members = sorted(members)
if actual_members != expected_members:
diff = ([i for i in actual_members if i not in expected_members] +
[i for i in expected_members if i not in actual_members])
print(diff, file=sys.stderr)
# pylint: disable=no-member
self.assertEqual(actual_members, expected_members)
def setUp(self):
self.root_dir = self.Dir()
self.args = self.Args()
self.relpath = self.String(200)

@ -15,83 +15,84 @@ import gclient_utils
class TrialDir(object):
"""Manages a temporary directory.
"""Manages a temporary directory.
On first object creation, TrialDir.TRIAL_ROOT will be set to a new temporary
directory created in /tmp or the equivalent. It will be deleted on process
exit unless TrialDir.SHOULD_LEAK is set to True.
"""
# When SHOULD_LEAK is set to True, temporary directories created while the
# tests are running aren't deleted at the end of the tests. Expect failures
# when running more than one test due to inter-test side-effects. Helps with
# debugging.
SHOULD_LEAK = False
# Main root directory.
TRIAL_ROOT = None
def __init__(self, subdir, leak=False):
self.leak = self.SHOULD_LEAK or leak
self.subdir = subdir
self.root_dir = None
def set_up(self):
"""All late initialization comes here."""
# You can override self.TRIAL_ROOT.
if not self.TRIAL_ROOT:
# Was not yet initialized.
TrialDir.TRIAL_ROOT = os.path.realpath(tempfile.mkdtemp(prefix='trial'))
atexit.register(self._clean)
self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir)
gclient_utils.rmtree(self.root_dir)
os.makedirs(self.root_dir)
def tear_down(self):
"""Cleans the trial subdirectory for this instance."""
if not self.leak:
logging.debug('Removing %s' % self.root_dir)
gclient_utils.rmtree(self.root_dir)
else:
logging.error('Leaking %s' % self.root_dir)
self.root_dir = None
@staticmethod
def _clean():
"""Cleans the root trial directory."""
if not TrialDir.SHOULD_LEAK:
logging.debug('Removing %s' % TrialDir.TRIAL_ROOT)
gclient_utils.rmtree(TrialDir.TRIAL_ROOT)
else:
logging.error('Leaking %s' % TrialDir.TRIAL_ROOT)
# When SHOULD_LEAK is set to True, temporary directories created while the
# tests are running aren't deleted at the end of the tests. Expect failures
# when running more than one test due to inter-test side-effects. Helps with
# debugging.
SHOULD_LEAK = False
# Main root directory.
TRIAL_ROOT = None
def __init__(self, subdir, leak=False):
self.leak = self.SHOULD_LEAK or leak
self.subdir = subdir
self.root_dir = None
def set_up(self):
"""All late initialization comes here."""
# You can override self.TRIAL_ROOT.
if not self.TRIAL_ROOT:
# Was not yet initialized.
TrialDir.TRIAL_ROOT = os.path.realpath(
tempfile.mkdtemp(prefix='trial'))
atexit.register(self._clean)
self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir)
gclient_utils.rmtree(self.root_dir)
os.makedirs(self.root_dir)
def tear_down(self):
"""Cleans the trial subdirectory for this instance."""
if not self.leak:
logging.debug('Removing %s' % self.root_dir)
gclient_utils.rmtree(self.root_dir)
else:
logging.error('Leaking %s' % self.root_dir)
self.root_dir = None
@staticmethod
def _clean():
"""Cleans the root trial directory."""
if not TrialDir.SHOULD_LEAK:
logging.debug('Removing %s' % TrialDir.TRIAL_ROOT)
gclient_utils.rmtree(TrialDir.TRIAL_ROOT)
else:
logging.error('Leaking %s' % TrialDir.TRIAL_ROOT)
class TrialDirMixIn(object):
def setUp(self):
# Create a specific directory just for the test.
self.trial = TrialDir(self.id())
self.trial.set_up()
def setUp(self):
# Create a specific directory just for the test.
self.trial = TrialDir(self.id())
self.trial.set_up()
def tearDown(self):
self.trial.tear_down()
def tearDown(self):
self.trial.tear_down()
@property
def root_dir(self):
return self.trial.root_dir
@property
def root_dir(self):
return self.trial.root_dir
class TestCase(unittest.TestCase, TrialDirMixIn):
"""Base unittest class that cleans off a trial directory in tearDown()."""
def setUp(self):
unittest.TestCase.setUp(self)
TrialDirMixIn.setUp(self)
"""Base unittest class that cleans off a trial directory in tearDown()."""
def setUp(self):
unittest.TestCase.setUp(self)
TrialDirMixIn.setUp(self)
def tearDown(self):
TrialDirMixIn.tearDown(self)
unittest.TestCase.tearDown(self)
def tearDown(self):
TrialDirMixIn.tearDown(self)
unittest.TestCase.tearDown(self)
if '-l' in sys.argv:
# See SHOULD_LEAK definition in TrialDir for its purpose.
TrialDir.SHOULD_LEAK = True
print('Leaking!')
sys.argv.remove('-l')
# See SHOULD_LEAK definition in TrialDir for its purpose.
TrialDir.SHOULD_LEAK = True
print('Leaking!')
sys.argv.remove('-l')

Loading…
Cancel
Save