aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/token/persistence/backends/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/token/persistence/backends/sql.py')
-rw-r--r--keystone-moon/keystone/token/persistence/backends/sql.py141
1 files changed, 72 insertions, 69 deletions
diff --git a/keystone-moon/keystone/token/persistence/backends/sql.py b/keystone-moon/keystone/token/persistence/backends/sql.py
index 6fc1d223..4b3439a1 100644
--- a/keystone-moon/keystone/token/persistence/backends/sql.py
+++ b/keystone-moon/keystone/token/persistence/backends/sql.py
@@ -53,7 +53,6 @@ def _expiry_range_batched(session, upper_bound_func, batch_size):
Return the timestamp of the next token that is `batch_size` rows from
being the oldest expired token.
"""
-
# This expiry strategy splits the tokens into roughly equal sized batches
# to be deleted. It does this by finding the timestamp of a token
# `batch_size` rows from the oldest token and yielding that to the caller.
@@ -79,7 +78,6 @@ def _expiry_range_batched(session, upper_bound_func, batch_size):
def _expiry_range_all(session, upper_bound_func):
"""Expires all tokens in one pass."""
-
yield upper_bound_func()
@@ -88,11 +86,11 @@ class Token(token.persistence.TokenDriverV8):
def get_token(self, token_id):
if token_id is None:
raise exception.TokenNotFound(token_id=token_id)
- session = sql.get_session()
- token_ref = session.query(TokenModel).get(token_id)
- if not token_ref or not token_ref.valid:
- raise exception.TokenNotFound(token_id=token_id)
- return token_ref.to_dict()
+ with sql.session_for_read() as session:
+ token_ref = session.query(TokenModel).get(token_id)
+ if not token_ref or not token_ref.valid:
+ raise exception.TokenNotFound(token_id=token_id)
+ return token_ref.to_dict()
def create_token(self, token_id, data):
data_copy = copy.deepcopy(data)
@@ -103,14 +101,12 @@ class Token(token.persistence.TokenDriverV8):
token_ref = TokenModel.from_dict(data_copy)
token_ref.valid = True
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
session.add(token_ref)
return token_ref.to_dict()
def delete_token(self, token_id):
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
token_ref = session.query(TokenModel).get(token_id)
if not token_ref or not token_ref.valid:
raise exception.TokenNotFound(token_id=token_id)
@@ -126,9 +122,8 @@ class Token(token.persistence.TokenDriverV8):
or the trustor's user ID, so will use trust_id to query the tokens.
"""
- session = sql.get_session()
token_list = []
- with session.begin():
+ with sql.session_for_write() as session:
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter_by(valid=True)
@@ -169,38 +164,37 @@ class Token(token.persistence.TokenDriverV8):
return False
def _list_tokens_for_trust(self, trust_id):
- session = sql.get_session()
- tokens = []
- now = timeutils.utcnow()
- query = session.query(TokenModel)
- query = query.filter(TokenModel.expires > now)
- query = query.filter(TokenModel.trust_id == trust_id)
-
- token_references = query.filter_by(valid=True)
- for token_ref in token_references:
- token_ref_dict = token_ref.to_dict()
- tokens.append(token_ref_dict['id'])
- return tokens
+ with sql.session_for_read() as session:
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter(TokenModel.expires > now)
+ query = query.filter(TokenModel.trust_id == trust_id)
+
+ token_references = query.filter_by(valid=True)
+ for token_ref in token_references:
+ token_ref_dict = token_ref.to_dict()
+ tokens.append(token_ref_dict['id'])
+ return tokens
def _list_tokens_for_user(self, user_id, tenant_id=None):
- session = sql.get_session()
- tokens = []
- now = timeutils.utcnow()
- query = session.query(TokenModel)
- query = query.filter(TokenModel.expires > now)
- query = query.filter(TokenModel.user_id == user_id)
-
- token_references = query.filter_by(valid=True)
- for token_ref in token_references:
- token_ref_dict = token_ref.to_dict()
- if self._tenant_matches(tenant_id, token_ref_dict):
- tokens.append(token_ref['id'])
- return tokens
+ with sql.session_for_read() as session:
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel)
+ query = query.filter(TokenModel.expires > now)
+ query = query.filter(TokenModel.user_id == user_id)
+
+ token_references = query.filter_by(valid=True)
+ for token_ref in token_references:
+ token_ref_dict = token_ref.to_dict()
+ if self._tenant_matches(tenant_id, token_ref_dict):
+ tokens.append(token_ref['id'])
+ return tokens
def _list_tokens_for_consumer(self, user_id, consumer_id):
tokens = []
- session = sql.get_session()
- with session.begin():
+ with sql.session_for_write() as session:
now = timeutils.utcnow()
query = session.query(TokenModel)
query = query.filter(TokenModel.expires > now)
@@ -225,19 +219,29 @@ class Token(token.persistence.TokenDriverV8):
return self._list_tokens_for_user(user_id, tenant_id)
def list_revoked_tokens(self):
- session = sql.get_session()
- tokens = []
- now = timeutils.utcnow()
- query = session.query(TokenModel.id, TokenModel.expires)
- query = query.filter(TokenModel.expires > now)
- token_references = query.filter_by(valid=False)
- for token_ref in token_references:
- record = {
- 'id': token_ref[0],
- 'expires': token_ref[1],
- }
- tokens.append(record)
- return tokens
+ with sql.session_for_read() as session:
+ tokens = []
+ now = timeutils.utcnow()
+ query = session.query(TokenModel.id, TokenModel.expires,
+ TokenModel.extra)
+ query = query.filter(TokenModel.expires > now)
+ token_references = query.filter_by(valid=False)
+ for token_ref in token_references:
+ token_data = token_ref[2]['token_data']
+ if 'access' in token_data:
+ # It's a v2 token.
+ audit_ids = token_data['access']['token']['audit_ids']
+ else:
+ # It's a v3 token.
+ audit_ids = token_data['token']['audit_ids']
+
+ record = {
+ 'id': token_ref[0],
+ 'expires': token_ref[1],
+ 'audit_id': audit_ids[0],
+ }
+ tokens.append(record)
+ return tokens
def _expiry_range_strategy(self, dialect):
"""Choose a token range expiration strategy
@@ -245,7 +249,6 @@ class Token(token.persistence.TokenDriverV8):
Based on the DB dialect, select an expiry range callable that is
appropriate.
"""
-
# DB2 and MySQL can both benefit from a batched strategy. On DB2 the
# transaction log can fill up and on MySQL w/Galera, large
# transactions can exceed the maximum write set size.
@@ -266,18 +269,18 @@ class Token(token.persistence.TokenDriverV8):
return _expiry_range_all
def flush_expired_tokens(self):
- session = sql.get_session()
- dialect = session.bind.dialect.name
- expiry_range_func = self._expiry_range_strategy(dialect)
- query = session.query(TokenModel.expires)
- total_removed = 0
- upper_bound_func = timeutils.utcnow
- for expiry_time in expiry_range_func(session, upper_bound_func):
- delete_query = query.filter(TokenModel.expires <=
- expiry_time)
- row_count = delete_query.delete(synchronize_session=False)
- total_removed += row_count
- LOG.debug('Removed %d total expired tokens', total_removed)
-
- session.flush()
- LOG.info(_LI('Total expired tokens removed: %d'), total_removed)
+ with sql.session_for_write() as session:
+ dialect = session.bind.dialect.name
+ expiry_range_func = self._expiry_range_strategy(dialect)
+ query = session.query(TokenModel.expires)
+ total_removed = 0
+ upper_bound_func = timeutils.utcnow
+ for expiry_time in expiry_range_func(session, upper_bound_func):
+ delete_query = query.filter(TokenModel.expires <=
+ expiry_time)
+ row_count = delete_query.delete(synchronize_session=False)
+ total_removed += row_count
+ LOG.debug('Removed %d total expired tokens', total_removed)
+
+ session.flush()
+ LOG.info(_LI('Total expired tokens removed: %d'), total_removed)