src/pyams_ldap/plugin.py
changeset 0 94ee60dd51e1
child 2 68423cd701bb
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/pyams_ldap/plugin.py	Sat Feb 28 15:20:14 2015 +0100
@@ -0,0 +1,330 @@
+#
+# Copyright (c) 2008-2015 Thierry Florac <tflorac AT ulthar.net>
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+
+__docformat__ = 'restructuredtext'
+
+
+# import standard library
+import ldap3
+import logging
+logger = logging.getLogger('pyams_ldap')
+
+import re
+
+# import interfaces
+from pyams_ldap.interfaces import ILDAPPlugin
+from zope.intid.interfaces import IIntIds
+
+# import packages
+from beaker.cache import cache_region
+from persistent import Persistent
+from pyams_ldap.query import LDAPQuery
+from pyams_security.principal import PrincipalInfo
+from pyams_utils.registry import query_utility
+from zope.container.contained import Contained
+from zope.interface import implementer
+from zope.schema.fieldproperty import FieldProperty
+
+
+managers = {}
+
+
+FORMAT_ATTRIBUTES = re.compile("\{(\w+)\[?\d*\]?\}")
+
+
+class ConnectionManager(object):
+    """LDAP connections manager"""
+
+    def __init__(self, plugin):
+        self.server = ldap3.Server(plugin.host,
+                                   port=plugin.port,
+                                   use_ssl=plugin.use_ssl,
+                                   tls=plugin.use_tls)
+        self.bind_dn = plugin.bind_dn
+        self.password = plugin.bind_password
+        if plugin.use_pool:
+            self.strategy = ldap3.STRATEGY_REUSABLE_THREADED
+            self.pool_name = 'pyams_ldap:{prefix}'.format(prefix=plugin.prefix)
+            self.pool_size = plugin.pool_size
+            self.pool_lifetime = plugin.pool_lifetime
+        else:
+            self.strategy = ldap3.STRATEGY_ASYNC_THREADED
+            self.pool_name = None
+            self.pool_size = None
+            self.pool_lifetime = None
+
+    def get_connection(self, user=None, password=None):
+        if user:
+            conn = ldap3.Connection(self.server,
+                                    user=user, password=password,
+                                    client_strategy=ldap3.STRATEGY_SYNC,
+                                    auto_bind=True, lazy=False, read_only=True)
+        else:
+            conn = ldap3.Connection(self.server,
+                                    user=self.bind_dn, password=self.password,
+                                    client_strategy=self.strategy,
+                                    pool_name=self.pool_name,
+                                    pool_size=self.pool_size,
+                                    pool_lifetime=self.pool_lifetime,
+                                    auto_bind=True, lazy=False, read_only=True)
+        return conn
+
+@implementer(ILDAPPlugin)
+class LDAPPlugin(Persistent, Contained):
+    """LDAP authentication plug-in"""
+
+    prefix = FieldProperty(ILDAPPlugin['prefix'])
+    title = FieldProperty(ILDAPPlugin['title'])
+    enabled = FieldProperty(ILDAPPlugin['enabled'])
+
+    _scheme = None
+    _host = None
+    _port = None
+    _use_ssl = False
+
+    _server_uri = FieldProperty(ILDAPPlugin['server_uri'])
+    bind_dn = FieldProperty(ILDAPPlugin['bind_dn'])
+    bind_password = FieldProperty(ILDAPPlugin['bind_password'])
+    use_tls = FieldProperty(ILDAPPlugin['use_tls'])
+    use_pool = FieldProperty(ILDAPPlugin['use_pool'])
+    pool_size = FieldProperty(ILDAPPlugin['pool_size'])
+    pool_lifetime = FieldProperty(ILDAPPlugin['pool_lifetime'])
+    base_dn = FieldProperty(ILDAPPlugin['base_dn'])
+    search_scope = FieldProperty(ILDAPPlugin['search_scope'])
+    login_attribute = FieldProperty(ILDAPPlugin['login_attribute'])
+    login_query = FieldProperty(ILDAPPlugin['login_query'])
+    uid_attribute = FieldProperty(ILDAPPlugin['uid_attribute'])
+    uid_query = FieldProperty(ILDAPPlugin['uid_query'])
+    title_format = FieldProperty(ILDAPPlugin['title_format'])
+    groups_base_dn = FieldProperty(ILDAPPlugin['groups_base_dn'])
+    groups_search_scope = FieldProperty(ILDAPPlugin['groups_search_scope'])
+    groups_query = FieldProperty(ILDAPPlugin['groups_query'])
+    group_prefix = FieldProperty(ILDAPPlugin['group_prefix'])
+    group_uid_attribute = FieldProperty(ILDAPPlugin['group_uid_attribute'])
+    group_title_format = FieldProperty(ILDAPPlugin['group_title_format'])
+
+    users_select_query = FieldProperty(ILDAPPlugin['users_select_query'])
+    users_search_query = FieldProperty(ILDAPPlugin['users_search_query'])
+    groups_select_query = FieldProperty(ILDAPPlugin['groups_select_query'])
+    groups_search_query = FieldProperty(ILDAPPlugin['groups_search_query'])
+
+    @property
+    def server_uri(self):
+        return self._server_uri
+
+    @server_uri.setter
+    def server_uri(self, value):
+        self._server_uri = value
+        try:
+            scheme, host = value.split('://', 1)
+        except ValueError:
+            scheme = 'ldap'
+            host = value
+        self._use_ssl = scheme == 'ldaps'
+        self._scheme = scheme
+        try:
+            host, port = host.split(':', 1)
+            port = int(port)
+        except ValueError:
+            port = 636 if self._use_ssl else 389
+        self._host = host
+        self._port = port
+
+    @property
+    def scheme(self):
+        return self._scheme
+
+    @property
+    def host(self):
+        return self._host
+
+    @property
+    def port(self):
+        return self._port
+
+    @property
+    def use_ssl(self):
+        return self._use_ssl
+
+    def _get_id(self):
+        intids = query_utility(IIntIds)
+        return intids.register(self)
+
+    def clear(self):
+        self_id = self._get_id()
+        if self_id in managers:
+            del managers[self_id]
+
+    def get_connection(self, user=None, password=None):
+        self_id = self._get_id()
+        if self_id not in managers:
+            managers[self_id] = ConnectionManager(self)
+        return managers[self_id].get_connection(user, password)
+
+    def authenticate(self, credentials, request):
+        if not self.enabled:
+            return None
+        attrs = credentials.attributes
+        login = attrs.get('login')
+        password = attrs.get('password')
+        conn = self.get_connection()
+        search = LDAPQuery(self.base_dn, self.login_query, self.search_scope, (self.login_attribute,
+                                                                               self.uid_attribute))
+        result = search.execute(conn, login=login, password=password)
+        if not result or len(result) > 1:
+            return None
+        result = result[0]
+        login_dn = result[0]
+        try:
+            login_conn = self.get_connection(user=login_dn, password=password)
+            login_conn.unbind()
+        except ldap3.LDAPException:
+            logger.debug("LDAP authentication exception with login %r", login, exc_info=True)
+            return None
+        else:
+            if self.uid_attribute == 'dn':
+                return "{prefix}:{dn}".format(prefix=self.prefix,
+                                              dn=login_dn)
+            else:
+                attrs = result[1]
+                if self.login_attribute in attrs:
+                    return "{prefix}:{attr}".format(prefix=self.prefix,
+                                                    attr=attrs[self.uid_attribute][0])
+
+    def _get_group(self, group_id):
+        if not self.enabled:
+            return None
+
+    def get_principal(self, principal_id):
+        if not self.enabled:
+            return None
+        if not principal_id.startswith(self.prefix + ':'):
+            return None
+        prefix, login = principal_id.split(':', 1)
+        conn = self.get_connection()
+        if login.startswith(self.group_prefix + ':'):
+            group_prefix, group_id = login.split(':', 1)
+            attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format)
+            if self.group_uid_attribute == 'dn':
+                search = LDAPQuery(group_id, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
+            else:
+                search = LDAPQuery(self.base_dn, self.uid_query, self.search_scope, attributes)
+            result = search.execute(conn, login=group_id)
+            if not result or len(result) > 1:
+                return None
+            group_dn, attrs = result[0]
+            return PrincipalInfo(id='{prefix}:{group_prefix}:{group_id}'.format(prefix=self.prefix,
+                                                                                group_prefix=self.group_prefix,
+                                                                                group_id=group_id),
+                                 title=self.group_title_format.format(**attrs),
+                                 dn=group_dn)
+        else:
+            attributes = FORMAT_ATTRIBUTES.findall(self.title_format)
+            if self.uid_attribute == 'dn':
+                search = LDAPQuery(login, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
+            else:
+                search = LDAPQuery(self.base_dn, self.uid_query, self.search_scope, attributes)
+            result = search.execute(conn, login=login)
+            if not result or len(result) > 1:
+                return None
+            user_dn, attrs = result[0]
+            return PrincipalInfo(id='{prefix}:{login}'.format(prefix=self.prefix,
+                                                              login=login),
+                                 title=self.title_format.format(**attrs),
+                                 dn=user_dn)
+
+    def _get_groups(self, principal):
+        if not self.groups_base_dn:
+            raise StopIteration
+        principal_dn = principal.attributes.get('dn')
+        if principal_dn is None:
+            raise StopIteration
+        conn = self.get_connection()
+        attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format)
+        search = LDAPQuery(self.groups_base_dn, self.groups_query, self.groups_search_scope, attributes)
+        for group_dn, group_attrs in search.execute(conn, dn=principal_dn):
+            if self.group_uid_attribute == 'dn':
+                yield '{prefix}:{group_prefix}:{dn}'.format(prefix=self.prefix,
+                                                            group_prefix=self.group_prefix,
+                                                            dn=group_dn)
+            else:
+                yield '{prefix}:{group_prefix}:{attr}'.format(prefix=self.prefix,
+                                                              group_prefix=self.group_prefix,
+                                                              attr=group_attrs[self.group_uid_attribute])
+
+    @cache_region('short')
+    def get_all_principals(self, principal_id):
+        if not self.enabled:
+            return set()
+        principal = self.get_principal(principal_id)
+        if principal is not None:
+            result = {principal_id}
+            if self.groups_query:
+                result |= set(self._get_groups(principal))
+            return result
+        return set()
+
+    def find_principals(self, query):
+        if not self.enabled:
+            raise StopIteration
+        if not query:
+            return None
+        conn = self.get_connection()
+        # users search
+        attributes = FORMAT_ATTRIBUTES.findall(self.title_format) + [self.uid_attribute, ]
+        search = LDAPQuery(self.base_dn, self.users_select_query, self.search_scope, attributes)
+        for user_dn, user_attrs in search.execute(conn, query=query):
+            if self.uid_attribute == 'dn':
+                yield PrincipalInfo(id='{prefix}:{dn}'.format(prefix=self.prefix,
+                                                              dn=user_dn),
+                                    title=self.title_format.format(**user_attrs),
+                                    dn=user_dn)
+            else:
+                yield PrincipalInfo(id='{prefix}:{attr}'.format(prefix=self.prefix,
+                                                                attr=user_attrs[self.uid_attribute][0]),
+                                    title=self.title_format.format(**user_attrs),
+                                    dn=user_dn)
+        # groups search
+        if self.groups_base_dn:
+            attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format) + [self.group_uid_attribute, ]
+            search = LDAPQuery(self.groups_base_dn, self.groups_select_query, self.groups_search_scope, attributes)
+            for group_dn, group_attrs in search.execute(conn, query=query):
+                if self.group_uid_attribute == 'dn':
+                    yield PrincipalInfo(id='{prefix}:{group_prefix}:{dn}'.format(prefix=self.prefix,
+                                                                                 group_prefix=self.group_prefix,
+                                                                                 dn=group_dn),
+                                        title=self.group_title_format.format(**group_attrs),
+                                        dn=group_dn)
+                else:
+                    yield PrincipalInfo(id='{prefix}:{group_prefix}:{attr}'.format(prefix=self.prefix,
+                                                                                   group_prefix=self.group_prefix,
+                                                                                   attr=group_attrs[
+                                                                                       self.group_uid_attribute][0]),
+                                        title=self.group_title_format.format(**group_attrs),
+                                        dn=group_dn)
+
+    def get_search_results(self, data):
+        # LDAP search results are made of tuples containing DN and all
+        # entries attributes
+        query = data.get('query')
+        if not query:
+            return ()
+        conn = self.get_connection()
+        # users search
+        search = LDAPQuery(self.base_dn, self.users_search_query, self.search_scope, ldap3.ALL_ATTRIBUTES)
+        for user_dn, user_attrs in search.execute(conn, query=query):
+            yield user_dn, user_attrs
+        # groups search
+        if self.groups_base_dn:
+            search = LDAPQuery(self.groups_base_dn, self.groups_search_query, self.groups_search_scope, ldap3.ALL_ATTRIBUTES)
+            for group_dn, group_attrs in search.execute(conn, query=query):
+                yield group_dn, group_attrs