aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/token/persistence/backends/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/token/persistence/backends/sql.py')
-rw-r--r--keystone-moon/keystone/token/persistence/backends/sql.py279
1 files changed, 279 insertions, 0 deletions
diff --git a/keystone-moon/keystone/token/persistence/backends/sql.py b/keystone-moon/keystone/token/persistence/backends/sql.py
new file mode 100644
index 00000000..fc70fb92
--- /dev/null
+++ b/keystone-moon/keystone/token/persistence/backends/sql.py
@@ -0,0 +1,279 @@
+# Copyright 2012 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+import functools
+
+from oslo_config import cfg
+from oslo_log import log
+from oslo_utils import timeutils
+
+from keystone.common import sql
+from keystone import exception
+from keystone.i18n import _LI
+from keystone import token
+from keystone.token import provider
+
+
+CONF = cfg.CONF
+LOG = log.getLogger(__name__)
+
+
+class TokenModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'token'
+ attributes = ['id', 'expires', 'user_id', 'trust_id']
+ id = sql.Column(sql.String(64), primary_key=True)
+ expires = sql.Column(sql.DateTime(), default=None)
+ extra = sql.Column(sql.JsonBlob())
+ valid = sql.Column(sql.Boolean(), default=True, nullable=False)
+ user_id = sql.Column(sql.String(64))
+ trust_id = sql.Column(sql.String(64))
+ __table_args__ = (
+ sql.Index('ix_token_expires', 'expires'),
+ sql.Index('ix_token_expires_valid', 'expires', 'valid'),
+ sql.Index('ix_token_user_id', 'user_id'),
+ sql.Index('ix_token_trust_id', 'trust_id')
+ )
+
+
+def _expiry_range_batched(session, upper_bound_func, batch_size):
+ """Returns the stop point of the next batch for expiration.
+
+ Return the timestamp of the next token that is `batch_size` rows from
+ being the oldest expired token.
+ """
+
+ # This expiry strategy splits the tokens into roughly equal sized batches
+ # to be deleted. It does this by finding the timestamp of a token
+ # `batch_size` rows from the oldest token and yielding that to the caller.
+ # It's expected that the caller will then delete all rows with a timestamp
+ # equal to or older than the one yielded. This may delete slightly more
+ # tokens than the batch_size, but that should be ok in almost all cases.
+ LOG.debug('Token expiration batch size: %d', batch_size)
+ query = session.query(TokenModel.expires)
+ query = query.filter(TokenModel.expires < upper_bound_func())
+ query = query.order_by(TokenModel.expires)
+ query = query.offset(batch_size - 1)
+ query = query.limit(1)
+ while True:
+ try:
+ next_expiration = query.one()[0]
+ except sql.NotFound:
+ # There are less than `batch_size` rows remaining, so fall
+ # through to the normal delete
+ break
+ yield next_expiration
+ yield upper_bound_func()
+
+
+def _expiry_range_all(session, upper_bound_func):
+ """Expires all tokens in one pass."""
+
+ yield upper_bound_func()
+
+
+class Token(token.persistence.Driver):
+ # Public interface
+ def get_token(self, token_id):
+ if token_id is None:
+ raise exception.TokenNotFound(token_id=token_id)
+ session = sql.get_session()
+ token_ref = session.query(TokenModel).get(token_id)
+ if not token_ref or not token_ref.valid:
+ raise exception.TokenNotFound(token_id=token_id)
+ return token_ref.to_dict()
+
+ def create_token(self, token_id, data):
+ data_copy = copy.deepcopy(data)
+ if not data_copy.get('expires'):
+ data_copy['expires'] = provider.default_expire_time()
+ if not data_copy.get('user_id'):
+ data_copy['user_id'] = data_copy['user']['id']
+
+ token_ref = TokenModel.from_dict(data_copy)
+ token_ref.valid = True
+ session = sql.get_session()
+ with session.begin():
+ session.add(token_ref)
+ return token_ref.to_dict()
+
+ def delete_token(self, token_id):
+ session = sql.get_session()
+ with session.begin():
+ token_ref = session.query(TokenModel).get(token_id)
+ if not token_ref or not token_ref.valid:
+ raise exception.TokenNotFound(token_id=token_id)
+ token_ref.valid = False
+
+ def delete_tokens(self, user_id, tenant_id=None, trust_id=None,
+ consumer_id=None):
+ """Deletes all tokens in one session
+
+ The user_id will be ignored if the trust_id is specified. user_id
+ will always be specified.
+ If using a trust, the token's user_id is set to the trustee's user ID
+ or the trustor's user ID, so will use trust_id to query the tokens.
+
+ """
+ session = sql.get_session()
+ with session.begin():
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter_by(valid=True)
+ query = query.filter(TokenModel.expires > now)
+ if trust_id:
+ query = query.filter(TokenModel.trust_id == trust_id)
+ else:
+ query = query.filter(TokenModel.user_id == user_id)
+
+ for token_ref in query.all():
+ if tenant_id:
+ token_ref_dict = token_ref.to_dict()
+ if not self._tenant_matches(tenant_id, token_ref_dict):
+ continue
+ if consumer_id:
+ token_ref_dict = token_ref.to_dict()
+ if not self._consumer_matches(consumer_id, token_ref_dict):
+ continue
+
+ token_ref.valid = False
+
+ def _tenant_matches(self, tenant_id, token_ref_dict):
+ return ((tenant_id is None) or
+ (token_ref_dict.get('tenant') and
+ token_ref_dict['tenant'].get('id') == tenant_id))
+
+ def _consumer_matches(self, consumer_id, ref):
+ if consumer_id is None:
+ return True
+ else:
+ try:
+ oauth = ref['token_data']['token'].get('OS-OAUTH1', {})
+ return oauth and oauth['consumer_id'] == consumer_id
+ except KeyError:
+ return False
+
+ def _list_tokens_for_trust(self, trust_id):
+ session = sql.get_session()
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter(TokenModel.expires > now)
+ query = query.filter(TokenModel.trust_id == trust_id)
+
+ token_references = query.filter_by(valid=True)
+ for token_ref in token_references:
+ token_ref_dict = token_ref.to_dict()
+ tokens.append(token_ref_dict['id'])
+ return tokens
+
+ def _list_tokens_for_user(self, user_id, tenant_id=None):
+ session = sql.get_session()
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter(TokenModel.expires > now)
+ query = query.filter(TokenModel.user_id == user_id)
+
+ token_references = query.filter_by(valid=True)
+ for token_ref in token_references:
+ token_ref_dict = token_ref.to_dict()
+ if self._tenant_matches(tenant_id, token_ref_dict):
+ tokens.append(token_ref['id'])
+ return tokens
+
+ def _list_tokens_for_consumer(self, user_id, consumer_id):
+ tokens = []
+ session = sql.get_session()
+ with session.begin():
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter(TokenModel.expires > now)
+ query = query.filter(TokenModel.user_id == user_id)
+ token_references = query.filter_by(valid=True)
+
+ for token_ref in token_references:
+ token_ref_dict = token_ref.to_dict()
+ if self._consumer_matches(consumer_id, token_ref_dict):
+ tokens.append(token_ref_dict['id'])
+ return tokens
+
+ def _list_tokens(self, user_id, tenant_id=None, trust_id=None,
+ consumer_id=None):
+ if not CONF.token.revoke_by_id:
+ return []
+ if trust_id:
+ return self._list_tokens_for_trust(trust_id)
+ if consumer_id:
+ return self._list_tokens_for_consumer(user_id, consumer_id)
+ else:
+ return self._list_tokens_for_user(user_id, tenant_id)
+
+ def list_revoked_tokens(self):
+ session = sql.get_session()
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel.id, TokenModel.expires)
+ query = query.filter(TokenModel.expires > now)
+ token_references = query.filter_by(valid=False)
+ for token_ref in token_references:
+ record = {
+ 'id': token_ref[0],
+ 'expires': token_ref[1],
+ }
+ tokens.append(record)
+ return tokens
+
+ def _expiry_range_strategy(self, dialect):
+ """Choose a token range expiration strategy
+
+ Based on the DB dialect, select an expiry range callable that is
+ appropriate.
+ """
+
+ # DB2 and MySQL can both benefit from a batched strategy. On DB2 the
+ # transaction log can fill up and on MySQL w/Galera, large
+ # transactions can exceed the maximum write set size.
+ if dialect == 'ibm_db_sa':
+ # Limit of 100 is known to not fill a transaction log
+ # of default maximum size while not significantly
+ # impacting the performance of large token purges on
+ # systems where the maximum transaction log size has
+ # been increased beyond the default.
+ return functools.partial(_expiry_range_batched,
+ batch_size=100)
+ elif dialect == 'mysql':
+ # We want somewhat more than 100, since Galera replication delay is
+ # at least RTT*2. This can be a significant amount of time if
+ # doing replication across a WAN.
+ return functools.partial(_expiry_range_batched,
+ batch_size=1000)
+ return _expiry_range_all
+
+ def flush_expired_tokens(self):
+ session = sql.get_session()
+ dialect = session.bind.dialect.name
+ expiry_range_func = self._expiry_range_strategy(dialect)
+ query = session.query(TokenModel.expires)
+ total_removed = 0
+ upper_bound_func = timeutils.utcnow
+ for expiry_time in expiry_range_func(session, upper_bound_func):
+ delete_query = query.filter(TokenModel.expires <=
+ expiry_time)
+ row_count = delete_query.delete(synchronize_session=False)
+ total_removed += row_count
+ LOG.debug('Removed %d total expired tokens', total_removed)
+
+ session.flush()
+ LOG.info(_LI('Total expired tokens removed: %d'), total_removed)