# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # from __future__ import absolute_import import socket import struct import logging logger = logging.getLogger(__name__) from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer from io import BytesIO from collections import deque from contextlib import contextmanager from tornado import gen, iostream, ioloop, tcpserver, concurrent __all__ = ['TTornadoServer', 'TTornadoStreamTransport'] class _Lock(object): def __init__(self): self._waiters = deque() def acquired(self): return len(self._waiters) > 0 @gen.coroutine def acquire(self): blocker = self._waiters[-1] if self.acquired() else None future = concurrent.Future() self._waiters.append(future) if blocker: yield blocker raise gen.Return(self._lock_context()) def release(self): assert self.acquired(), 'Lock not aquired' future = self._waiters.popleft() future.set_result(None) @contextmanager def _lock_context(self): try: yield finally: self.release() class TTornadoStreamTransport(TTransportBase): """a framed, buffered transport over a Tornado stream""" def __init__(self, host, port, stream=None, io_loop=None): self.host = host self.port = port self.io_loop = io_loop or ioloop.IOLoop.current() self.__wbuf = BytesIO() self._read_lock = _Lock() # servers provide a ready-to-go stream self.stream = stream def with_timeout(self, timeout, future): return gen.with_timeout(timeout, future, self.io_loop) @gen.coroutine def open(self, timeout=None): logger.debug('socket connecting') sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.stream = iostream.IOStream(sock) try: connect = self.stream.connect((self.host, self.port)) if timeout is not None: yield self.with_timeout(timeout, connect) else: yield connect except (socket.error, IOError, ioloop.TimeoutError) as e: message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e) raise TTransportException( type=TTransportException.NOT_OPEN, message=message) raise gen.Return(self) def set_close_callback(self, callback): """ Should be called only after open() returns """ self.stream.set_close_callback(callback) def close(self): # don't raise if we intend to close self.stream.set_close_callback(None) self.stream.close() def read(self, _): # The generated code for Tornado shouldn't do individual reads -- only # frames at a time assert False, "you're doing it wrong" @contextmanager def io_exception_context(self): try: yield except (socket.error, IOError) as e: raise TTransportException( type=TTransportException.END_OF_FILE, message=str(e)) except iostream.StreamBufferFullError as e: raise TTransportException( type=TTransportException.UNKNOWN, message=str(e)) @gen.coroutine def readFrame(self): # IOStream processes reads one at a time with (yield self._read_lock.acquire()): with self.io_exception_context(): frame_header = yield self.stream.read_bytes(4) if len(frame_header) == 0: raise iostream.StreamClosedError('Read zero bytes from stream') frame_length, = struct.unpack('!i', frame_header) frame = yield self.stream.read_bytes(frame_length) raise gen.Return(frame) def write(self, buf): self.__wbuf.write(buf) def flush(self): frame = self.__wbuf.getvalue() # reset wbuf before write/flush to preserve state on underlying failure frame_length = struct.pack('!i', len(frame)) self.__wbuf = BytesIO() with self.io_exception_context(): return self.stream.write(frame_length + frame) class TTornadoServer(tcpserver.TCPServer): def __init__(self, processor, iprot_factory, oprot_factory=None, *args, **kwargs): super(TTornadoServer, self).__init__(*args, **kwargs) self._processor = processor self._iprot_factory = iprot_factory self._oprot_factory = (oprot_factory if oprot_factory is not None else iprot_factory) @gen.coroutine def handle_stream(self, stream, address): host, port = address trans = TTornadoStreamTransport(host=host, port=port, stream=stream, io_loop=self.io_loop) oprot = self._oprot_factory.getProtocol(trans) try: while not trans.stream.closed(): frame = yield trans.readFrame() tr = TMemoryBuffer(frame) iprot = self._iprot_factory.getProtocol(tr) yield self._processor.process(iprot, oprot) except Exception: logger.exception('thrift exception in handle_stream') trans.close() logger.info('client disconnected %s:%d', host, port)