aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/identity/backends/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/identity/backends/sql.py')
-rw-r--r--keystone-moon/keystone/identity/backends/sql.py345
1 files changed, 221 insertions, 124 deletions
diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py
index d37240eb..5680a8a2 100644
--- a/keystone-moon/keystone/identity/backends/sql.py
+++ b/keystone-moon/keystone/identity/backends/sql.py
@@ -12,8 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
-from oslo_config import cfg
+import sqlalchemy
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy import orm
+from keystone.common import driver_hints
from keystone.common import sql
from keystone.common import utils
from keystone import exception
@@ -21,23 +24,84 @@ from keystone.i18n import _
from keystone import identity
-CONF = cfg.CONF
-
-
class User(sql.ModelBase, sql.DictBase):
__tablename__ = 'user'
attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
'default_project_id']
id = sql.Column(sql.String(64), primary_key=True)
- name = sql.Column(sql.String(255), nullable=False)
- domain_id = sql.Column(sql.String(64), nullable=False)
- password = sql.Column(sql.String(128))
enabled = sql.Column(sql.Boolean)
extra = sql.Column(sql.JsonBlob())
default_project_id = sql.Column(sql.String(64))
- # Unique constraint across two columns to create the separation
- # rather than just only 'name' being unique
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+ local_user = orm.relationship('LocalUser', uselist=False,
+ single_parent=True, lazy='subquery',
+ cascade='all,delete-orphan', backref='user')
+ federated_users = orm.relationship('FederatedUser',
+ single_parent=True,
+ lazy='subquery',
+ cascade='all,delete-orphan',
+ backref='user')
+
+ # name property
+ @hybrid_property
+ def name(self):
+ if self.local_user:
+ return self.local_user.name
+ elif self.federated_users:
+ return self.federated_users[0].display_name
+ else:
+ return None
+
+ @name.setter
+ def name(self, value):
+ if not self.local_user:
+ self.local_user = LocalUser()
+ self.local_user.name = value
+
+ @name.expression
+ def name(cls):
+ return LocalUser.name
+
+ # password property
+ @hybrid_property
+ def password(self):
+ if self.local_user and self.local_user.passwords:
+ return self.local_user.passwords[0].password
+ else:
+ return None
+
+ @password.setter
+ def password(self, value):
+ if not value:
+ if self.local_user and self.local_user.passwords:
+ self.local_user.passwords = []
+ else:
+ if not self.local_user:
+ self.local_user = LocalUser()
+ if not self.local_user.passwords:
+ self.local_user.passwords.append(Password())
+ self.local_user.passwords[0].password = value
+
+ @password.expression
+ def password(cls):
+ return Password.password
+
+ # domain_id property
+ @hybrid_property
+ def domain_id(self):
+ if self.local_user:
+ return self.local_user.domain_id
+ else:
+ return None
+
+ @domain_id.setter
+ def domain_id(self, value):
+ if not self.local_user:
+ self.local_user = LocalUser()
+ self.local_user.domain_id = value
+
+ @domain_id.expression
+ def domain_id(cls):
+ return LocalUser.domain_id
def to_dict(self, include_extra_dict=False):
d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
@@ -46,6 +110,49 @@ class User(sql.ModelBase, sql.DictBase):
return d
+class LocalUser(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'local_user'
+ attributes = ['id', 'user_id', 'domain_id', 'name']
+ id = sql.Column(sql.Integer, primary_key=True)
+ user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
+ ondelete='CASCADE'), unique=True)
+ domain_id = sql.Column(sql.String(64), nullable=False)
+ name = sql.Column(sql.String(255), nullable=False)
+ passwords = orm.relationship('Password', single_parent=True,
+ cascade='all,delete-orphan',
+ backref='local_user')
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+
+
+class Password(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'password'
+ attributes = ['id', 'local_user_id', 'password']
+ id = sql.Column(sql.Integer, primary_key=True)
+ local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id',
+ ondelete='CASCADE'))
+ password = sql.Column(sql.String(128))
+
+
+class FederatedUser(sql.ModelBase, sql.ModelDictMixin):
+ __tablename__ = 'federated_user'
+ attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id',
+ 'display_name']
+ id = sql.Column(sql.Integer, primary_key=True)
+ user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id',
+ ondelete='CASCADE'))
+ idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
+ ondelete='CASCADE'))
+ protocol_id = sql.Column(sql.String(64), nullable=False)
+ unique_id = sql.Column(sql.String(255), nullable=False)
+ display_name = sql.Column(sql.String(255), nullable=True)
+ __table_args__ = (
+ sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'),
+ sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'],
+ ['federation_protocol.id',
+ 'federation_protocol.idp_id'])
+ )
+
+
class Group(sql.ModelBase, sql.DictBase):
__tablename__ = 'group'
attributes = ['id', 'name', 'domain_id', 'description']
@@ -56,11 +163,12 @@ class Group(sql.ModelBase, sql.DictBase):
extra = sql.Column(sql.JsonBlob())
# Unique constraint across two columns to create the separation
# rather than just only 'name' being unique
- __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),)
class UserGroupMembership(sql.ModelBase, sql.DictBase):
"""Group membership join table."""
+
__tablename__ = 'user_group_membership'
user_id = sql.Column(sql.String(64),
sql.ForeignKey('user.id'),
@@ -74,11 +182,9 @@ class Identity(identity.IdentityDriverV8):
# NOTE(henry-nash): Override the __init__() method so as to take a
# config parameter to enable sql to be used as a domain-specific driver.
def __init__(self, conf=None):
+ self.conf = conf
super(Identity, self).__init__()
- def default_assignment_driver(self):
- return 'sql'
-
@property
def is_sql(self):
return True
@@ -96,33 +202,32 @@ class Identity(identity.IdentityDriverV8):
# Identity interface
def authenticate(self, user_id, password):
- session = sql.get_session()
- user_ref = None
- try:
- user_ref = self._get_user(session, user_id)
- except exception.UserNotFound:
- raise AssertionError(_('Invalid user / password'))
- if not self._check_password(password, user_ref):
- raise AssertionError(_('Invalid user / password'))
- return identity.filter_user(user_ref.to_dict())
+ with sql.session_for_read() as session:
+ user_ref = None
+ try:
+ user_ref = self._get_user(session, user_id)
+ except exception.UserNotFound:
+ raise AssertionError(_('Invalid user / password'))
+ if not self._check_password(password, user_ref):
+ raise AssertionError(_('Invalid user / password'))
+ return identity.filter_user(user_ref.to_dict())
# user crud
@sql.handle_conflicts(conflict_type='user')
def create_user(self, user_id, user):
user = utils.hash_user_password(user)
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
user_ref = User.from_dict(user)
session.add(user_ref)
- return identity.filter_user(user_ref.to_dict())
+ return identity.filter_user(user_ref.to_dict())
- @sql.truncated
+ @driver_hints.truncated
def list_users(self, hints):
- session = sql.get_session()
- query = session.query(User)
- user_refs = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(x.to_dict()) for x in user_refs]
+ with sql.session_for_read() as session:
+ query = session.query(User).outerjoin(LocalUser)
+ user_refs = sql.filter_limit_query(User, query, hints)
+ return [identity.filter_user(x.to_dict()) for x in user_refs]
def _get_user(self, session, user_id):
user_ref = session.query(User).get(user_id)
@@ -131,25 +236,24 @@ class Identity(identity.IdentityDriverV8):
return user_ref
def get_user(self, user_id):
- session = sql.get_session()
- return identity.filter_user(self._get_user(session, user_id).to_dict())
+ with sql.session_for_read() as session:
+ return identity.filter_user(
+ self._get_user(session, user_id).to_dict())
def get_user_by_name(self, user_name, domain_id):
- session = sql.get_session()
- query = session.query(User)
- query = query.filter_by(name=user_name)
- query = query.filter_by(domain_id=domain_id)
- try:
- user_ref = query.one()
- except sql.NotFound:
- raise exception.UserNotFound(user_id=user_name)
- return identity.filter_user(user_ref.to_dict())
+ with sql.session_for_read() as session:
+ query = session.query(User).join(LocalUser)
+ query = query.filter(sqlalchemy.and_(LocalUser.name == user_name,
+ LocalUser.domain_id == domain_id))
+ try:
+ user_ref = query.one()
+ except sql.NotFound:
+ raise exception.UserNotFound(user_id=user_name)
+ return identity.filter_user(user_ref.to_dict())
@sql.handle_conflicts(conflict_type='user')
def update_user(self, user_id, user):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
user_ref = self._get_user(session, user_id)
old_user_dict = user_ref.to_dict()
user = utils.hash_user_password(user)
@@ -160,76 +264,74 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(user_ref, attr, getattr(new_user, attr))
user_ref.extra = new_user.extra
- return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
+ return identity.filter_user(
+ user_ref.to_dict(include_extra_dict=True))
def add_user_to_group(self, user_id, group_id):
- session = sql.get_session()
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- rv = query.first()
- if rv:
- return
-
- with session.begin():
+ with sql.session_for_write() as session:
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ rv = query.first()
+ if rv:
+ return
+
session.add(UserGroupMembership(user_id=user_id,
group_id=group_id))
def check_user_in_group(self, user_id, group_id):
- session = sql.get_session()
- self.get_group(group_id)
- self.get_user(user_id)
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- if not query.first():
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
+ with sql.session_for_read() as session:
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ if not query.first():
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
def remove_user_from_group(self, user_id, group_id):
- session = sql.get_session()
# We don't check if user or group are still valid and let the remove
# be tried anyway - in case this is some kind of clean-up operation
- query = session.query(UserGroupMembership)
- query = query.filter_by(user_id=user_id)
- query = query.filter_by(group_id=group_id)
- membership_ref = query.first()
- if membership_ref is None:
- # Check if the group and user exist to return descriptive
- # exceptions.
- self.get_group(group_id)
- self.get_user(user_id)
- raise exception.NotFound(_("User '%(user_id)s' not found in"
- " group '%(group_id)s'") %
- {'user_id': user_id,
- 'group_id': group_id})
- with session.begin():
+ with sql.session_for_write() as session:
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ membership_ref = query.first()
+ if membership_ref is None:
+ # Check if the group and user exist to return descriptive
+ # exceptions.
+ self.get_group(group_id)
+ self.get_user(user_id)
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
session.delete(membership_ref)
def list_groups_for_user(self, user_id, hints):
- session = sql.get_session()
- self.get_user(user_id)
- query = session.query(Group).join(UserGroupMembership)
- query = query.filter(UserGroupMembership.user_id == user_id)
- query = sql.filter_limit_query(Group, query, hints)
- return [g.to_dict() for g in query]
+ with sql.session_for_read() as session:
+ self.get_user(user_id)
+ query = session.query(Group).join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.user_id == user_id)
+ query = sql.filter_limit_query(Group, query, hints)
+ return [g.to_dict() for g in query]
def list_users_in_group(self, group_id, hints):
- session = sql.get_session()
- self.get_group(group_id)
- query = session.query(User).join(UserGroupMembership)
- query = query.filter(UserGroupMembership.group_id == group_id)
- query = sql.filter_limit_query(User, query, hints)
- return [identity.filter_user(u.to_dict()) for u in query]
+ with sql.session_for_read() as session:
+ self.get_group(group_id)
+ query = session.query(User).outerjoin(LocalUser)
+ query = query.join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.group_id == group_id)
+ query = sql.filter_limit_query(User, query, hints)
+ return [identity.filter_user(u.to_dict()) for u in query]
def delete_user(self, user_id):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_user(session, user_id)
q = session.query(UserGroupMembership)
@@ -242,18 +344,17 @@ class Identity(identity.IdentityDriverV8):
@sql.handle_conflicts(conflict_type='group')
def create_group(self, group_id, group):
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
ref = Group.from_dict(group)
session.add(ref)
- return ref.to_dict()
+ return ref.to_dict()
- @sql.truncated
+ @driver_hints.truncated
def list_groups(self, hints):
- session = sql.get_session()
- query = session.query(Group)
- refs = sql.filter_limit_query(Group, query, hints)
- return [ref.to_dict() for ref in refs]
+ with sql.session_for_read() as session:
+ query = session.query(Group)
+ refs = sql.filter_limit_query(Group, query, hints)
+ return [ref.to_dict() for ref in refs]
def _get_group(self, session, group_id):
ref = session.query(Group).get(group_id)
@@ -262,25 +363,23 @@ class Identity(identity.IdentityDriverV8):
return ref
def get_group(self, group_id):
- session = sql.get_session()
- return self._get_group(session, group_id).to_dict()
+ with sql.session_for_read() as session:
+ return self._get_group(session, group_id).to_dict()
def get_group_by_name(self, group_name, domain_id):
- session = sql.get_session()
- query = session.query(Group)
- query = query.filter_by(name=group_name)
- query = query.filter_by(domain_id=domain_id)
- try:
- group_ref = query.one()
- except sql.NotFound:
- raise exception.GroupNotFound(group_id=group_name)
- return group_ref.to_dict()
+ with sql.session_for_read() as session:
+ query = session.query(Group)
+ query = query.filter_by(name=group_name)
+ query = query.filter_by(domain_id=domain_id)
+ try:
+ group_ref = query.one()
+ except sql.NotFound:
+ raise exception.GroupNotFound(group_id=group_name)
+ return group_ref.to_dict()
@sql.handle_conflicts(conflict_type='group')
def update_group(self, group_id, group):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
old_dict = ref.to_dict()
for k in group:
@@ -290,12 +389,10 @@ class Identity(identity.IdentityDriverV8):
if attr != 'id':
setattr(ref, attr, getattr(new_group, attr))
ref.extra = new_group.extra
- return ref.to_dict()
+ return ref.to_dict()
def delete_group(self, group_id):
- session = sql.get_session()
-
- with session.begin():
+ with sql.session_for_write() as session:
ref = self._get_group(session, group_id)
q = session.query(UserGroupMembership)