src/pyams_zmq/process.py
changeset 13 839b61e1531a
parent 4 d624312bfc2b
child 18 cd5a88ba2223
--- a/src/pyams_zmq/process.py	Sat Jan 27 00:40:59 2018 +0100
+++ b/src/pyams_zmq/process.py	Mon Mar 05 12:27:53 2018 +0100
@@ -23,6 +23,7 @@
 from pyams_zmq.interfaces import IZMQProcess
 
 # import packages
+from zmq.auth.thread import ThreadAuthenticator
 from zmq.eventloop import ioloop, zmqstream
 from zope.interface import implementer
 
@@ -35,8 +36,9 @@
     """
 
     socket_type = zmq.REP
+    auth_thread = None
 
-    def __init__(self, bind_addr, handler):
+    def __init__(self, bind_addr, handler, auth=None, clients=None):
         super(ZMQProcess, self).__init__()
 
         self.context = None
@@ -48,10 +50,18 @@
         self.bind_addr = bind_addr
         self.rep_stream = None
         self.handler = handler
+        self.passwords = dict([auth.split(':', 1)]) if auth else None
+        self.clients = clients.split() if clients else None
 
     def setup(self):
         """Creates a :attr:`context` and an event :attr:`loop` for the process."""
-        self.context = zmq.Context()
+        ctx = self.context = zmq.Context()
+        auth = self.auth_thread = ThreadAuthenticator(ctx)
+        auth.start()
+        if self.clients:
+            auth.allow(*self.clients)
+        if self.passwords:
+            auth.configure_plain(domain='*', passwords=self.passwords)
         self.loop = ioloop.IOLoop.instance()
         self.rep_stream, _ = self.stream(self.socket_type, self.bind_addr, bind=True)
         self.initStream()
@@ -71,6 +81,7 @@
         if self.loop is not None:
             self.loop.stop()
             self.loop = None
+        self.auth_thread.stop()
 
     def exit(self, num, frame):
         self.stop()
@@ -108,6 +119,10 @@
         """
         sock = self.context.socket(sock_type)
 
+        # add server authenticator
+        if self.passwords:
+            sock.plain_server = True
+
         # addr may be 'host:port' or ('host', port)
         if isinstance(addr, str):
             addr = addr.split(':')