diff py/params.py @ 293:d15dda1b1f76

merge
author Matt Johnston <matt@ucc.asn.au>
date Sat, 06 Jul 2019 18:29:45 +0800
parents ef3a75128116
children
line wrap: on
line diff
--- a/py/params.py	Thu Mar 19 21:50:52 2015 +0800
+++ b/py/params.py	Sat Jul 06 18:29:45 2019 +0800
@@ -2,12 +2,13 @@
 import collections
 import json
 import signal
-import StringIO
-
-import gevent
+import tempfile
+import os
+import binascii
 
 import config
 from utils import W,L,E,EX
+import utils
 
 _FIELD_DEFAULTS = {
     'fridge_setpoint': 16,
@@ -26,6 +27,7 @@
 
     def __init__(self):
         self.update(_FIELD_DEFAULTS)
+        self._set_epoch(None)
 
     def __getattr__(self, k):
         return self[k]
@@ -35,16 +37,14 @@
         self[k]
         self[k] = v
 
-    def load(self, f = None):
-        if not f:
-            try:
-                f = file(config.PARAMS_FILE, 'r')
-            except IOError, e:
-                W("Missing parameter file, using defaults. %s", e)
-                return
+    def _set_epoch(self, epoch):
+        # since __setattr__ is overridden
+        object.__setattr__(self, '_epoch', epoch)
+
+    def _do_load(self, f):
         try:
-            u = json.load(f)
-        except Exception, e:
+            u = utils.json_load_round_float(f.read())
+        except Exception as e:
             raise self.Error(e)
 
         for k in u:
@@ -53,19 +53,77 @@
             if k not in self:
                 raise self.Error("Unknown parameter %s=%s in file '%s'" % (str(k), str(u[k]), getattr(f, 'name', '???')))
         self.update(u)
+        # new epoch, 120 random bits
+        self._set_epoch(binascii.hexlify(os.urandom(15)).decode())
 
         L("Loaded parameters")
         L(self.save_string())
 
+    def load(self, f = None):
+        if f:
+            return self._do_load(f)
+        else:
+            with open(config.PARAMS_FILE, 'r') as f:
+                try:
+                    return self._do_load(f)
+                except IOError as e:
+                    W("Missing parameter file, using defaults. %s" % str(e))
+                    return
 
-    def save(self, f = None):
-        if not f:
-            f = file(config.PARAMS_FILE, 'w')
-        json.dump(self, f, sort_keys=True, indent=4)
-        f.write('\n')
-        f.flush()
+    def get_epoch(self):
+        return self._epoch
+
+    def receive(self, params, epoch):
+        """ updates parameters from the server. does some validation,
+        writes config file to disk.
+        Returns True on success, False failure 
+        """
+
+        if epoch != self._epoch:
+            return
+
+        def same_type(a, b):
+            ta = type(a)
+            tb = type(b)
+
+            if ta == int:
+                ta = float
+            if tb == int:
+                tb = float
+
+            return ta == tb
+
+        if self.keys() != params.keys():
+            diff = self.keys() ^ params.keys()
+            E("Mismatching params, %s" % str(diff))
+            return False
+
+        for k, v in params.items():
+            if not same_type(v, self[k]):
+                E("Bad type for %s" % k)
+                return False
+
+        dir = os.path.dirname(config.PARAMS_FILE)
+        try:
+            t = tempfile.NamedTemporaryFile(prefix='config',
+                mode='w+t', # NamedTemporaryFile is binary by default
+                dir = dir,
+                delete = False)
+
+            out = json.dumps(params, sort_keys=True, indent=4)+'\n'
+            t.write(out)
+            name = t.name
+            t.close()
+
+            os.rename(name, config.PARAMS_FILE)
+        except Exception as e:
+            EX("Problem: %s" % e)
+            return False
+
+        self.update(params)
+        L("Received parameters")
+        L(self.save_string())
+        return True
 
     def save_string(self):
-        s = StringIO.StringIO()
-        self.save(s)
-        return s.getvalue()
+        return json.dumps(self, sort_keys=True, indent=4)