aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/token/persistence/backends/sql.py
blob: fc70fb92615c3de4fc71bc7027f97d5950839ed3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# Copyright 2012 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import copy
import functools

from oslo_config import cfg
from oslo_log import log
from oslo_utils import timeutils

from keystone.common import sql
from keystone import exception
from keystone.i18n import _LI
from keystone import token
from keystone.token import provider


CONF = cfg.CONF
LOG = log.getLogger(__name__)


class TokenModel(sql.ModelBase, sql.DictBase):
    __tablename__ = 'token'
    attributes = ['id', 'expires', 'user_id', 'trust_id']
    id = sql.Column(sql.String(64), primary_key=True)
    expires = sql.Column(sql.DateTime(), default=None)
    extra = sql.Column(sql.JsonBlob())
    valid = sql.Column(sql.Boolean(), default=True, nullable=False)
    user_id = sql.Column(sql.String(64))
    trust_id = sql.Column(sql.String(64))
    __table_args__ = (
        sql.Index('ix_token_expires', 'expires'),
        sql.Index('ix_token_expires_valid', 'expires', 'valid'),
        sql.Index('ix_token_user_id', 'user_id'),
        sql.Index('ix_token_trust_id', 'trust_id')
    )


def _expiry_range_batched(session, upper_bound_func, batch_size):
    """Returns the stop point of the next batch for expiration.

    Return the timestamp of the next token that is `batch_size` rows from
    being the oldest expired token.
    """

    # This expiry strategy splits the tokens into roughly equal sized batches
    # to be deleted.  It does this by finding the timestamp of a token
    # `batch_size` rows from the oldest token and yielding that to the caller.
    # It's expected that the caller will then delete all rows with a timestamp
    # equal to or older than the one yielded.  This may delete slightly more
    # tokens than the batch_size, but that should be ok in almost all cases.
    LOG.debug('Token expiration batch size: %d', batch_size)
    query = session.query(TokenModel.expires)
    query = query.filter(TokenModel.expires < upper_bound_func())
    query = query.order_by(TokenModel.expires)
    query = query.offset(batch_size - 1)
    query = query.limit(1)
    while True:
        try:
            next_expiration = query.one()[0]
        except sql.NotFound:
            # There are less than `batch_size` rows remaining, so fall
            # through to the normal delete
            break
        yield next_expiration
    yield upper_bound_func()


def _expiry_range_all(session, upper_bound_func):
    """Expires all tokens in one pass."""

    yield upper_bound_func()


class Token(token.persistence.Driver):
    # Public interface
    def get_token(self, token_id):
        if token_id is None:
            raise exception.TokenNotFound(token_id=token_id)
        session = sql.get_session()
        token_ref = session.query(TokenModel).get(token_id)
        if not token_ref or not token_ref.valid:
            raise exception.TokenNotFound(token_id=token_id)
        return token_ref.to_dict()

    def create_token(self, token_id, data):
        data_copy = copy.deepcopy(data)
        if not data_copy.get('expires'):
            data_copy['expires'] = provider.default_expire_time()
        if not data_copy.get('user_id'):
            data_copy['user_id'] = data_copy['user']['id']

        token_ref = TokenModel.from_dict(data_copy)
        token_ref.valid = True
        session = sql.get_session()
        with session.begin():
            session.add(token_ref)
        return token_ref.to_dict()

    def delete_token(self, token_id):
        session = sql.get_session()
        with session.begin():
            token_ref = session.query(TokenModel).get(token_id)
            if not token_ref or not token_ref.valid:
                raise exception.TokenNotFound(token_id=token_id)
            token_ref.valid = False

    def delete_tokens(self, user_id, tenant_id=None, trust_id=None,
                      consumer_id=None):
        """Deletes all tokens in one session

        The user_id will be ignored if the trust_id is specified. user_id
        will always be specified.
        If using a trust, the token's user_id is set to the trustee's user ID
        or the trustor's user ID, so will use trust_id to query the tokens.

        """
        session = sql.get_session()
        with session.begin():
            now = timeutils.utcnow()
            query = session.query(TokenModel)
            query = query.filter_by(valid=True)
            query = query.filter(TokenModel.expires > now)
            if trust_id:
                query = query.filter(TokenModel.trust_id == trust_id)
            else:
                query = query.filter(TokenModel.user_id == user_id)

            for token_ref in query.all():
                if tenant_id:
                    token_ref_dict = token_ref.to_dict()
                    if not self._tenant_matches(tenant_id, token_ref_dict):
                        continue
                if consumer_id:
                    token_ref_dict = token_ref.to_dict()
                    if not self._consumer_matches(consumer_id, token_ref_dict):
                        continue

                token_ref.valid = False

    def _tenant_matches(self, tenant_id, token_ref_dict):
        return ((tenant_id is None) or
                (token_ref_dict.get('tenant') and
                 token_ref_dict['tenant'].get('id') == tenant_id))

    def _consumer_matches(self, consumer_id, ref):
        if consumer_id is None:
            return True
        else:
            try:
                oauth = ref['token_data']['token'].get('OS-OAUTH1', {})
                return oauth and oauth['consumer_id'] == consumer_id
            except KeyError:
                return False

    def _list_tokens_for_trust(self, trust_id):
        session = sql.get_session()
        tokens = []
        now = timeutils.utcnow()
        query = session.query(TokenModel)
        query = query.filter(TokenModel.expires > now)
        query = query.filter(TokenModel.trust_id == trust_id)

        token_references = query.filter_by(valid=True)
        for token_ref in token_references:
            token_ref_dict = token_ref.to_dict()
            tokens.append(token_ref_dict['id'])
        return tokens

    def _list_tokens_for_user(self, user_id, tenant_id=None):
        session = sql.get_session()
        tokens = []
        now = timeutils.utcnow()
        query = session.query(TokenModel)
        query = query.filter(TokenModel.expires > now)
        query = query.filter(TokenModel.user_id == user_id)

        token_references = query.filter_by(valid=True)
        for token_ref in token_references:
            token_ref_dict = token_ref.to_dict()
            if self._tenant_matches(tenant_id, token_ref_dict):
                tokens.append(token_ref['id'])
        return tokens

    def _list_tokens_for_consumer(self, user_id, consumer_id):
        tokens = []
        session = sql.get_session()
        with session.begin():
            now = timeutils.utcnow()
            query = session.query(TokenModel)
            query = query.filter(TokenModel.expires > now)
            query = query.filter(TokenModel.user_id == user_id)
            token_references = query.filter_by(valid=True)

            for token_ref in token_references:
                token_ref_dict = token_ref.to_dict()
                if self._consumer_matches(consumer_id, token_ref_dict):
                    tokens.append(token_ref_dict['id'])
        return tokens

    def _list_tokens(self, user_id, tenant_id=None, trust_id=None,
                     consumer_id=None):
        if not CONF.token.revoke_by_id:
            return []
        if trust_id:
            return self._list_tokens_for_trust(trust_id)
        if consumer_id:
            return self._list_tokens_for_consumer(user_id, consumer_id)
        else:
            return self._list_tokens_for_user(user_id, tenant_id)

    def list_revoked_tokens(self):
        session = sql.get_session()
        tokens = []
        now = timeutils.utcnow()
        query = session.query(TokenModel.id, TokenModel.expires)
        query = query.filter(TokenModel.expires > now)
        token_references = query.filter_by(valid=False)
        for token_ref in token_references:
            record = {
                'id': token_ref[0],
                'expires': token_ref[1],
            }
            tokens.append(record)
        return tokens

    def _expiry_range_strategy(self, dialect):
        """Choose a token range expiration strategy

        Based on the DB dialect, select an expiry range callable that is
        appropriate.
        """

        # DB2 and MySQL can both benefit from a batched strategy.  On DB2 the
        # transaction log can fill up and on MySQL w/Galera, large
        # transactions can exceed the maximum write set size.
        if dialect == 'ibm_db_sa':
            # Limit of 100 is known to not fill a transaction log
            # of default maximum size while not significantly
            # impacting the performance of large token purges on
            # systems where the maximum transaction log size has
            # been increased beyond the default.
            return functools.partial(_expiry_range_batched,
                                     batch_size=100)
        elif dialect == 'mysql':
            # We want somewhat more than 100, since Galera replication delay is
            # at least RTT*2.  This can be a significant amount of time if
            # doing replication across a WAN.
            return functools.partial(_expiry_range_batched,
                                     batch_size=1000)
        return _expiry_range_all

    def flush_expired_tokens(self):
        session = sql.get_session()
        dialect = session.bind.dialect.name
        expiry_range_func = self._expiry_range_strategy(dialect)
        query = session.query(TokenModel.expires)
        total_removed = 0
        upper_bound_func = timeutils.utcnow
        for expiry_time in expiry_range_func(session, upper_bound_func):
            delete_query = query.filter(TokenModel.expires <=
                                        expiry_time)
            row_count = delete_query.delete(synchronize_session=False)
            total_removed += row_count
            LOG.debug('Removed %d total expired tokens', total_removed)

        session.flush()
        LOG.info(_LI('Total expired tokens removed: %d'), total_removed)