diff options
author | Lars Wirzenius <liw@liw.fi> | 2016-09-25 19:04:50 +0300 |
---|---|---|
committer | Lars Wirzenius <liw@liw.fi> | 2016-09-25 19:04:50 +0300 |
commit | 7b94ed647e5e0454d120fc0502e661525f6a3f52 (patch) | |
tree | dd3ea69f2a506374897ff91bce6b853e1635d879 | |
parent | 9f8dfa9d9b3cf3315bfab27274f4b11715d952c4 (diff) | |
download | server-yarns-7b94ed647e5e0454d120fc0502e661525f6a3f52.tar.gz |
Implement persistent saving, loading of variables
-rw-r--r-- | yarnhelper.py | 28 | ||||
-rw-r--r-- | yarnhelper_tests.py | 17 |
2 files changed, 44 insertions, 1 deletions
diff --git a/yarnhelper.py b/yarnhelper.py index b5c45d0..ae0c30b 100644 --- a/yarnhelper.py +++ b/yarnhelper.py @@ -18,12 +18,18 @@ import os +import yaml + + +variables_filename = 'vars.yaml' + class YarnHelper(object): def __init__(self): self._env = dict(os.environ) self._next_match = 1 + self._variables = None # None means not loaded, otherwise dict def set_environment(self, env): self._env = dict(env) @@ -36,7 +42,27 @@ class YarnHelper(object): return self._env[name] def get_variable(self, name): - raise Error('no variable {}'.format(name)) + if self._variables is None: + self._variables = self._load_variables() + if name not in self._variables: + raise Error('no variable {}'.format(name)) + return self._variables[name] + + def _load_variables(self): + if os.path.exists(variables_filename): + with open(variables_filename, 'r') as f: + return yaml.safe_load(f) + return {} + + def set_variable(self, name, value): + if self._variables is None: + self._variables = {} + self._variables[name] = value + self._save_variables(self._variables) + + def _save_variables(self, variables): + with open(variables_filename, 'w') as f: + yaml.safe_dump(variables, f) class Error(Exception): diff --git a/yarnhelper_tests.py b/yarnhelper_tests.py index bdc0cd9..aca3432 100644 --- a/yarnhelper_tests.py +++ b/yarnhelper_tests.py @@ -16,6 +16,7 @@ # =*= License: GPL-3+ =*= +import os import unittest import yarnhelper @@ -57,8 +58,24 @@ class GetNextMatchTests(unittest.TestCase): class PersistentVariableTests(unittest.TestCase): + def setUp(self): + # We need this so that tearDown works + pass + + def tearDown(self): + if os.path.exists(yarnhelper.variables_filename): + os.remove(yarnhelper.variables_filename) + def test_raises_error_if_no_such_variable(self): h = yarnhelper.YarnHelper() with self.assertRaises(yarnhelper.Error): h.get_variable('FOO') + print + print 'variables:', h._variables + + def test_sets_variable_persistently(self): + h = yarnhelper.YarnHelper() + h.set_variable('FOO', 'bar') + h2 = yarnhelper.YarnHelper() + self.assertEqual(h2.get_variable('FOO'), 'bar') |