summaryrefslogtreecommitdiffstats
path: root/lib/thrift/TTornado.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/thrift/TTornado.py')
-rw-r--r--lib/thrift/TTornado.py182
1 files changed, 182 insertions, 0 deletions
diff --git a/lib/thrift/TTornado.py b/lib/thrift/TTornado.py
new file mode 100644
index 0000000..ef3e0f2
--- /dev/null
+++ b/lib/thrift/TTornado.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from __future__ import absolute_import
+import socket
+import struct
+
+import logging
+logger = logging.getLogger(__name__)
+
+from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
+
+from io import BytesIO
+from collections import deque
+from contextlib import contextmanager
+from tornado import gen, iostream, ioloop, tcpserver, concurrent
+
+__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
+
+
+class _Lock(object):
+ def __init__(self):
+ self._waiters = deque()
+
+ def acquired(self):
+ return len(self._waiters) > 0
+
+ @gen.coroutine
+ def acquire(self):
+ blocker = self._waiters[-1] if self.acquired() else None
+ future = concurrent.Future()
+ self._waiters.append(future)
+ if blocker:
+ yield blocker
+
+ raise gen.Return(self._lock_context())
+
+ def release(self):
+ assert self.acquired(), 'Lock not aquired'
+ future = self._waiters.popleft()
+ future.set_result(None)
+
+ @contextmanager
+ def _lock_context(self):
+ try:
+ yield
+ finally:
+ self.release()
+
+
+class TTornadoStreamTransport(TTransportBase):
+ """a framed, buffered transport over a Tornado stream"""
+ def __init__(self, host, port, stream=None, io_loop=None):
+ self.host = host
+ self.port = port
+ self.io_loop = io_loop or ioloop.IOLoop.current()
+ self.__wbuf = BytesIO()
+ self._read_lock = _Lock()
+
+ # servers provide a ready-to-go stream
+ self.stream = stream
+
+ def with_timeout(self, timeout, future):
+ return gen.with_timeout(timeout, future, self.io_loop)
+
+ @gen.coroutine
+ def open(self, timeout=None):
+ logger.debug('socket connecting')
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ self.stream = iostream.IOStream(sock)
+
+ try:
+ connect = self.stream.connect((self.host, self.port))
+ if timeout is not None:
+ yield self.with_timeout(timeout, connect)
+ else:
+ yield connect
+ except (socket.error, IOError, ioloop.TimeoutError) as e:
+ message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message=message)
+
+ raise gen.Return(self)
+
+ def set_close_callback(self, callback):
+ """
+ Should be called only after open() returns
+ """
+ self.stream.set_close_callback(callback)
+
+ def close(self):
+ # don't raise if we intend to close
+ self.stream.set_close_callback(None)
+ self.stream.close()
+
+ def read(self, _):
+ # The generated code for Tornado shouldn't do individual reads -- only
+ # frames at a time
+ assert False, "you're doing it wrong"
+
+ @contextmanager
+ def io_exception_context(self):
+ try:
+ yield
+ except (socket.error, IOError) as e:
+ raise TTransportException(
+ type=TTransportException.END_OF_FILE,
+ message=str(e))
+ except iostream.StreamBufferFullError as e:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message=str(e))
+
+ @gen.coroutine
+ def readFrame(self):
+ # IOStream processes reads one at a time
+ with (yield self._read_lock.acquire()):
+ with self.io_exception_context():
+ frame_header = yield self.stream.read_bytes(4)
+ if len(frame_header) == 0:
+ raise iostream.StreamClosedError('Read zero bytes from stream')
+ frame_length, = struct.unpack('!i', frame_header)
+ frame = yield self.stream.read_bytes(frame_length)
+ raise gen.Return(frame)
+
+ def write(self, buf):
+ self.__wbuf.write(buf)
+
+ def flush(self):
+ frame = self.__wbuf.getvalue()
+ # reset wbuf before write/flush to preserve state on underlying failure
+ frame_length = struct.pack('!i', len(frame))
+ self.__wbuf = BytesIO()
+ with self.io_exception_context():
+ return self.stream.write(frame_length + frame)
+
+
+class TTornadoServer(tcpserver.TCPServer):
+ def __init__(self, processor, iprot_factory, oprot_factory=None,
+ *args, **kwargs):
+ super(TTornadoServer, self).__init__(*args, **kwargs)
+
+ self._processor = processor
+ self._iprot_factory = iprot_factory
+ self._oprot_factory = (oprot_factory if oprot_factory is not None
+ else iprot_factory)
+
+ @gen.coroutine
+ def handle_stream(self, stream, address):
+ host, port = address
+ trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
+ io_loop=self.io_loop)
+ oprot = self._oprot_factory.getProtocol(trans)
+
+ try:
+ while not trans.stream.closed():
+ frame = yield trans.readFrame()
+ tr = TMemoryBuffer(frame)
+ iprot = self._iprot_factory.getProtocol(tr)
+ yield self._processor.process(iprot, oprot)
+ except Exception:
+ logger.exception('thrift exception in handle_stream')
+ trans.close()
+
+ logger.info('client disconnected %s:%d', host, port)