summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLars Wirzenius <liw@liw.fi>2016-09-25 19:04:50 +0300
committerLars Wirzenius <liw@liw.fi>2016-09-25 19:04:50 +0300
commit7b94ed647e5e0454d120fc0502e661525f6a3f52 (patch)
treedd3ea69f2a506374897ff91bce6b853e1635d879
parent9f8dfa9d9b3cf3315bfab27274f4b11715d952c4 (diff)
downloadserver-yarns-7b94ed647e5e0454d120fc0502e661525f6a3f52.tar.gz
Implement persistent saving, loading of variables
-rw-r--r--yarnhelper.py28
-rw-r--r--yarnhelper_tests.py17
2 files changed, 44 insertions, 1 deletions
diff --git a/yarnhelper.py b/yarnhelper.py
index b5c45d0..ae0c30b 100644
--- a/yarnhelper.py
+++ b/yarnhelper.py
@@ -18,12 +18,18 @@
import os
+import yaml
+
+
+variables_filename = 'vars.yaml'
+
class YarnHelper(object):
def __init__(self):
self._env = dict(os.environ)
self._next_match = 1
+ self._variables = None # None means not loaded, otherwise dict
def set_environment(self, env):
self._env = dict(env)
@@ -36,7 +42,27 @@ class YarnHelper(object):
return self._env[name]
def get_variable(self, name):
- raise Error('no variable {}'.format(name))
+ if self._variables is None:
+ self._variables = self._load_variables()
+ if name not in self._variables:
+ raise Error('no variable {}'.format(name))
+ return self._variables[name]
+
+ def _load_variables(self):
+ if os.path.exists(variables_filename):
+ with open(variables_filename, 'r') as f:
+ return yaml.safe_load(f)
+ return {}
+
+ def set_variable(self, name, value):
+ if self._variables is None:
+ self._variables = {}
+ self._variables[name] = value
+ self._save_variables(self._variables)
+
+ def _save_variables(self, variables):
+ with open(variables_filename, 'w') as f:
+ yaml.safe_dump(variables, f)
class Error(Exception):
diff --git a/yarnhelper_tests.py b/yarnhelper_tests.py
index bdc0cd9..aca3432 100644
--- a/yarnhelper_tests.py
+++ b/yarnhelper_tests.py
@@ -16,6 +16,7 @@
# =*= License: GPL-3+ =*=
+import os
import unittest
import yarnhelper
@@ -57,8 +58,24 @@ class GetNextMatchTests(unittest.TestCase):
class PersistentVariableTests(unittest.TestCase):
+ def setUp(self):
+ # We need this so that tearDown works
+ pass
+
+ def tearDown(self):
+ if os.path.exists(yarnhelper.variables_filename):
+ os.remove(yarnhelper.variables_filename)
+
def test_raises_error_if_no_such_variable(self):
h = yarnhelper.YarnHelper()
with self.assertRaises(yarnhelper.Error):
h.get_variable('FOO')
+ print
+ print 'variables:', h._variables
+
+ def test_sets_variable_persistently(self):
+ h = yarnhelper.YarnHelper()
+ h.set_variable('FOO', 'bar')
+ h2 = yarnhelper.YarnHelper()
+ self.assertEqual(h2.get_variable('FOO'), 'bar')