aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/identity/backends
diff options
context:
space:
mode:
authorDUVAL Thomas <thomas.duval@orange.com>2016-06-09 09:11:50 +0200
committerDUVAL Thomas <thomas.duval@orange.com>2016-06-09 09:11:50 +0200
commit2e7b4f2027a1147ca28301e4f88adf8274b39a1f (patch)
tree8b8d94001ebe6cc34106cf813b538911a8d66d9a /keystone-moon/keystone/identity/backends
parenta33bdcb627102a01244630a54cb4b5066b385a6a (diff)
Update Keystone core to Mitaka.
Change-Id: Ia10d6add16f4a9d25d1f42d420661c46332e69db
Diffstat (limited to 'keystone-moon/keystone/identity/backends')
-rw-r--r--keystone-moon/keystone/identity/backends/ldap.py62
-rw-r--r--keystone-moon/keystone/identity/backends/sql.py345
2 files changed, 265 insertions, 142 deletions
diff --git a/keystone-moon/keystone/identity/backends/ldap.py b/keystone-moon/keystone/identity/backends/ldap.py
index 1f33bacb..fe8e8477 100644
--- a/keystone-moon/keystone/identity/backends/ldap.py
+++ b/keystone-moon/keystone/identity/backends/ldap.py
@@ -17,6 +17,7 @@ import uuid
import ldap.filter
from oslo_config import cfg
from oslo_log import log
+from oslo_log import versionutils
import six
from keystone.common import clean
@@ -31,17 +32,20 @@ from keystone import identity
CONF = cfg.CONF
LOG = log.getLogger(__name__)
+_DEPRECATION_MSG = _('%s for the LDAP identity backend has been deprecated in '
+ 'the Mitaka release in favor of read-only identity LDAP '
+ 'access. It will be removed in the "O" release.')
+
class Identity(identity.IdentityDriverV8):
def __init__(self, conf=None):
super(Identity, self).__init__()
if conf is None:
- conf = CONF
- self.user = UserApi(conf)
- self.group = GroupApi(conf)
-
- def default_assignment_driver(self):
- return 'ldap'
+ self.conf = CONF
+ else:
+ self.conf = conf
+ self.user = UserApi(self.conf)
+ self.group = GroupApi(self.conf)
def is_domain_aware(self):
return False
@@ -87,11 +91,15 @@ class Identity(identity.IdentityDriverV8):
# CRUD
def create_user(self, user_id, user):
+ msg = _DEPRECATION_MSG % "create_user"
+ versionutils.report_deprecated_feature(LOG, msg)
self.user.check_allow_create()
user_ref = self.user.create(user)
return self.user.filter_attributes(user_ref)
def update_user(self, user_id, user):
+ msg = _DEPRECATION_MSG % "update_user"
+ versionutils.report_deprecated_feature(LOG, msg)
self.user.check_allow_update()
old_obj = self.user.get(user_id)
if 'name' in user and old_obj.get('name') != user['name']:
@@ -110,6 +118,8 @@ class Identity(identity.IdentityDriverV8):
return self.user.get_filtered(user_id)
def delete_user(self, user_id):
+ msg = _DEPRECATION_MSG % "delete_user"
+ versionutils.report_deprecated_feature(LOG, msg)
self.user.check_allow_delete()
user = self.user.get(user_id)
user_dn = user['dn']
@@ -122,6 +132,8 @@ class Identity(identity.IdentityDriverV8):
self.user.delete(user_id)
def create_group(self, group_id, group):
+ msg = _DEPRECATION_MSG % "create_group"
+ versionutils.report_deprecated_feature(LOG, msg)
self.group.check_allow_create()
group['name'] = clean.group_name(group['name'])
return common_ldap.filter_entity(self.group.create(group))
@@ -135,28 +147,39 @@ class Identity(identity.IdentityDriverV8):
return self.group.get_filtered_by_name(group_name)
def update_group(self, group_id, group):
+ msg = _DEPRECATION_MSG % "update_group"
+ versionutils.report_deprecated_feature(LOG, msg)
self.group.check_allow_update()
if 'name' in group:
group['name'] = clean.group_name(group['name'])
return common_ldap.filter_entity(self.group.update(group_id, group))
def delete_group(self, group_id):
+ msg = _DEPRECATION_MSG % "delete_group"
+ versionutils.report_deprecated_feature(LOG, msg)
self.group.check_allow_delete()
return self.group.delete(group_id)
def add_user_to_group(self, user_id, group_id):
+ msg = _DEPRECATION_MSG % "add_user_to_group"
+ versionutils.report_deprecated_feature(LOG, msg)
user_ref = self._get_user(user_id)
user_dn = user_ref['dn']
self.group.add_user(user_dn, group_id, user_id)
def remove_user_from_group(self, user_id, group_id):
+ msg = _DEPRECATION_MSG % "remove_user_from_group"
+ versionutils.report_deprecated_feature(LOG, msg)
user_ref = self._get_user(user_id)
user_dn = user_ref['dn']
self.group.remove_user(user_dn, group_id, user_id)
def list_groups_for_user(self, user_id, hints):
user_ref = self._get_user(user_id)
- user_dn = user_ref['dn']
+ if self.conf.ldap.group_members_are_ids:
+ user_dn = user_ref['id']
+ else:
+ user_dn = user_ref['dn']
return self.group.list_user_groups_filtered(user_dn, hints)
def list_groups(self, hints):
@@ -164,15 +187,19 @@ class Identity(identity.IdentityDriverV8):
def list_users_in_group(self, group_id, hints):
users = []
- for user_dn in self.group.list_group_users(group_id):
- user_id = self.user._dn_to_id(user_dn)
+ for user_key in self.group.list_group_users(group_id):
+ if self.conf.ldap.group_members_are_ids:
+ user_id = user_key
+ else:
+ user_id = self.user._dn_to_id(user_key)
+
try:
users.append(self.user.get_filtered(user_id))
except exception.UserNotFound:
- LOG.debug(("Group member '%(user_dn)s' not found in"
+ LOG.debug(("Group member '%(user_key)s' not found in"
" '%(group_id)s'. The user should be removed"
" from the group. The user will be ignored."),
- dict(user_dn=user_dn, group_id=group_id))
+ dict(user_key=user_key, group_id=group_id))
return users
def check_user_in_group(self, user_id, group_id):
@@ -201,6 +228,7 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
attribute_options_names = {'password': 'pass',
'email': 'mail',
'name': 'name',
+ 'description': 'description',
'enabled': 'enabled',
'default_project_id': 'default_project_id'}
immutable_attrs = ['id']
@@ -264,15 +292,15 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
return self.filter_attributes(user)
def get_all_filtered(self, hints):
- query = self.filter_query(hints)
- return [self.filter_attributes(user) for user in self.get_all(query)]
+ query = self.filter_query(hints, self.ldap_filter)
+ return [self.filter_attributes(user)
+ for user in self.get_all(query, hints)]
def filter_attributes(self, user):
return identity.filter_user(common_ldap.filter_entity(user))
def is_user(self, dn):
"""Returns True if the entry is a user."""
-
# NOTE(blk-u): It's easy to check if the DN is under the User tree,
# but may not be accurate. A more accurate test would be to fetch the
# entry to see if it's got the user objectclass, but this could be
@@ -314,7 +342,7 @@ class GroupApi(common_ldap.BaseLdap):
def delete(self, group_id):
if self.subtree_delete_enabled:
- super(GroupApi, self).deleteTree(group_id)
+ super(GroupApi, self).delete_tree(group_id)
else:
# TODO(spzala): this is only placeholder for group and domain
# role support which will be added under bug 1101287
@@ -349,7 +377,6 @@ class GroupApi(common_ldap.BaseLdap):
def list_user_groups(self, user_dn):
"""Return a list of groups for which the user is a member."""
-
user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
query = '(%s=%s)%s' % (self.member_attribute,
user_dn_esc,
@@ -358,7 +385,6 @@ class GroupApi(common_ldap.BaseLdap):
def list_user_groups_filtered(self, user_dn, hints):
"""Return a filtered list of groups for which the user is a member."""
-
user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
query = '(%s=%s)%s' % (self.member_attribute,
user_dn_esc,
@@ -396,4 +422,4 @@ class GroupApi(common_ldap.BaseLdap):
def get_all_filtered(self, hints, query=None):
query = self.filter_query(hints, query)
return [common_ldap.filter_entity(group)
- for group in self.get_all(query)]
+ for group in self.get_all(query, hints)]
diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py
index d37240eb..5680a8a2 100644
--- a/keystone-moon/keystone/identity/backends/sql.py
+++ b/keystone-moon/keystone/identity/backends/sql.py
@@ -12,8 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
-from oslo_config import cfg
+import sqlalchemy
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy import orm
+from keystone.common import driver_hints
from keystone.common import sql
from keystone.common import utils
from keystone import exception
@@ -21,23 +24,84 @@ from keystone.i18n import _
from keystone import identity
-CONF = cfg.CONF
-
-
class User(sql.ModelBase, sql.DictBase):
__tablename__ = 'user'
attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
'default_project_id']
id = sql.Column(sql.String(64), primary_key=True)
- name = sql.Column(sql.String(255), nullable=False)
- domain_id = sql.Column(sql.String(64), nullable=False)
- password = sql.Column(sql.String(128))
enabled = sql.Column(sql.Boolean)
extra = sql.Column(sql.JsonBlob())
default_project_id = sql.Column(sql.String(64))
- # Unique constraint across two columns to create the separation
- # rather than just only 'name' being unique
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+ local_user = orm.relationship('LocalUser', uselist=False,
+ single_parent=True, lazy='subquery',
+ cascade='all,delete-orphan', backref='user')
+ federated_users = orm.relationship('FederatedUser',
+ single_parent=True,
+ lazy='subquery',
+ cascade='all,delete-orphan',
+ backref='user')
+
+ # name property
+ @hybrid_property
+ def name(self):
+ if self.local_user:
+ return self.local_user.name
+ elif self.federated_users:
+ return self.federated_users[0].display_name
+ else:
+ return None
+
+ @name.setter
+ def name(self, value):
+ if not self.local_user:
+ self.local_user = LocalUser()
+ self.local_user.name = value
+
+ @name.expression
+ def name(cls):
+ return LocalUser.name
+
+ # password property
+ @hybrid_property
+ def password(self):
+ if self.local_user and self.local_user.passwords:
+ return self.local_user.passwords[0].password
+ else:
+ return None
+
+ @password.setter
+ def password(self, value):
+ if not value:
+ if self.local_user and self.local_user.passwords:
+ self.local_user.passwords = []
+ else:
+ if not self.local_user:
+ self.local_user = LocalUser()
+ if not self.local_user.passwords:
+ self.local_user.passwords.append(Password())
+ self.local_user.passwords[0].password = value
+
+ @password.expression
+ def password(cls):
+ return Password.password
+
+ # domain_id property
+ @hybrid_property
+ def domain_id(self):
+ if self.local_user:
+ return self.local_user.domain_id
+ else:
+ return None
+
+ @domain_id.setter
+ def domain_id(self, value):
+ if not self.local_user:
+ self.local_user = LocalUser()
+ self.local_user.domain_id = value
+
+ @domain_id.expression
+ def domain_id(cls):
+ return LocalUser.domain_id
def to_dict(self, include_extra_dict=False):
d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
@@ -46,6 +110,49 @@ class User(sql.ModelBase, sql.DictBase):
return d
+class LocalUser(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'local_user'
+ attributes = ['id', 'user_id', 'domain_id', 'name']
+ id = sql.Column(sql.Integer, primary_key=True)
+ user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
+ ondelete='CASCADE'), unique=True)
+ domain_id = sql.Column(sql.String(64), nullable=False)
+ name = sql.Column(sql.String(255), nullable=False)
+ passwords = orm.relationship('Password', single_parent=True,
+ cascade='all,delete-orphan',
+ backref='local_user')
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+
+
+class Password(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'password'
+ attributes = ['id', 'local_user_id', 'password']
+ id = sql.Column(sql.Integer, primary_key=True)
+ local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id',
+ ondelete='CASCADE'))
+ password = sql.Column(sql.String(128))
+
+
+class FederatedUser(sql.ModelBase, sql.ModelDictMixin):
+ __tablename__ = 'federated_user'
+ attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id',
+ 'display_name']
+ id = sql.Column(sql.Integer, primary_key=True)
+ user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
+ ondelete='CASCADE'))
+ idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
+ ondelete='CASCADE'))
+ protocol_id = sql.Column(sql.String(64), nullable=False)
+ unique_id = sql.Column(sql.String(255), nullable=False)
+ display_name = sql.Column(sql.String(255), nullable=True)
+ __table_args__ = (
+ sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'),
+ sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'],
+ ['federation_protocol.id',
+ 'federation_protocol.idp_id'])
+ )
+
+
class Group(sql.ModelBase, sql.DictBase):
__tablename__ = 'group'
attributes = ['id', 'name', 'domain_id', 'description']
@@ -56,11 +163,12 @@ class Group(sql.ModelBase, sql.DictBase):
extra = sql.Column(sql.JsonBlob())
# Unique constraint across two columns to create the separation
# rather than just only 'name' being unique
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),)
class UserGroupMembership(sql.ModelBase, sql.DictBase):
"""Group membership join table."""
+
__tablename__ = 'user_group_membership'
user_id = sql.Column(sql.String(64),
sql.ForeignKey('user.id'),
@@ -74,11 +182,9 @@ class Identity(identity.IdentityDriverV8):
# NOTE(henry-nash): Override the __init__() method so as to take a
# config parameter to enable sql to be used as a domain-specific driver.
def __init__(self, conf=None):
+ self.conf = conf
super(Identity, self).__init__()
- def default_assignment_driver(self):
- return 'sql'
-
@property
def is_sql(self):
return True
@@ -96,33 +202,32 @@ class Identity(identity.IdentityDriverV8):
# Identity interface
def authenticate(self, user_id, password):
- session = sql.get_session()
- user_ref = None
- try:
- user_ref = self._get_user(session, user_id)
- except exception.UserNotFound:
- raise AssertionError(_('Invalid user / password'))
- if not self._check_password(password, user_ref):
- raise AssertionError(_('Invalid user / password'))
- return identity.filter_user(user_ref.to_dict())
+ with sql.session_for_read() as session:
+ user_ref = None
+ try:
+ user_ref = self._get_user(session, user_id)
+ except exception.UserNotFound:
+ raise AssertionError(_('Invalid user / password'))
+ if not self._check_password(password, user_ref):
+ raise AssertionError(_('Invalid user / password'))
+ return identity.filter_user(user_ref.to_dict())
# user crud
@sql.handle_conflicts(conflict_type='user')
def create_user(self, user_id, user):
user = utils.hash_user_password(user)
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
user_ref = User.from_dict(user)
session.add(user_ref)
- return identity.filter_user(user_ref.to_dict())
+ return identity.filter_user(user_ref.to_dict())
- @sql.truncated
+ @driver_hints.truncated
def list_users(self, hints):
- session = sql.get_session()
- query = session.query(User)
- user_refs = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(x.to_dict()) for x in user_refs]
+ with sql.session_for_read() as session:
+ query = session.query(User).outerjoin(LocalUser)
+ user_refs = sql.filter_limit_query(User, query, hints)
+ return [identity.filter_user(x.to_dict()) for x in user_refs]
def _get_user(self, session, user_id):
user_ref = session.query(User).get(user_id)
@@ -131,25 +236,24 @@ class Identity(identity.IdentityDriverV8):
return user_ref
def get_user(self, user_id):
- session = sql.get_session()
- return identity.filter_user(self._get_user(session, user_id).to_dict())
+ with sql.session_for_read() as session:
+ return identity.filter_user(
+ self._get_user(session, user_id).to_dict())
def get_user_by_name(self, user_name, domain_id):
- session = sql.get_session()
- query = session.query(User)
- query = query.filter_by(name=user_name)
- query = query.filter_by(domain_id=domain_id)
- try:
- user_ref = query.one()
- except sql.NotFound:
- raise exception.UserNotFound(user_id=user_name)
- return identity.filter_user(user_ref.to_dict())
+ with sql.session_for_read() as session:
+ query = session.query(User).join(LocalUser)
+ query = query.filter(sqlalchemy.and_(LocalUser.name == user_name,
+ LocalUser.domain_id == domain_id))
+ try:
+ user_ref = query.one()
+ except sql.NotFound:
+ raise exception.UserNotFound(user_id=user_name)
+ return identity.filter_user(user_ref.to_dict())
@sql.handle_conflicts(conflict_type='user')
def update_user(self, user_id, user):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
user_ref = self._get_user(session, user_id)
old_user_dict = user_ref.to_dict()
user = utils.hash_user_password(user)
@@ -160,76 +264,74 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(user_ref, attr, getattr(new_user, attr))
user_ref.extra = new_user.extra
- return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
+ return identity.filter_user(
+ user_ref.to_dict(include_extra_dict=True))
def add_user_to_group(self, user_id, group_id):
- session = sql.get_session()
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- rv = query.first()
- if rv:
- return
-
- with session.begin():
+ with sql.session_for_write() as session:
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ rv = query.first()
+ if rv:
+ return
+
session.add(UserGroupMembership(user_id=user_id,
group_id=group_id))
def check_user_in_group(self, user_id, group_id):
- session = sql.get_session()
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- if not query.first():
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
+ with sql.session_for_read() as session:
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ if not query.first():
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
- session = sql.get_session()
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- membership_ref = query.first()
- if membership_ref is None:
- # Check if the group and user exist to return descriptive
- # exceptions.
- self.get_group(group_id)
- self.get_user(user_id)
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
- with session.begin():
+ with sql.session_for_write() as session:
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ membership_ref = query.first()
+ if membership_ref is None:
+ # Check if the group and user exist to return descriptive
+ # exceptions.
+ self.get_group(group_id)
+ self.get_user(user_id)
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
session.delete(membership_ref)
def list_groups_for_user(self, user_id, hints):
- session = sql.get_session()
- self.get_user(user_id)
- query = session.query(Group).join(UserGroupMembership)
- query = query.filter(UserGroupMembership.user_id == user_id)
- query = sql.filter_limit_query(Group, query, hints)
- return [g.to_dict() for g in query]
+ with sql.session_for_read() as session:
+ self.get_user(user_id)
+ query = session.query(Group).join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.user_id == user_id)
+ query = sql.filter_limit_query(Group, query, hints)
+ return [g.to_dict() for g in query]
def list_users_in_group(self, group_id, hints):
- session = sql.get_session()
- self.get_group(group_id)
- query = session.query(User).join(UserGroupMembership)
- query = query.filter(UserGroupMembership.group_id == group_id)
- query = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(u.to_dict()) for u in query]
+ with sql.session_for_read() as session:
+ self.get_group(group_id)
+ query = session.query(User).outerjoin(LocalUser)
+ query = query.join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.group_id == group_id)
+ query = sql.filter_limit_query(User, query, hints)
+ return [identity.filter_user(u.to_dict()) for u in query]
def delete_user(self, user_id):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_user(session, user_id)
q = session.query(UserGroupMembership)
@@ -242,18 +344,17 @@ class Identity(identity.IdentityDriverV8):
@sql.handle_conflicts(conflict_type='group')
def create_group(self, group_id, group):
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
ref = Group.from_dict(group)
session.add(ref)
- return ref.to_dict()
+ return ref.to_dict()
- @sql.truncated
+ @driver_hints.truncated
def list_groups(self, hints):
- session = sql.get_session()
- query = session.query(Group)
- refs = sql.filter_limit_query(Group, query, hints)
- return [ref.to_dict() for ref in refs]
+ with sql.session_for_read() as session:
+ query = session.query(Group)
+ refs = sql.filter_limit_query(Group, query, hints)
+ return [ref.to_dict() for ref in refs]
def _get_group(self, session, group_id):
ref = session.query(Group).get(group_id)
@@ -262,25 +363,23 @@ class Identity(identity.IdentityDriverV8):
return ref
def get_group(self, group_id):
- session = sql.get_session()
- return self._get_group(session, group_id).to_dict()
+ with sql.session_for_read() as session:
+ return self._get_group(session, group_id).to_dict()
def get_group_by_name(self, group_name, domain_id):
- session = sql.get_session()
- query = session.query(Group)
- query = query.filter_by(name=group_name)
- query = query.filter_by(domain_id=domain_id)
- try:
- group_ref = query.one()
- except sql.NotFound:
- raise exception.GroupNotFound(group_id=group_name)
- return group_ref.to_dict()
+ with sql.session_for_read() as session:
+ query = session.query(Group)
+ query = query.filter_by(name=group_name)
+ query = query.filter_by(domain_id=domain_id)
+ try:
+ group_ref = query.one()
+ except sql.NotFound:
+ raise exception.GroupNotFound(group_id=group_name)
+ return group_ref.to_dict()
@sql.handle_conflicts(conflict_type='group')
def update_group(self, group_id, group):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
old_dict = ref.to_dict()
for k in group:
@@ -290,12 +389,10 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_group, attr))
ref.extra = new_group.extra
- return ref.to_dict()
+ return ref.to_dict()
def delete_group(self, group_id):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
q = session.query(UserGroupMembership)