diff --git a/gclient.py b/gclient.py index 92c477f1d6..2b4ada77bc 100755 --- a/gclient.py +++ b/gclient.py @@ -2971,7 +2971,10 @@ def CMDsetdep(parser, args): if not name or not value: raise gclient_utils.Error( 'Wrong var format: %s should be of the form name=value.' % var) - gclient_eval.SetVar(local_scope, name, value) + if name in local_scope['vars']: + gclient_eval.SetVar(local_scope, name, value) + else: + gclient_eval.AddVar(local_scope, name, value) for revision in options.revisions: name, _, value = revision.partition('@') diff --git a/gclient_eval.py b/gclient_eval.py index 92a165ff6b..703787501a 100644 --- a/gclient_eval.py +++ b/gclient_eval.py @@ -34,6 +34,21 @@ class _NodeDict(collections.MutableMapping): def __len__(self): return len(self.data) + def MoveTokens(self, origin, delta): + if self.tokens: + new_tokens = {} + for pos, token in self.tokens.iteritems(): + if pos[0] >= origin: + pos = (pos[0] + delta, pos[1]) + token = token[:2] + (pos,) + token[3:] + new_tokens[pos] = token + + for value, node in self.data.values(): + if node.lineno >= origin: + node.lineno += delta + if isinstance(value, _NodeDict): + value.MoveTokens(origin, delta) + def GetNode(self, key): return self.data[key][1] @@ -500,6 +515,69 @@ def _UpdateAstString(tokens, node, value): node.s = value +def _ShiftLinesInTokens(tokens, delta, start): + new_tokens = {} + for token in tokens.values(): + if token[2][0] >= start: + token[2] = token[2][0] + delta, token[2][1] + token[3] = token[3][0] + delta, token[3][1] + new_tokens[token[2]] = token + return new_tokens + + +def AddVar(gclient_dict, var_name, value): + if not isinstance(gclient_dict, _NodeDict) or gclient_dict.tokens is None: + raise ValueError( + "Can't use SetVar for the given gclient dict. It contains no " + "formatting information.") + + if 'vars' not in gclient_dict: + raise KeyError("vars dict is not defined.") + + if var_name in gclient_dict['vars']: + raise ValueError( + "%s has already been declared in the vars dict. Consider using SetVar " + "instead." % var_name) + + if not gclient_dict['vars']: + raise ValueError('vars dict is empty. This is not yet supported.') + + # We will attempt to add the var right before the first var. + node = gclient_dict.GetNode('vars').keys[0] + if node is None: + raise ValueError( + "The vars dict has no formatting information." % var_name) + line = node.lineno + col = node.col_offset + + # We use a minimal Python dictionary, so that ast can parse it. + var_content = '{\n%s"%s": "%s",\n}' % (' ' * col, var_name, value) + var_ast = ast.parse(var_content).body[0].value + + # Set the ast nodes for the key and value. + vars_node = gclient_dict.GetNode('vars') + + var_name_node = var_ast.keys[0] + var_name_node.lineno += line - 2 + vars_node.keys.insert(0, var_name_node) + + value_node = var_ast.values[0] + value_node.lineno += line - 2 + vars_node.values.insert(0, value_node) + + # Update the tokens. + var_tokens = list(tokenize.generate_tokens( + cStringIO.StringIO(var_content).readline)) + var_tokens = { + token[2]: list(token) + # Ignore the tokens corresponding to braces and new lines. + for token in var_tokens[2:-2] + } + + gclient_dict.tokens = _ShiftLinesInTokens(gclient_dict.tokens, 1, line) + gclient_dict.tokens.update(_ShiftLinesInTokens(var_tokens, line - 2, 0)) + + def SetVar(gclient_dict, var_name, value): if not isinstance(gclient_dict, _NodeDict) or gclient_dict.tokens is None: raise ValueError( @@ -507,9 +585,13 @@ def SetVar(gclient_dict, var_name, value): "formatting information.") tokens = gclient_dict.tokens - if 'vars' not in gclient_dict or var_name not in gclient_dict['vars']: + if 'vars' not in gclient_dict: + raise KeyError("vars dict is not defined.") + + if var_name not in gclient_dict['vars']: raise ValueError( - "Could not find any variable called %s." % var_name) + "%s has not been declared in the vars dict. Consider using AddVar " + "instead." % var_name) node = gclient_dict['vars'].GetNode(var_name) if node is None: @@ -528,7 +610,7 @@ def SetCIPD(gclient_dict, dep_name, package_name, new_version): tokens = gclient_dict.tokens if 'deps' not in gclient_dict or dep_name not in gclient_dict['deps']: - raise ValueError( + raise KeyError( "Could not find any dependency called %s." % dep_name) # Find the package with the given name @@ -593,7 +675,7 @@ def SetRevision(gclient_dict, dep_name, new_revision): tokens = gclient_dict.tokens if 'deps' not in gclient_dict or dep_name not in gclient_dict['deps']: - raise ValueError( + raise KeyError( "Could not find any dependency called %s." % dep_name) if isinstance(gclient_dict['deps'][dep_name], _NodeDict): diff --git a/tests/gclient_eval_unittest.py b/tests/gclient_eval_unittest.py index 93a07e0441..4fffbfeb09 100755 --- a/tests/gclient_eval_unittest.py +++ b/tests/gclient_eval_unittest.py @@ -236,6 +236,79 @@ class EvaluateConditionTest(unittest.TestCase): str(cm.exception)) +class AddVarTest(unittest.TestCase): + def test_adds_var(self): + local_scope = gclient_eval.Exec('\n'.join([ + 'vars = {', + ' "foo": "bar",', + '}', + ])) + + gclient_eval.AddVar(local_scope, 'baz', 'lemur') + result = gclient_eval.RenderDEPSFile(local_scope) + + self.assertEqual(result, '\n'.join([ + 'vars = {', + ' "baz": "lemur",', + ' "foo": "bar",', + '}', + ])) + + def test_adds_var_twice(self): + local_scope = gclient_eval.Exec('\n'.join([ + 'vars = {', + ' "foo": "bar",', + '}', + ])) + + gclient_eval.AddVar(local_scope, 'baz', 'lemur') + gclient_eval.AddVar(local_scope, 'v8_revision', 'deadbeef') + result = gclient_eval.RenderDEPSFile(local_scope) + + self.assertEqual(result, '\n'.join([ + 'vars = {', + ' "v8_revision": "deadbeef",', + ' "baz": "lemur",', + ' "foo": "bar",', + '}', + ])) + + def test_preserves_formatting(self): + local_scope = gclient_eval.Exec('\n'.join([ + '# Copyright stuff', + '# some initial comments', + '', + 'vars = { ', + ' "foo": "bar",', + ' # Some commets.', + ' # More comments.', + ' # Even more comments.', + ' "v8_revision": ', + ' "deadbeef",', + ' # Someone formatted this wrong', + '}', + ])) + + gclient_eval.AddVar(local_scope, 'baz', 'lemur') + result = gclient_eval.RenderDEPSFile(local_scope) + + self.assertEqual(result, '\n'.join([ + '# Copyright stuff', + '# some initial comments', + '', + 'vars = { ', + ' "baz": "lemur",', + ' "foo": "bar",', + ' # Some commets.', + ' # More comments.', + ' # Even more comments.', + ' "v8_revision": ', + ' "deadbeef",', + ' # Someone formatted this wrong', + '}', + ])) + + class SetVarTest(unittest.TestCase): def test_sets_var(self): local_scope = gclient_eval.Exec('\n'.join([ @@ -272,7 +345,6 @@ class SetVarTest(unittest.TestCase): ])) - class SetCipdTest(unittest.TestCase): def test_sets_cipd(self): local_scope = gclient_eval.Exec('\n'.join([