src/pyams_ldap/plugin.py
changeset 16 59c957423fe8
parent 8 bbc24a6bd6f0
child 20 68b5251b9687
--- a/src/pyams_ldap/plugin.py	Mon Jan 18 18:02:35 2016 +0100
+++ b/src/pyams_ldap/plugin.py	Thu Jun 02 15:38:05 2016 +0200
@@ -21,7 +21,8 @@
 import re
 
 # import interfaces
-from pyams_ldap.interfaces import ILDAPPlugin
+from pyams_ldap.interfaces import ILDAPPlugin, ILDAPUserInfo, ILDAPGroupInfo
+from pyams_mail.interfaces import IPrincipalMailInfo
 from zope.intid.interfaces import IIntIds
 
 # import packages
@@ -29,6 +30,7 @@
 from persistent import Persistent
 from pyams_ldap.query import LDAPQuery
 from pyams_security.principal import PrincipalInfo
+from pyams_utils.adapter import adapter_config, ContextAdapter
 from pyams_utils.registry import query_utility
 from zope.container.contained import Contained
 from zope.interface import implementer
@@ -78,6 +80,87 @@
                                     auto_bind=True, lazy=False, read_only=True)
         return conn
 
+
+@implementer(ILDAPUserInfo)
+class LDAPUserInfo(object):
+    """LDAP user info"""
+
+    def __init__(self, dn, attributes, plugin=None):
+        self.dn = dn
+        self.attributes = attributes
+        self.plugin = plugin
+
+
+@adapter_config(context=ILDAPUserInfo, provides=IPrincipalMailInfo)
+class LDAPUserMailInfoAdapter(ContextAdapter):
+    """LDAP user mail adapter"""
+
+    def get_addresses(self):
+        """Get mail address of given user"""
+
+        user = self.context
+        plugin = user.plugin
+
+        mail = user.attributes.get(plugin.mail_attribute)
+        if mail:
+            return {(plugin.title_format.format(**user.attributes), mail[0])}
+        else:
+            return set()
+
+
+@implementer(ILDAPGroupInfo)
+class LDAPGroupInfo(object):
+    """LDAP group info"""
+
+    def __init__(self, dn, attributes, plugin=None):
+        self.dn = dn
+        self.attributes = attributes
+        self.plugin = plugin
+
+    def get_members(self, info=True):
+        return self.plugin.get_members(self, info=info)
+
+
+@adapter_config(context=ILDAPGroupInfo, provides=IPrincipalMailInfo)
+class LDAPGroupMailInfoAdapter(ContextAdapter):
+    """LDAP group mail adapter"""
+
+    def get_addresses(self):
+        """Get mail address of given group"""
+
+        group = self.context
+        plugin = group.plugin
+
+        if plugin.group_mail_mode == 'none':
+            # use members address
+            for member in plugin.get_members(group, info=False):
+                mail_info = IPrincipalMailInfo(member, None)
+                if mail_info is not None:
+                    for address in mail_info.get_addresses():
+                        yield address
+
+        elif plugin.group_mail_mode == 'internal':
+            # use group internal attribute
+            mail = group.attributes.get(plugin.group_mail_attribute)
+            if mail:
+                yield plugin.group_title_format(**group.attributes), mail[0]
+
+        else:
+            # redirect: use internal attribute of another group
+            source, target = plugin.group_replace_expression.split('|')
+            target_dn = group.dn.replace(source, target)
+            conn = plugin.get_connection()
+            attributes = FORMAT_ATTRIBUTES.findall(plugin.group_title_format) + [plugin.group_mail_attribute]
+            search = LDAPQuery(target_dn, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
+            result = search.execute(conn)
+            if not result or len(result) > 1:
+                raise StopIteration
+            target_dn, attrs = result[0]
+            mail = attrs.get(plugin.group_mail_attribute)
+            if mail:
+                yield plugin.group_title_format(**attrs), mail[0]
+
+
 @implementer(ILDAPPlugin)
 class LDAPPlugin(Persistent, Contained):
     """LDAP authentication plug-in"""
@@ -100,17 +183,28 @@
     pool_lifetime = FieldProperty(ILDAPPlugin['pool_lifetime'])
     base_dn = FieldProperty(ILDAPPlugin['base_dn'])
     search_scope = FieldProperty(ILDAPPlugin['search_scope'])
+
     login_attribute = FieldProperty(ILDAPPlugin['login_attribute'])
     login_query = FieldProperty(ILDAPPlugin['login_query'])
     uid_attribute = FieldProperty(ILDAPPlugin['uid_attribute'])
     uid_query = FieldProperty(ILDAPPlugin['uid_query'])
     title_format = FieldProperty(ILDAPPlugin['title_format'])
+    mail_attribute = FieldProperty(ILDAPPlugin['mail_attribute'])
+    user_extra_attributes = FieldProperty(ILDAPPlugin['user_extra_attributes'])
+
     groups_base_dn = FieldProperty(ILDAPPlugin['groups_base_dn'])
     groups_search_scope = FieldProperty(ILDAPPlugin['groups_search_scope'])
-    groups_query = FieldProperty(ILDAPPlugin['groups_query'])
     group_prefix = FieldProperty(ILDAPPlugin['group_prefix'])
     group_uid_attribute = FieldProperty(ILDAPPlugin['group_uid_attribute'])
     group_title_format = FieldProperty(ILDAPPlugin['group_title_format'])
+    group_members_query_mode = FieldProperty(ILDAPPlugin['group_members_query_mode'])
+    groups_query = FieldProperty(ILDAPPlugin['groups_query'])
+    group_members_attribute = FieldProperty(ILDAPPlugin['group_members_attribute'])
+    user_groups_attribute = FieldProperty(ILDAPPlugin['user_groups_attribute'])
+    group_mail_mode = FieldProperty(ILDAPPlugin['group_mail_mode'])
+    group_replace_expression = FieldProperty(ILDAPPlugin['group_replace_expression'])
+    group_mail_attribute = FieldProperty(ILDAPPlugin['group_mail_attribute'])
+    group_extra_attributes = FieldProperty(ILDAPPlugin['group_extra_attributes'])
 
     users_select_query = FieldProperty(ILDAPPlugin['users_select_query'])
     users_search_query = FieldProperty(ILDAPPlugin['users_search_query'])
@@ -213,7 +307,9 @@
         conn = self.get_connection()
         if login.startswith(self.group_prefix + ':'):
             group_prefix, group_id = login.split(':', 1)
-            attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format)
+            attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format) + [self.group_mail_attribute]
+            if self.group_extra_attributes:
+                attributes += self.group_extra_attributes.split(',')
             if self.group_uid_attribute == 'dn':
                 search = LDAPQuery(group_id, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
             else:
@@ -222,13 +318,21 @@
             if not result or len(result) > 1:
                 return None
             group_dn, attrs = result[0]
-            return PrincipalInfo(id='{prefix}:{group_prefix}:{group_id}'.format(prefix=self.prefix,
-                                                                                group_prefix=self.group_prefix,
-                                                                                group_id=group_id),
-                                 title=self.group_title_format.format(**attrs),
-                                 dn=group_dn)
+            if info:
+                return PrincipalInfo(id='{prefix}:{group_prefix}:{group_id}'.format(prefix=self.prefix,
+                                                                                    group_prefix=self.group_prefix,
+                                                                                    group_id=group_id),
+                                     title=self.group_title_format.format(**attrs),
+                                     dn=group_dn)
+            else:
+                attrs.update({'principal_id': '{prefix}:{group_prefix}:{group_id}'.format(prefix=self.prefix,
+                                                                                          group_prefix=self.group_prefix,
+                                                                                          group_id=group_id)})
+                return LDAPGroupInfo(group_dn, attrs, self)
         else:
-            attributes = FORMAT_ATTRIBUTES.findall(self.title_format)
+            attributes = FORMAT_ATTRIBUTES.findall(self.title_format) + [self.mail_attribute]
+            if self.user_extra_attributes:
+                attributes += self.user_extra_attributes.split(',')
             if self.uid_attribute == 'dn':
                 search = LDAPQuery(login, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
             else:
@@ -237,32 +341,61 @@
             if not result or len(result) > 1:
                 return None
             user_dn, attrs = result[0]
-            return PrincipalInfo(id='{prefix}:{login}'.format(prefix=self.prefix,
-                                                              login=login),
-                                 title=self.title_format.format(**attrs),
-                                 dn=user_dn)
+            if info:
+                return PrincipalInfo(id='{prefix}:{login}'.format(prefix=self.prefix,
+                                                                  login=login),
+                                     title=self.title_format.format(**attrs),
+                                     dn=user_dn)
+            else:
+                attrs.update({'principal_id': '{prefix}:{login}'.format(prefix=self.prefix, login=login)})
+                return LDAPUserInfo(user_dn, attrs, self)
 
     def _get_groups(self, principal):
-        if not self.groups_base_dn:
-            raise StopIteration
+        """Get principal groups"""
         principal_dn = principal.attributes.get('dn')
         if principal_dn is None:
             raise StopIteration
-        conn = self.get_connection()
-        attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format)
-        search = LDAPQuery(self.groups_base_dn, self.groups_query, self.groups_search_scope, attributes)
-        for group_dn, group_attrs in search.execute(conn, dn=principal_dn):
-            if self.group_uid_attribute == 'dn':
-                yield '{prefix}:{group_prefix}:{dn}'.format(prefix=self.prefix,
-                                                            group_prefix=self.group_prefix,
-                                                            dn=group_dn)
-            else:
-                yield '{prefix}:{group_prefix}:{attr}'.format(prefix=self.prefix,
-                                                              group_prefix=self.group_prefix,
-                                                              attr=group_attrs[self.group_uid_attribute])
+        if self.group_members_query_mode == 'group':
+            # group members are defined inside group
+            if not self.groups_base_dn:
+                raise StopIteration
+            conn = self.get_connection()
+            attributes = FORMAT_ATTRIBUTES.findall(self.group_title_format)
+            search = LDAPQuery(self.groups_base_dn, self.groups_query, self.groups_search_scope, attributes)
+            for group_dn, group_attrs in search.execute(conn, dn=principal_dn):
+                if self.group_uid_attribute == 'dn':
+                    yield '{prefix}:{group_prefix}:{dn}'.format(prefix=self.prefix,
+                                                                group_prefix=self.group_prefix,
+                                                                dn=group_dn)
+                else:
+                    yield '{prefix}:{group_prefix}:{attr}'.format(prefix=self.prefix,
+                                                                  group_prefix=self.group_prefix,
+                                                                  attr=group_attrs[self.group_uid_attribute])
+        else:
+            # a member defines it's groups
+            conn = self.get_connection()
+            attributes = [self.user_groups_attribute]
+            user_search = LDAPQuery(principal_dn, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
+            for user_dn, user_attrs in user_search.execute(conn):
+                if self.group_uid_attribute == 'dn':
+                    for group_dn in user_attrs.get(self.user_groups_attribute, ()):
+                        yield '{prefix}:{group_prefix}:{dn}'.format(prefix=self.prefix,
+                                                                    group_prefix=self.group_prefix,
+                                                                    dn=group_dn)
+                else:
+                    attributes = [self.group_uid_attribute]
+                    for group_dn in user_attrs.get(self.user_groups_attribute, ()):
+                        group_search = LDAPQuery(group_dn, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT,
+                                                 attributes)
+                        for group_search_dn, group_search_attrs in group_search.execute(conn):
+                            yield '{prefix}:{group_prefix}:{attr}'.format(prefix=self.prefix,
+                                                                          group_prefix=self.group_prefix,
+                                                                          attr=group_search_attrs[
+                                                                              self.group_uid_attribute])
 
     @cache_region('short')
     def get_all_principals(self, principal_id):
+        """Get all principals (including groups) for given principal ID"""
         if not self.enabled:
             return set()
         principal = self.get_principal(principal_id)
@@ -273,6 +406,47 @@
             return result
         return set()
 
+    def get_members(self, group, info=True):
+        """Get all members of given LDAP group as LDAP users"""
+        if not self.enabled:
+            return set()
+        conn = self.get_connection()
+        if self.group_members_query_mode == 'group':
+            # group members are defined into group attribute
+            attributes = [self.group_members_attribute]
+            user_attributes = FORMAT_ATTRIBUTES.findall(self.title_format) + [self.mail_attribute]
+            search = LDAPQuery(group.dn, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, attributes)
+            for group_dn, attrs in search.execute(conn):
+                for user_dn in attrs.get(self.group_members_attribute):
+                    user_search = LDAPQuery(user_dn, '(objectClass=*)', ldap3.SEARCH_SCOPE_BASE_OBJECT, user_attributes)
+                    for user_search_dn, user_search_attrs in user_search.execute(conn):
+                        if info:
+                            yield PrincipalInfo(id='{prefix}:{dn}'.format(prefix=self.prefix,
+                                                                          dn=user_search_dn),
+                                                title=self.title_format.format(**user_search_attrs),
+                                                dn=user_search_dn)
+                        else:
+                            yield LDAPUserInfo(dn=user_search_dn, attributes=user_search_attrs, plugin=self)
+        else:
+            # member groups are defined into member attribute
+            attributes = FORMAT_ATTRIBUTES.findall(self.title_format) + [self.uid_attribute, self.mail_attribute]
+            search = LDAPQuery(self.base_dn, '({attribute}={{group_dn}})'.format(attribute=self.user_groups_attribute),
+                               self.search_scope, attributes)
+            for user_dn, user_attrs in search.execute(conn, group_dn=group.dn):
+                if info:
+                    if self.uid_attribute == 'dn':
+                        yield PrincipalInfo(id='{prefix}:{dn}'.format(prefix=self.prefix,
+                                                                      dn=user_dn),
+                                            title=self.title_format.format(**user_attrs),
+                                            dn=user_dn)
+                    else:
+                        yield PrincipalInfo(id='{prefix}:{attr}'.format(prefix=self.prefix,
+                                                                        attr=user_attrs[self.uid_attribute][0]),
+                                            title=self.title_format.format(**user_attrs),
+                                            dn=user_dn)
+                else:
+                    yield LDAPUserInfo(dn=user_dn, attributes=user_attrs, plugin=self)
+
     def find_principals(self, query):
         if not self.enabled:
             raise StopIteration