From 8f2cf6cd84f54a1db8bdd51c93253cd8e41d6d9d Mon Sep 17 00:00:00 2001 From: Lars Wirzenius Date: Sat, 18 Mar 2017 19:42:28 +0200 Subject: Add default value to .get; fix vars.yaml location --- yarnhelper.py | 20 ++++++++++++-------- yarnhelper_tests.py | 7 ++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/yarnhelper.py b/yarnhelper.py index 8f9da19..4ed2c03 100644 --- a/yarnhelper.py +++ b/yarnhelper.py @@ -26,6 +26,7 @@ import requests import yaml +datadir = os.environ.get('DATADIR', '.') variables_filename = os.environ.get('VARIABLES', 'vars.yaml') @@ -34,6 +35,7 @@ class YarnHelper(object): def __init__(self): self._env = dict(os.environ) self._next_match = 1 + self._filename = os.path.join(datadir, variables_filename) self._variables = None # None means not loaded, otherwise dict def set_environment(self, env): @@ -46,17 +48,19 @@ class YarnHelper(object): self._next_match += 1 return self._env[name] - def get_variable(self, name): + def get_variable(self, name, default=None): if self._variables is None: self._variables = self._load_variables() - if name not in self._variables: - raise Error('no variable {}'.format(name)) - return self._variables[name] + assert self._variables is not None + return self._variables.get(name, default) def _load_variables(self): - if os.path.exists(variables_filename): - with open(variables_filename, 'r') as f: - return yaml.safe_load(f) + if os.path.exists(self._filename): + with open(self._filename, 'r') as f: + data = f.read() + if data: + f.seek(0) + return yaml.safe_load(f) return {} def set_variable(self, name, value): @@ -66,7 +70,7 @@ class YarnHelper(object): self._save_variables(self._variables) def _save_variables(self, variables): - with open(variables_filename, 'w') as f: + with open(self._filename, 'w') as f: yaml.safe_dump(variables, f) def construct_aliased_http_request( diff --git a/yarnhelper_tests.py b/yarnhelper_tests.py index 756a165..858bcdc 100644 --- a/yarnhelper_tests.py +++ b/yarnhelper_tests.py @@ -66,12 +66,9 @@ class PersistentVariableTests(unittest.TestCase): if os.path.exists(yarnhelper.variables_filename): os.remove(yarnhelper.variables_filename) - def test_raises_error_if_no_such_variable(self): + def test_returns_default_if_no_such_variable(self): h = yarnhelper.YarnHelper() - with self.assertRaises(yarnhelper.Error): - h.get_variable('FOO') - print - print 'variables:', h._variables + self.assertEqual(h.get_variable('foo', default=42), 42) def test_sets_variable_persistently(self): h = yarnhelper.YarnHelper() -- cgit v1.2.1