summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Wirzenius <liw@liw.fi>2017-03-18 19:42:28 +0200
committerLars Wirzenius <liw@liw.fi>2017-03-18 19:42:28 +0200
commit8f2cf6cd84f54a1db8bdd51c93253cd8e41d6d9d (patch)
treeb1e626ed3f1369f90ed93ecd88b417e3bff06d38
parentfc757194e1c4e2596678ba8d3b5269280fa12b40 (diff)
downloadserver-yarns-8f2cf6cd84f54a1db8bdd51c93253cd8e41d6d9d.tar.gz
Add default value to .get; fix vars.yaml location
-rw-r--r--yarnhelper.py20
-rw-r--r--yarnhelper_tests.py7
2 files changed, 14 insertions, 13 deletions
diff --git a/yarnhelper.py b/yarnhelper.py
index 8f9da19..4ed2c03 100644
--- a/yarnhelper.py
+++ b/yarnhelper.py
@@ -26,6 +26,7 @@ import requests
import yaml
+datadir = os.environ.get('DATADIR', '.')
variables_filename = os.environ.get('VARIABLES', 'vars.yaml')
@@ -34,6 +35,7 @@ class YarnHelper(object):
def __init__(self):
self._env = dict(os.environ)
self._next_match = 1
+ self._filename = os.path.join(datadir, variables_filename)
self._variables = None # None means not loaded, otherwise dict
def set_environment(self, env):
@@ -46,17 +48,19 @@ class YarnHelper(object):
self._next_match += 1
return self._env[name]
- def get_variable(self, name):
+ def get_variable(self, name, default=None):
if self._variables is None:
self._variables = self._load_variables()
- if name not in self._variables:
- raise Error('no variable {}'.format(name))
- return self._variables[name]
+ assert self._variables is not None
+ return self._variables.get(name, default)
def _load_variables(self):
- if os.path.exists(variables_filename):
- with open(variables_filename, 'r') as f:
- return yaml.safe_load(f)
+ if os.path.exists(self._filename):
+ with open(self._filename, 'r') as f:
+ data = f.read()
+ if data:
+ f.seek(0)
+ return yaml.safe_load(f)
return {}
def set_variable(self, name, value):
@@ -66,7 +70,7 @@ class YarnHelper(object):
self._save_variables(self._variables)
def _save_variables(self, variables):
- with open(variables_filename, 'w') as f:
+ with open(self._filename, 'w') as f:
yaml.safe_dump(variables, f)
def construct_aliased_http_request(
diff --git a/yarnhelper_tests.py b/yarnhelper_tests.py
index 756a165..858bcdc 100644
--- a/yarnhelper_tests.py
+++ b/yarnhelper_tests.py
@@ -66,12 +66,9 @@ class PersistentVariableTests(unittest.TestCase):
if os.path.exists(yarnhelper.variables_filename):
os.remove(yarnhelper.variables_filename)
- def test_raises_error_if_no_such_variable(self):
+ def test_returns_default_if_no_such_variable(self):
h = yarnhelper.YarnHelper()
- with self.assertRaises(yarnhelper.Error):
- h.get_variable('FOO')
- print
- print 'variables:', h._variables
+ self.assertEqual(h.get_variable('foo', default=42), 42)
def test_sets_variable_persistently(self):
h = yarnhelper.YarnHelper()