aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/identity/backends
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/identity/backends')
-rw-r--r--keystone-moon/keystone/identity/backends/__init__.py0
-rw-r--r--keystone-moon/keystone/identity/backends/ldap.py425
-rw-r--r--keystone-moon/keystone/identity/backends/sql.py402
3 files changed, 0 insertions, 827 deletions
diff --git a/keystone-moon/keystone/identity/backends/__init__.py b/keystone-moon/keystone/identity/backends/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/keystone-moon/keystone/identity/backends/__init__.py
+++ /dev/null
diff --git a/keystone-moon/keystone/identity/backends/ldap.py b/keystone-moon/keystone/identity/backends/ldap.py
deleted file mode 100644
index fe8e8477..00000000
--- a/keystone-moon/keystone/identity/backends/ldap.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# Copyright 2012 OpenStack Foundation
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-from __future__ import absolute_import
-import uuid
-
-import ldap.filter
-from oslo_config import cfg
-from oslo_log import log
-from oslo_log import versionutils
-import six
-
-from keystone.common import clean
-from keystone.common import driver_hints
-from keystone.common import ldap as common_ldap
-from keystone.common import models
-from keystone import exception
-from keystone.i18n import _
-from keystone import identity
-
-
-CONF = cfg.CONF
-LOG = log.getLogger(__name__)
-
-_DEPRECATION_MSG = _('%s for the LDAP identity backend has been deprecated in '
- 'the Mitaka release in favor of read-only identity LDAP '
- 'access. It will be removed in the "O" release.')
-
-
-class Identity(identity.IdentityDriverV8):
- def __init__(self, conf=None):
- super(Identity, self).__init__()
- if conf is None:
- self.conf = CONF
- else:
- self.conf = conf
- self.user = UserApi(self.conf)
- self.group = GroupApi(self.conf)
-
- def is_domain_aware(self):
- return False
-
- def generates_uuids(self):
- return False
-
- # Identity interface
-
- def authenticate(self, user_id, password):
- try:
- user_ref = self._get_user(user_id)
- except exception.UserNotFound:
- raise AssertionError(_('Invalid user / password'))
- if not user_id or not password:
- raise AssertionError(_('Invalid user / password'))
- conn = None
- try:
- conn = self.user.get_connection(user_ref['dn'],
- password, end_user_auth=True)
- if not conn:
- raise AssertionError(_('Invalid user / password'))
- except Exception:
- raise AssertionError(_('Invalid user / password'))
- finally:
- if conn:
- conn.unbind_s()
- return self.user.filter_attributes(user_ref)
-
- def _get_user(self, user_id):
- return self.user.get(user_id)
-
- def get_user(self, user_id):
- return self.user.get_filtered(user_id)
-
- def list_users(self, hints):
- return self.user.get_all_filtered(hints)
-
- def get_user_by_name(self, user_name, domain_id):
- # domain_id will already have been handled in the Manager layer,
- # parameter left in so this matches the Driver specification
- return self.user.filter_attributes(self.user.get_by_name(user_name))
-
- # CRUD
- def create_user(self, user_id, user):
- msg = _DEPRECATION_MSG % "create_user"
- versionutils.report_deprecated_feature(LOG, msg)
- self.user.check_allow_create()
- user_ref = self.user.create(user)
- return self.user.filter_attributes(user_ref)
-
- def update_user(self, user_id, user):
- msg = _DEPRECATION_MSG % "update_user"
- versionutils.report_deprecated_feature(LOG, msg)
- self.user.check_allow_update()
- old_obj = self.user.get(user_id)
- if 'name' in user and old_obj.get('name') != user['name']:
- raise exception.Conflict(_('Cannot change user name'))
-
- if self.user.enabled_mask:
- self.user.mask_enabled_attribute(user)
- elif self.user.enabled_invert and not self.user.enabled_emulation:
- # We need to invert the enabled value for the old model object
- # to prevent the LDAP update code from thinking that the enabled
- # values are already equal.
- user['enabled'] = not user['enabled']
- old_obj['enabled'] = not old_obj['enabled']
-
- self.user.update(user_id, user, old_obj)
- return self.user.get_filtered(user_id)
-
- def delete_user(self, user_id):
- msg = _DEPRECATION_MSG % "delete_user"
- versionutils.report_deprecated_feature(LOG, msg)
- self.user.check_allow_delete()
- user = self.user.get(user_id)
- user_dn = user['dn']
- groups = self.group.list_user_groups(user_dn)
- for group in groups:
- self.group.remove_user(user_dn, group['id'], user_id)
-
- if hasattr(user, 'tenant_id'):
- self.project.remove_user(user.tenant_id, user_dn)
- self.user.delete(user_id)
-
- def create_group(self, group_id, group):
- msg = _DEPRECATION_MSG % "create_group"
- versionutils.report_deprecated_feature(LOG, msg)
- self.group.check_allow_create()
- group['name'] = clean.group_name(group['name'])
- return common_ldap.filter_entity(self.group.create(group))
-
- def get_group(self, group_id):
- return self.group.get_filtered(group_id)
-
- def get_group_by_name(self, group_name, domain_id):
- # domain_id will already have been handled in the Manager layer,
- # parameter left in so this matches the Driver specification
- return self.group.get_filtered_by_name(group_name)
-
- def update_group(self, group_id, group):
- msg = _DEPRECATION_MSG % "update_group"
- versionutils.report_deprecated_feature(LOG, msg)
- self.group.check_allow_update()
- if 'name' in group:
- group['name'] = clean.group_name(group['name'])
- return common_ldap.filter_entity(self.group.update(group_id, group))
-
- def delete_group(self, group_id):
- msg = _DEPRECATION_MSG % "delete_group"
- versionutils.report_deprecated_feature(LOG, msg)
- self.group.check_allow_delete()
- return self.group.delete(group_id)
-
- def add_user_to_group(self, user_id, group_id):
- msg = _DEPRECATION_MSG % "add_user_to_group"
- versionutils.report_deprecated_feature(LOG, msg)
- user_ref = self._get_user(user_id)
- user_dn = user_ref['dn']
- self.group.add_user(user_dn, group_id, user_id)
-
- def remove_user_from_group(self, user_id, group_id):
- msg = _DEPRECATION_MSG % "remove_user_from_group"
- versionutils.report_deprecated_feature(LOG, msg)
- user_ref = self._get_user(user_id)
- user_dn = user_ref['dn']
- self.group.remove_user(user_dn, group_id, user_id)
-
- def list_groups_for_user(self, user_id, hints):
- user_ref = self._get_user(user_id)
- if self.conf.ldap.group_members_are_ids:
- user_dn = user_ref['id']
- else:
- user_dn = user_ref['dn']
- return self.group.list_user_groups_filtered(user_dn, hints)
-
- def list_groups(self, hints):
- return self.group.get_all_filtered(hints)
-
- def list_users_in_group(self, group_id, hints):
- users = []
- for user_key in self.group.list_group_users(group_id):
- if self.conf.ldap.group_members_are_ids:
- user_id = user_key
- else:
- user_id = self.user._dn_to_id(user_key)
-
- try:
- users.append(self.user.get_filtered(user_id))
- except exception.UserNotFound:
- LOG.debug(("Group member '%(user_key)s' not found in"
- " '%(group_id)s'. The user should be removed"
- " from the group. The user will be ignored."),
- dict(user_key=user_key, group_id=group_id))
- return users
-
- def check_user_in_group(self, user_id, group_id):
- user_refs = self.list_users_in_group(group_id, driver_hints.Hints())
- for x in user_refs:
- if x['id'] == user_id:
- break
- else:
- # Try to fetch the user to see if it even exists. This
- # will raise a more accurate exception.
- self.get_user(user_id)
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
-
-
-# TODO(termie): turn this into a data object and move logic to driver
-class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
- DEFAULT_OU = 'ou=Users'
- DEFAULT_STRUCTURAL_CLASSES = ['person']
- DEFAULT_ID_ATTR = 'cn'
- DEFAULT_OBJECTCLASS = 'inetOrgPerson'
- NotFound = exception.UserNotFound
- options_name = 'user'
- attribute_options_names = {'password': 'pass',
- 'email': 'mail',
- 'name': 'name',
- 'description': 'description',
- 'enabled': 'enabled',
- 'default_project_id': 'default_project_id'}
- immutable_attrs = ['id']
-
- model = models.User
-
- def __init__(self, conf):
- super(UserApi, self).__init__(conf)
- self.enabled_mask = conf.ldap.user_enabled_mask
- self.enabled_default = conf.ldap.user_enabled_default
- self.enabled_invert = conf.ldap.user_enabled_invert
- self.enabled_emulation = conf.ldap.user_enabled_emulation
-
- def _ldap_res_to_model(self, res):
- obj = super(UserApi, self)._ldap_res_to_model(res)
- if self.enabled_mask != 0:
- enabled = int(obj.get('enabled', self.enabled_default))
- obj['enabled'] = ((enabled & self.enabled_mask) !=
- self.enabled_mask)
- elif self.enabled_invert and not self.enabled_emulation:
- # This could be a bool or a string. If it's a string,
- # we need to convert it so we can invert it properly.
- enabled = obj.get('enabled', self.enabled_default)
- if isinstance(enabled, six.string_types):
- if enabled.lower() == 'true':
- enabled = True
- else:
- enabled = False
- obj['enabled'] = not enabled
- obj['dn'] = res[0]
-
- return obj
-
- def mask_enabled_attribute(self, values):
- value = values['enabled']
- values.setdefault('enabled_nomask', int(self.enabled_default))
- if value != ((values['enabled_nomask'] & self.enabled_mask) !=
- self.enabled_mask):
- values['enabled_nomask'] ^= self.enabled_mask
- values['enabled'] = values['enabled_nomask']
- del values['enabled_nomask']
-
- def create(self, values):
- if self.enabled_mask:
- orig_enabled = values['enabled']
- self.mask_enabled_attribute(values)
- elif self.enabled_invert and not self.enabled_emulation:
- orig_enabled = values['enabled']
- if orig_enabled is not None:
- values['enabled'] = not orig_enabled
- else:
- values['enabled'] = self.enabled_default
- values = super(UserApi, self).create(values)
- if self.enabled_mask or (self.enabled_invert and
- not self.enabled_emulation):
- values['enabled'] = orig_enabled
- return values
-
- def get_filtered(self, user_id):
- user = self.get(user_id)
- return self.filter_attributes(user)
-
- def get_all_filtered(self, hints):
- query = self.filter_query(hints, self.ldap_filter)
- return [self.filter_attributes(user)
- for user in self.get_all(query, hints)]
-
- def filter_attributes(self, user):
- return identity.filter_user(common_ldap.filter_entity(user))
-
- def is_user(self, dn):
- """Returns True if the entry is a user."""
- # NOTE(blk-u): It's easy to check if the DN is under the User tree,
- # but may not be accurate. A more accurate test would be to fetch the
- # entry to see if it's got the user objectclass, but this could be
- # really expensive considering how this is used.
-
- return common_ldap.dn_startswith(dn, self.tree_dn)
-
-
-class GroupApi(common_ldap.BaseLdap):
- DEFAULT_OU = 'ou=UserGroups'
- DEFAULT_STRUCTURAL_CLASSES = []
- DEFAULT_OBJECTCLASS = 'groupOfNames'
- DEFAULT_ID_ATTR = 'cn'
- DEFAULT_MEMBER_ATTRIBUTE = 'member'
- NotFound = exception.GroupNotFound
- options_name = 'group'
- attribute_options_names = {'description': 'desc',
- 'name': 'name'}
- immutable_attrs = ['name']
- model = models.Group
-
- def _ldap_res_to_model(self, res):
- model = super(GroupApi, self)._ldap_res_to_model(res)
- model['dn'] = res[0]
- return model
-
- def __init__(self, conf):
- super(GroupApi, self).__init__(conf)
- self.member_attribute = (conf.ldap.group_member_attribute
- or self.DEFAULT_MEMBER_ATTRIBUTE)
-
- def create(self, values):
- data = values.copy()
- if data.get('id') is None:
- data['id'] = uuid.uuid4().hex
- if 'description' in data and data['description'] in ['', None]:
- data.pop('description')
- return super(GroupApi, self).create(data)
-
- def delete(self, group_id):
- if self.subtree_delete_enabled:
- super(GroupApi, self).delete_tree(group_id)
- else:
- # TODO(spzala): this is only placeholder for group and domain
- # role support which will be added under bug 1101287
-
- group_ref = self.get(group_id)
- group_dn = group_ref['dn']
- if group_dn:
- self._delete_tree_nodes(group_dn, ldap.SCOPE_ONELEVEL)
- super(GroupApi, self).delete(group_id)
-
- def update(self, group_id, values):
- old_obj = self.get(group_id)
- return super(GroupApi, self).update(group_id, values, old_obj)
-
- def add_user(self, user_dn, group_id, user_id):
- group_ref = self.get(group_id)
- group_dn = group_ref['dn']
- try:
- super(GroupApi, self).add_member(user_dn, group_dn)
- except exception.Conflict:
- raise exception.Conflict(_(
- 'User %(user_id)s is already a member of group %(group_id)s') %
- {'user_id': user_id, 'group_id': group_id})
-
- def remove_user(self, user_dn, group_id, user_id):
- group_ref = self.get(group_id)
- group_dn = group_ref['dn']
- try:
- super(GroupApi, self).remove_member(user_dn, group_dn)
- except ldap.NO_SUCH_ATTRIBUTE:
- raise exception.UserNotFound(user_id=user_id)
-
- def list_user_groups(self, user_dn):
- """Return a list of groups for which the user is a member."""
- user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
- query = '(%s=%s)%s' % (self.member_attribute,
- user_dn_esc,
- self.ldap_filter or '')
- return self.get_all(query)
-
- def list_user_groups_filtered(self, user_dn, hints):
- """Return a filtered list of groups for which the user is a member."""
- user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
- query = '(%s=%s)%s' % (self.member_attribute,
- user_dn_esc,
- self.ldap_filter or '')
- return self.get_all_filtered(hints, query)
-
- def list_group_users(self, group_id):
- """Return a list of user dns which are members of a group."""
- group_ref = self.get(group_id)
- group_dn = group_ref['dn']
-
- try:
- attrs = self._ldap_get_list(group_dn, ldap.SCOPE_BASE,
- attrlist=[self.member_attribute])
- except ldap.NO_SUCH_OBJECT:
- raise self.NotFound(group_id=group_id)
-
- users = []
- for dn, member in attrs:
- user_dns = member.get(self.member_attribute, [])
- for user_dn in user_dns:
- if self._is_dumb_member(user_dn):
- continue
- users.append(user_dn)
- return users
-
- def get_filtered(self, group_id):
- group = self.get(group_id)
- return common_ldap.filter_entity(group)
-
- def get_filtered_by_name(self, group_name):
- group = self.get_by_name(group_name)
- return common_ldap.filter_entity(group)
-
- def get_all_filtered(self, hints, query=None):
- query = self.filter_query(hints, query)
- return [common_ldap.filter_entity(group)
- for group in self.get_all(query, hints)]
diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py
deleted file mode 100644
index 5680a8a2..00000000
--- a/keystone-moon/keystone/identity/backends/sql.py
+++ /dev/null
@@ -1,402 +0,0 @@
-# Copyright 2012 OpenStack Foundation
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-import sqlalchemy
-from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy import orm
-
-from keystone.common import driver_hints
-from keystone.common import sql
-from keystone.common import utils
-from keystone import exception
-from keystone.i18n import _
-from keystone import identity
-
-
-class User(sql.ModelBase, sql.DictBase):
- __tablename__ = 'user'
- attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
- 'default_project_id']
- id = sql.Column(sql.String(64), primary_key=True)
- enabled = sql.Column(sql.Boolean)
- extra = sql.Column(sql.JsonBlob())
- default_project_id = sql.Column(sql.String(64))
- local_user = orm.relationship('LocalUser', uselist=False,
- single_parent=True, lazy='subquery',
- cascade='all,delete-orphan', backref='user')
- federated_users = orm.relationship('FederatedUser',
- single_parent=True,
- lazy='subquery',
- cascade='all,delete-orphan',
- backref='user')
-
- # name property
- @hybrid_property
- def name(self):
- if self.local_user:
- return self.local_user.name
- elif self.federated_users:
- return self.federated_users[0].display_name
- else:
- return None
-
- @name.setter
- def name(self, value):
- if not self.local_user:
- self.local_user = LocalUser()
- self.local_user.name = value
-
- @name.expression
- def name(cls):
- return LocalUser.name
-
- # password property
- @hybrid_property
- def password(self):
- if self.local_user and self.local_user.passwords:
- return self.local_user.passwords[0].password
- else:
- return None
-
- @password.setter
- def password(self, value):
- if not value:
- if self.local_user and self.local_user.passwords:
- self.local_user.passwords = []
- else:
- if not self.local_user:
- self.local_user = LocalUser()
- if not self.local_user.passwords:
- self.local_user.passwords.append(Password())
- self.local_user.passwords[0].password = value
-
- @password.expression
- def password(cls):
- return Password.password
-
- # domain_id property
- @hybrid_property
- def domain_id(self):
- if self.local_user:
- return self.local_user.domain_id
- else:
- return None
-
- @domain_id.setter
- def domain_id(self, value):
- if not self.local_user:
- self.local_user = LocalUser()
- self.local_user.domain_id = value
-
- @domain_id.expression
- def domain_id(cls):
- return LocalUser.domain_id
-
- def to_dict(self, include_extra_dict=False):
- d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
- if 'default_project_id' in d and d['default_project_id'] is None:
- del d['default_project_id']
- return d
-
-
-class LocalUser(sql.ModelBase, sql.DictBase):
- __tablename__ = 'local_user'
- attributes = ['id', 'user_id', 'domain_id', 'name']
- id = sql.Column(sql.Integer, primary_key=True)
- user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
- ondelete='CASCADE'), unique=True)
- domain_id = sql.Column(sql.String(64), nullable=False)
- name = sql.Column(sql.String(255), nullable=False)
- passwords = orm.relationship('Password', single_parent=True,
- cascade='all,delete-orphan',
- backref='local_user')
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
-
-
-class Password(sql.ModelBase, sql.DictBase):
- __tablename__ = 'password'
- attributes = ['id', 'local_user_id', 'password']
- id = sql.Column(sql.Integer, primary_key=True)
- local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id',
- ondelete='CASCADE'))
- password = sql.Column(sql.String(128))
-
-
-class FederatedUser(sql.ModelBase, sql.ModelDictMixin):
- __tablename__ = 'federated_user'
- attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id',
- 'display_name']
- id = sql.Column(sql.Integer, primary_key=True)
- user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
- ondelete='CASCADE'))
- idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
- ondelete='CASCADE'))
- protocol_id = sql.Column(sql.String(64), nullable=False)
- unique_id = sql.Column(sql.String(255), nullable=False)
- display_name = sql.Column(sql.String(255), nullable=True)
- __table_args__ = (
- sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'),
- sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'],
- ['federation_protocol.id',
- 'federation_protocol.idp_id'])
- )
-
-
-class Group(sql.ModelBase, sql.DictBase):
- __tablename__ = 'group'
- attributes = ['id', 'name', 'domain_id', 'description']
- id = sql.Column(sql.String(64), primary_key=True)
- name = sql.Column(sql.String(64), nullable=False)
- domain_id = sql.Column(sql.String(64), nullable=False)
- description = sql.Column(sql.Text())
- extra = sql.Column(sql.JsonBlob())
- # Unique constraint across two columns to create the separation
- # rather than just only 'name' being unique
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),)
-
-
-class UserGroupMembership(sql.ModelBase, sql.DictBase):
- """Group membership join table."""
-
- __tablename__ = 'user_group_membership'
- user_id = sql.Column(sql.String(64),
- sql.ForeignKey('user.id'),
- primary_key=True)
- group_id = sql.Column(sql.String(64),
- sql.ForeignKey('group.id'),
- primary_key=True)
-
-
-class Identity(identity.IdentityDriverV8):
- # NOTE(henry-nash): Override the __init__() method so as to take a
- # config parameter to enable sql to be used as a domain-specific driver.
- def __init__(self, conf=None):
- self.conf = conf
- super(Identity, self).__init__()
-
- @property
- def is_sql(self):
- return True
-
- def _check_password(self, password, user_ref):
- """Check the specified password against the data store.
-
- Note that we'll pass in the entire user_ref in case the subclass
- needs things like user_ref.get('name')
- For further justification, please see the follow up suggestion at
- https://blueprints.launchpad.net/keystone/+spec/sql-identiy-pam
-
- """
- return utils.check_password(password, user_ref.password)
-
- # Identity interface
- def authenticate(self, user_id, password):
- with sql.session_for_read() as session:
- user_ref = None
- try:
- user_ref = self._get_user(session, user_id)
- except exception.UserNotFound:
- raise AssertionError(_('Invalid user / password'))
- if not self._check_password(password, user_ref):
- raise AssertionError(_('Invalid user / password'))
- return identity.filter_user(user_ref.to_dict())
-
- # user crud
-
- @sql.handle_conflicts(conflict_type='user')
- def create_user(self, user_id, user):
- user = utils.hash_user_password(user)
- with sql.session_for_write() as session:
- user_ref = User.from_dict(user)
- session.add(user_ref)
- return identity.filter_user(user_ref.to_dict())
-
- @driver_hints.truncated
- def list_users(self, hints):
- with sql.session_for_read() as session:
- query = session.query(User).outerjoin(LocalUser)
- user_refs = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(x.to_dict()) for x in user_refs]
-
- def _get_user(self, session, user_id):
- user_ref = session.query(User).get(user_id)
- if not user_ref:
- raise exception.UserNotFound(user_id=user_id)
- return user_ref
-
- def get_user(self, user_id):
- with sql.session_for_read() as session:
- return identity.filter_user(
- self._get_user(session, user_id).to_dict())
-
- def get_user_by_name(self, user_name, domain_id):
- with sql.session_for_read() as session:
- query = session.query(User).join(LocalUser)
- query = query.filter(sqlalchemy.and_(LocalUser.name == user_name,
- LocalUser.domain_id == domain_id))
- try:
- user_ref = query.one()
- except sql.NotFound:
- raise exception.UserNotFound(user_id=user_name)
- return identity.filter_user(user_ref.to_dict())
-
- @sql.handle_conflicts(conflict_type='user')
- def update_user(self, user_id, user):
- with sql.session_for_write() as session:
- user_ref = self._get_user(session, user_id)
- old_user_dict = user_ref.to_dict()
- user = utils.hash_user_password(user)
- for k in user:
- old_user_dict[k] = user[k]
- new_user = User.from_dict(old_user_dict)
- for attr in User.attributes:
- if attr != 'id':
- setattr(user_ref, attr, getattr(new_user, attr))
- user_ref.extra = new_user.extra
- return identity.filter_user(
- user_ref.to_dict(include_extra_dict=True))
-
- def add_user_to_group(self, user_id, group_id):
- with sql.session_for_write() as session:
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- rv = query.first()
- if rv:
- return
-
- session.add(UserGroupMembership(user_id=user_id,
- group_id=group_id))
-
- def check_user_in_group(self, user_id, group_id):
- with sql.session_for_read() as session:
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- if not query.first():
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
-
- def remove_user_from_group(self, user_id, group_id):
- # We don't check if user or group are still valid and let the remove
- # be tried anyway - in case this is some kind of clean-up operation
- with sql.session_for_write() as session:
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- membership_ref = query.first()
- if membership_ref is None:
- # Check if the group and user exist to return descriptive
- # exceptions.
- self.get_group(group_id)
- self.get_user(user_id)
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
- session.delete(membership_ref)
-
- def list_groups_for_user(self, user_id, hints):
- with sql.session_for_read() as session:
- self.get_user(user_id)
- query = session.query(Group).join(UserGroupMembership)
- query = query.filter(UserGroupMembership.user_id == user_id)
- query = sql.filter_limit_query(Group, query, hints)
- return [g.to_dict() for g in query]
-
- def list_users_in_group(self, group_id, hints):
- with sql.session_for_read() as session:
- self.get_group(group_id)
- query = session.query(User).outerjoin(LocalUser)
- query = query.join(UserGroupMembership)
- query = query.filter(UserGroupMembership.group_id == group_id)
- query = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(u.to_dict()) for u in query]
-
- def delete_user(self, user_id):
- with sql.session_for_write() as session:
- ref = self._get_user(session, user_id)
-
- q = session.query(UserGroupMembership)
- q = q.filter_by(user_id=user_id)
- q.delete(False)
-
- session.delete(ref)
-
- # group crud
-
- @sql.handle_conflicts(conflict_type='group')
- def create_group(self, group_id, group):
- with sql.session_for_write() as session:
- ref = Group.from_dict(group)
- session.add(ref)
- return ref.to_dict()
-
- @driver_hints.truncated
- def list_groups(self, hints):
- with sql.session_for_read() as session:
- query = session.query(Group)
- refs = sql.filter_limit_query(Group, query, hints)
- return [ref.to_dict() for ref in refs]
-
- def _get_group(self, session, group_id):
- ref = session.query(Group).get(group_id)
- if not ref:
- raise exception.GroupNotFound(group_id=group_id)
- return ref
-
- def get_group(self, group_id):
- with sql.session_for_read() as session:
- return self._get_group(session, group_id).to_dict()
-
- def get_group_by_name(self, group_name, domain_id):
- with sql.session_for_read() as session:
- query = session.query(Group)
- query = query.filter_by(name=group_name)
- query = query.filter_by(domain_id=domain_id)
- try:
- group_ref = query.one()
- except sql.NotFound:
- raise exception.GroupNotFound(group_id=group_name)
- return group_ref.to_dict()
-
- @sql.handle_conflicts(conflict_type='group')
- def update_group(self, group_id, group):
- with sql.session_for_write() as session:
- ref = self._get_group(session, group_id)
- old_dict = ref.to_dict()
- for k in group:
- old_dict[k] = group[k]
- new_group = Group.from_dict(old_dict)
- for attr in Group.attributes:
- if attr != 'id':
- setattr(ref, attr, getattr(new_group, attr))
- ref.extra = new_group.extra
- return ref.to_dict()
-
- def delete_group(self, group_id):
- with sql.session_for_write() as session:
- ref = self._get_group(session, group_id)
-
- q = session.query(UserGroupMembership)
- q = q.filter_by(group_id=group_id)
- q.delete(False)
-
- session.delete(ref)