diff --git a/presubmit_support.py b/presubmit_support.py index d30240d5f..546e16ed5 100755 --- a/presubmit_support.py +++ b/presubmit_support.py @@ -35,6 +35,7 @@ import unittest # Exposed through the API. import urllib.parse as urlparse import urllib.request as urllib_request import urllib.error as urllib_error +from typing import Mapping from warnings import warn # Local imports. @@ -1798,12 +1799,7 @@ def DoPresubmitChecks(change, Return: 1 if presubmit checks failed or 0 otherwise. """ - old_environ = os.environ - try: - # Make sure python subprocesses won't generate .pyc files. - os.environ = os.environ.copy() - os.environ['PYTHONDONTWRITEBYTECODE'] = '1' - + with setup_environ({'PYTHONDONTWRITEBYTECODE': '1'}): python_version = 'Python %s' % sys.version_info.major if committing: sys.stdout.write('Running %s presubmit commit checks ...\n' % @@ -1910,8 +1906,6 @@ def DoPresubmitChecks(change, _ASKED_FOR_FEEDBACK = True return 1 if presubmits_failed else 0 - finally: - os.environ = old_environ def _scan_sub_dirs(mask, recursive): @@ -2020,6 +2014,27 @@ def _parse_gerrit_options(parser, options): return gerrit_obj +@contextlib.contextmanager +def setup_environ(kv: Mapping[str, str]): + """Update environment while in context, and reset back to original on exit. + + Example usage: + with setup_environ({"key": "value"}): + # os.environ now has key set to value. + pass + """ + old_kv = {} + for k, v in kv.items(): + old_kv[k] = os.environ.get(k, None) + os.environ[k] = v + yield + for k, v in old_kv.items(): + if v: + os.environ[k] = v + else: + os.environ.pop(k, None) + + @contextlib.contextmanager def canned_check_filter(method_names): filtered = {} diff --git a/tests/presubmit_support_test.py b/tests/presubmit_support_test.py new file mode 100755 index 000000000..59d0c18be --- /dev/null +++ b/tests/presubmit_support_test.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright 2024 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import os.path +import subprocess +import sys +import unittest + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, ROOT_DIR) + +import presubmit_support + + +class PresubmitSupportTest(unittest.TestCase): + def test_environ(self): + self.assertIsNone(os.environ.get('PRESUBMIT_FOO_ENV', None)) + kv = {'PRESUBMIT_FOO_ENV': 'FOOBAR'} + with presubmit_support.setup_environ(kv): + self.assertEqual(os.environ.get('PRESUBMIT_FOO_ENV', None), + 'FOOBAR') + self.assertIsNone(os.environ.get('PRESUBMIT_FOO_ENV', None)) + + +if __name__ == "__main__": + unittest.main()