diff options
Diffstat (limited to 'keystone-moon/keystone/identity/backends/sql.py')
-rw-r--r-- | keystone-moon/keystone/identity/backends/sql.py | 345 |
1 files changed, 221 insertions, 124 deletions
diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py index d37240eb..5680a8a2 100644 --- a/keystone-moon/keystone/identity/backends/sql.py +++ b/keystone-moon/keystone/identity/backends/sql.py @@ -12,8 +12,11 @@ # License for the specific language governing permissions and limitations # under the License. -from oslo_config import cfg +import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy import orm +from keystone.common import driver_hints from keystone.common import sql from keystone.common import utils from keystone import exception @@ -21,23 +24,84 @@ from keystone.i18n import _ from keystone import identity -CONF = cfg.CONF - - class User(sql.ModelBase, sql.DictBase): __tablename__ = 'user' attributes = ['id', 'name', 'domain_id', 'password', 'enabled', 'default_project_id'] id = sql.Column(sql.String(64), primary_key=True) - name = sql.Column(sql.String(255), nullable=False) - domain_id = sql.Column(sql.String(64), nullable=False) - password = sql.Column(sql.String(128)) enabled = sql.Column(sql.Boolean) extra = sql.Column(sql.JsonBlob()) default_project_id = sql.Column(sql.String(64)) - # Unique constraint across two columns to create the separation - # rather than just only 'name' being unique - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + local_user = orm.relationship('LocalUser', uselist=False, + single_parent=True, lazy='subquery', + cascade='all,delete-orphan', backref='user') + federated_users = orm.relationship('FederatedUser', + single_parent=True, + lazy='subquery', + cascade='all,delete-orphan', + backref='user') + + # name property + @hybrid_property + def name(self): + if self.local_user: + return self.local_user.name + elif self.federated_users: + return self.federated_users[0].display_name + else: + return None + + @name.setter + def name(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.name = value + + @name.expression + def name(cls): + return LocalUser.name + + # password property + @hybrid_property + def password(self): + if self.local_user and self.local_user.passwords: + return self.local_user.passwords[0].password + else: + return None + + @password.setter + def password(self, value): + if not value: + if self.local_user and self.local_user.passwords: + self.local_user.passwords = [] + else: + if not self.local_user: + self.local_user = LocalUser() + if not self.local_user.passwords: + self.local_user.passwords.append(Password()) + self.local_user.passwords[0].password = value + + @password.expression + def password(cls): + return Password.password + + # domain_id property + @hybrid_property + def domain_id(self): + if self.local_user: + return self.local_user.domain_id + else: + return None + + @domain_id.setter + def domain_id(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.domain_id = value + + @domain_id.expression + def domain_id(cls): + return LocalUser.domain_id def to_dict(self, include_extra_dict=False): d = super(User, self).to_dict(include_extra_dict=include_extra_dict) @@ -46,6 +110,49 @@ class User(sql.ModelBase, sql.DictBase): return d +class LocalUser(sql.ModelBase, sql.DictBase): + __tablename__ = 'local_user' + attributes = ['id', 'user_id', 'domain_id', 'name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE'), unique=True) + domain_id = sql.Column(sql.String(64), nullable=False) + name = sql.Column(sql.String(255), nullable=False) + passwords = orm.relationship('Password', single_parent=True, + cascade='all,delete-orphan', + backref='local_user') + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + + +class Password(sql.ModelBase, sql.DictBase): + __tablename__ = 'password' + attributes = ['id', 'local_user_id', 'password'] + id = sql.Column(sql.Integer, primary_key=True) + local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id', + ondelete='CASCADE')) + password = sql.Column(sql.String(128)) + + +class FederatedUser(sql.ModelBase, sql.ModelDictMixin): + __tablename__ = 'federated_user' + attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id', + 'display_name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE')) + idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', + ondelete='CASCADE')) + protocol_id = sql.Column(sql.String(64), nullable=False) + unique_id = sql.Column(sql.String(255), nullable=False) + display_name = sql.Column(sql.String(255), nullable=True) + __table_args__ = ( + sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'), + sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'], + ['federation_protocol.id', + 'federation_protocol.idp_id']) + ) + + class Group(sql.ModelBase, sql.DictBase): __tablename__ = 'group' attributes = ['id', 'name', 'domain_id', 'description'] @@ -56,11 +163,12 @@ class Group(sql.ModelBase, sql.DictBase): extra = sql.Column(sql.JsonBlob()) # Unique constraint across two columns to create the separation # rather than just only 'name' being unique - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),) class UserGroupMembership(sql.ModelBase, sql.DictBase): """Group membership join table.""" + __tablename__ = 'user_group_membership' user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id'), @@ -74,11 +182,9 @@ class Identity(identity.IdentityDriverV8): # NOTE(henry-nash): Override the __init__() method so as to take a # config parameter to enable sql to be used as a domain-specific driver. def __init__(self, conf=None): + self.conf = conf super(Identity, self).__init__() - def default_assignment_driver(self): - return 'sql' - @property def is_sql(self): return True @@ -96,33 +202,32 @@ class Identity(identity.IdentityDriverV8): # Identity interface def authenticate(self, user_id, password): - session = sql.get_session() - user_ref = None - try: - user_ref = self._get_user(session, user_id) - except exception.UserNotFound: - raise AssertionError(_('Invalid user / password')) - if not self._check_password(password, user_ref): - raise AssertionError(_('Invalid user / password')) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + user_ref = None + try: + user_ref = self._get_user(session, user_id) + except exception.UserNotFound: + raise AssertionError(_('Invalid user / password')) + if not self._check_password(password, user_ref): + raise AssertionError(_('Invalid user / password')) + return identity.filter_user(user_ref.to_dict()) # user crud @sql.handle_conflicts(conflict_type='user') def create_user(self, user_id, user): user = utils.hash_user_password(user) - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: user_ref = User.from_dict(user) session.add(user_ref) - return identity.filter_user(user_ref.to_dict()) + return identity.filter_user(user_ref.to_dict()) - @sql.truncated + @driver_hints.truncated def list_users(self, hints): - session = sql.get_session() - query = session.query(User) - user_refs = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(x.to_dict()) for x in user_refs] + with sql.session_for_read() as session: + query = session.query(User).outerjoin(LocalUser) + user_refs = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(x.to_dict()) for x in user_refs] def _get_user(self, session, user_id): user_ref = session.query(User).get(user_id) @@ -131,25 +236,24 @@ class Identity(identity.IdentityDriverV8): return user_ref def get_user(self, user_id): - session = sql.get_session() - return identity.filter_user(self._get_user(session, user_id).to_dict()) + with sql.session_for_read() as session: + return identity.filter_user( + self._get_user(session, user_id).to_dict()) def get_user_by_name(self, user_name, domain_id): - session = sql.get_session() - query = session.query(User) - query = query.filter_by(name=user_name) - query = query.filter_by(domain_id=domain_id) - try: - user_ref = query.one() - except sql.NotFound: - raise exception.UserNotFound(user_id=user_name) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + query = session.query(User).join(LocalUser) + query = query.filter(sqlalchemy.and_(LocalUser.name == user_name, + LocalUser.domain_id == domain_id)) + try: + user_ref = query.one() + except sql.NotFound: + raise exception.UserNotFound(user_id=user_name) + return identity.filter_user(user_ref.to_dict()) @sql.handle_conflicts(conflict_type='user') def update_user(self, user_id, user): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: user_ref = self._get_user(session, user_id) old_user_dict = user_ref.to_dict() user = utils.hash_user_password(user) @@ -160,76 +264,74 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(user_ref, attr, getattr(new_user, attr)) user_ref.extra = new_user.extra - return identity.filter_user(user_ref.to_dict(include_extra_dict=True)) + return identity.filter_user( + user_ref.to_dict(include_extra_dict=True)) def add_user_to_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - rv = query.first() - if rv: - return - - with session.begin(): + with sql.session_for_write() as session: + self.get_group(group_id) + self.get_user(user_id) + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + rv = query.first() + if rv: + return + session.add(UserGroupMembership(user_id=user_id, group_id=group_id)) def check_user_in_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - if not query.first(): - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) + with sql.session_for_read() as session: + self.get_group(group_id) + self.get_user(user_id) + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + if not query.first(): + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) def remove_user_from_group(self, user_id, group_id): - session = sql.get_session() # We don't check if user or group are still valid and let the remove # be tried anyway - in case this is some kind of clean-up operation - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - membership_ref = query.first() - if membership_ref is None: - # Check if the group and user exist to return descriptive - # exceptions. - self.get_group(group_id) - self.get_user(user_id) - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) - with session.begin(): + with sql.session_for_write() as session: + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + membership_ref = query.first() + if membership_ref is None: + # Check if the group and user exist to return descriptive + # exceptions. + self.get_group(group_id) + self.get_user(user_id) + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) session.delete(membership_ref) def list_groups_for_user(self, user_id, hints): - session = sql.get_session() - self.get_user(user_id) - query = session.query(Group).join(UserGroupMembership) - query = query.filter(UserGroupMembership.user_id == user_id) - query = sql.filter_limit_query(Group, query, hints) - return [g.to_dict() for g in query] + with sql.session_for_read() as session: + self.get_user(user_id) + query = session.query(Group).join(UserGroupMembership) + query = query.filter(UserGroupMembership.user_id == user_id) + query = sql.filter_limit_query(Group, query, hints) + return [g.to_dict() for g in query] def list_users_in_group(self, group_id, hints): - session = sql.get_session() - self.get_group(group_id) - query = session.query(User).join(UserGroupMembership) - query = query.filter(UserGroupMembership.group_id == group_id) - query = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(u.to_dict()) for u in query] + with sql.session_for_read() as session: + self.get_group(group_id) + query = session.query(User).outerjoin(LocalUser) + query = query.join(UserGroupMembership) + query = query.filter(UserGroupMembership.group_id == group_id) + query = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(u.to_dict()) for u in query] def delete_user(self, user_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_user(session, user_id) q = session.query(UserGroupMembership) @@ -242,18 +344,17 @@ class Identity(identity.IdentityDriverV8): @sql.handle_conflicts(conflict_type='group') def create_group(self, group_id, group): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = Group.from_dict(group) session.add(ref) - return ref.to_dict() + return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_groups(self, hints): - session = sql.get_session() - query = session.query(Group) - refs = sql.filter_limit_query(Group, query, hints) - return [ref.to_dict() for ref in refs] + with sql.session_for_read() as session: + query = session.query(Group) + refs = sql.filter_limit_query(Group, query, hints) + return [ref.to_dict() for ref in refs] def _get_group(self, session, group_id): ref = session.query(Group).get(group_id) @@ -262,25 +363,23 @@ class Identity(identity.IdentityDriverV8): return ref def get_group(self, group_id): - session = sql.get_session() - return self._get_group(session, group_id).to_dict() + with sql.session_for_read() as session: + return self._get_group(session, group_id).to_dict() def get_group_by_name(self, group_name, domain_id): - session = sql.get_session() - query = session.query(Group) - query = query.filter_by(name=group_name) - query = query.filter_by(domain_id=domain_id) - try: - group_ref = query.one() - except sql.NotFound: - raise exception.GroupNotFound(group_id=group_name) - return group_ref.to_dict() + with sql.session_for_read() as session: + query = session.query(Group) + query = query.filter_by(name=group_name) + query = query.filter_by(domain_id=domain_id) + try: + group_ref = query.one() + except sql.NotFound: + raise exception.GroupNotFound(group_id=group_name) + return group_ref.to_dict() @sql.handle_conflicts(conflict_type='group') def update_group(self, group_id, group): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) old_dict = ref.to_dict() for k in group: @@ -290,12 +389,10 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_group, attr)) ref.extra = new_group.extra - return ref.to_dict() + return ref.to_dict() def delete_group(self, group_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) q = session.query(UserGroupMembership) |