summaryrefslogtreecommitdiffstats
path: root/lib/thrift/transport/TTwisted.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/thrift/transport/TTwisted.py')
-rw-r--r--lib/thrift/transport/TTwisted.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/lib/thrift/transport/TTwisted.py b/lib/thrift/transport/TTwisted.py
new file mode 100644
index 0000000..29bbd4c
--- /dev/null
+++ b/lib/thrift/transport/TTwisted.py
@@ -0,0 +1,324 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import struct
+from cStringIO import StringIO
+
+from zope.interface import implements, Interface, Attribute
+from twisted.internet.protocol import ServerFactory, ClientFactory, \
+ connectionDone
+from twisted.internet import defer
+from twisted.internet.threads import deferToThread
+from twisted.protocols import basic
+from twisted.web import server, resource, http
+
+from thrift.transport import TTransport
+
+
+class TMessageSenderTransport(TTransport.TTransportBase):
+
+ def __init__(self):
+ self.__wbuf = StringIO()
+
+ def write(self, buf):
+ self.__wbuf.write(buf)
+
+ def flush(self):
+ msg = self.__wbuf.getvalue()
+ self.__wbuf = StringIO()
+ return self.sendMessage(msg)
+
+ def sendMessage(self, message):
+ raise NotImplementedError
+
+
+class TCallbackTransport(TMessageSenderTransport):
+
+ def __init__(self, func):
+ TMessageSenderTransport.__init__(self)
+ self.func = func
+
+ def sendMessage(self, message):
+ return self.func(message)
+
+
+class ThriftClientProtocol(basic.Int32StringReceiver):
+
+ MAX_LENGTH = 2 ** 31 - 1
+
+ def __init__(self, client_class, iprot_factory, oprot_factory=None):
+ self._client_class = client_class
+ self._iprot_factory = iprot_factory
+ if oprot_factory is None:
+ self._oprot_factory = iprot_factory
+ else:
+ self._oprot_factory = oprot_factory
+
+ self.recv_map = {}
+ self.started = defer.Deferred()
+
+ def dispatch(self, msg):
+ self.sendString(msg)
+
+ def connectionMade(self):
+ tmo = TCallbackTransport(self.dispatch)
+ self.client = self._client_class(tmo, self._oprot_factory)
+ self.started.callback(self.client)
+
+ def connectionLost(self, reason=connectionDone):
+ for k, v in self.client._reqs.iteritems():
+ tex = TTransport.TTransportException(
+ type=TTransport.TTransportException.END_OF_FILE,
+ message='Connection closed')
+ v.errback(tex)
+
+ def stringReceived(self, frame):
+ tr = TTransport.TMemoryBuffer(frame)
+ iprot = self._iprot_factory.getProtocol(tr)
+ (fname, mtype, rseqid) = iprot.readMessageBegin()
+
+ try:
+ method = self.recv_map[fname]
+ except KeyError:
+ method = getattr(self.client, 'recv_' + fname)
+ self.recv_map[fname] = method
+
+ method(iprot, mtype, rseqid)
+
+
+class ThriftSASLClientProtocol(ThriftClientProtocol):
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ MAX_LENGTH = 2 ** 31 - 1
+
+ def __init__(self, client_class, iprot_factory, oprot_factory=None,
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ """
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
+
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
+
+ from puresasl.client import SASLClient
+ self.SASLCLient = SASLClient
+
+ ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
+
+ self._sasl_negotiation_deferred = None
+ self._sasl_negotiation_status = None
+ self.client = None
+
+ if host is not None:
+ self.createSASLClient(host, service, mechanism, **sasl_kwargs)
+
+ def createSASLClient(self, host, service, mechanism, **kwargs):
+ self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
+
+ def dispatch(self, msg):
+ encoded = self.sasl.wrap(msg)
+ len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
+ ThriftClientProtocol.dispatch(self, len_and_encoded)
+
+ @defer.inlineCallbacks
+ def connectionMade(self):
+ self._sendSASLMessage(self.START, self.sasl.mechanism)
+ initial_message = yield deferToThread(self.sasl.process)
+ self._sendSASLMessage(self.OK, initial_message)
+
+ while True:
+ status, challenge = yield self._receiveSASLMessage()
+ if status == self.OK:
+ response = yield deferToThread(self.sasl.process, challenge)
+ self._sendSASLMessage(self.OK, response)
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ msg = "The server erroneously indicated that SASL " \
+ "negotiation was complete"
+ raise TTransport.TTransportException(msg, message=msg)
+ else:
+ break
+ else:
+ msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
+ raise TTransport.TTransportException(msg, message=msg)
+
+ self._sasl_negotiation_deferred = None
+ ThriftClientProtocol.connectionMade(self)
+
+ def _sendSASLMessage(self, status, body):
+ if body is None:
+ body = ""
+ header = struct.pack(">BI", status, len(body))
+ self.transport.write(header + body)
+
+ def _receiveSASLMessage(self):
+ self._sasl_negotiation_deferred = defer.Deferred()
+ self._sasl_negotiation_status = None
+ return self._sasl_negotiation_deferred
+
+ def connectionLost(self, reason=connectionDone):
+ if self.client:
+ ThriftClientProtocol.connectionLost(self, reason)
+
+ def dataReceived(self, data):
+ if self._sasl_negotiation_deferred:
+ # we got a sasl challenge in the format (status, length, challenge)
+ # save the status, let IntNStringReceiver piece the challenge data together
+ self._sasl_negotiation_status, = struct.unpack("B", data[0])
+ ThriftClientProtocol.dataReceived(self, data[1:])
+ else:
+ # normal frame, let IntNStringReceiver piece it together
+ ThriftClientProtocol.dataReceived(self, data)
+
+ def stringReceived(self, frame):
+ if self._sasl_negotiation_deferred:
+ # the frame is just a SASL challenge
+ response = (self._sasl_negotiation_status, frame)
+ self._sasl_negotiation_deferred.callback(response)
+ else:
+ # there's a second 4 byte length prefix inside the frame
+ decoded_frame = self.sasl.unwrap(frame[4:])
+ ThriftClientProtocol.stringReceived(self, decoded_frame)
+
+
+class ThriftServerProtocol(basic.Int32StringReceiver):
+
+ MAX_LENGTH = 2 ** 31 - 1
+
+ def dispatch(self, msg):
+ self.sendString(msg)
+
+ def processError(self, error):
+ self.transport.loseConnection()
+
+ def processOk(self, _, tmo):
+ msg = tmo.getvalue()
+
+ if len(msg) > 0:
+ self.dispatch(msg)
+
+ def stringReceived(self, frame):
+ tmi = TTransport.TMemoryBuffer(frame)
+ tmo = TTransport.TMemoryBuffer()
+
+ iprot = self.factory.iprot_factory.getProtocol(tmi)
+ oprot = self.factory.oprot_factory.getProtocol(tmo)
+
+ d = self.factory.processor.process(iprot, oprot)
+ d.addCallbacks(self.processOk, self.processError,
+ callbackArgs=(tmo,))
+
+
+class IThriftServerFactory(Interface):
+
+ processor = Attribute("Thrift processor")
+
+ iprot_factory = Attribute("Input protocol factory")
+
+ oprot_factory = Attribute("Output protocol factory")
+
+
+class IThriftClientFactory(Interface):
+
+ client_class = Attribute("Thrift client class")
+
+ iprot_factory = Attribute("Input protocol factory")
+
+ oprot_factory = Attribute("Output protocol factory")
+
+
+class ThriftServerFactory(ServerFactory):
+
+ implements(IThriftServerFactory)
+
+ protocol = ThriftServerProtocol
+
+ def __init__(self, processor, iprot_factory, oprot_factory=None):
+ self.processor = processor
+ self.iprot_factory = iprot_factory
+ if oprot_factory is None:
+ self.oprot_factory = iprot_factory
+ else:
+ self.oprot_factory = oprot_factory
+
+
+class ThriftClientFactory(ClientFactory):
+
+ implements(IThriftClientFactory)
+
+ protocol = ThriftClientProtocol
+
+ def __init__(self, client_class, iprot_factory, oprot_factory=None):
+ self.client_class = client_class
+ self.iprot_factory = iprot_factory
+ if oprot_factory is None:
+ self.oprot_factory = iprot_factory
+ else:
+ self.oprot_factory = oprot_factory
+
+ def buildProtocol(self, addr):
+ p = self.protocol(self.client_class, self.iprot_factory,
+ self.oprot_factory)
+ p.factory = self
+ return p
+
+
+class ThriftResource(resource.Resource):
+
+ allowedMethods = ('POST',)
+
+ def __init__(self, processor, inputProtocolFactory,
+ outputProtocolFactory=None):
+ resource.Resource.__init__(self)
+ self.inputProtocolFactory = inputProtocolFactory
+ if outputProtocolFactory is None:
+ self.outputProtocolFactory = inputProtocolFactory
+ else:
+ self.outputProtocolFactory = outputProtocolFactory
+ self.processor = processor
+
+ def getChild(self, path, request):
+ return self
+
+ def _cbProcess(self, _, request, tmo):
+ msg = tmo.getvalue()
+ request.setResponseCode(http.OK)
+ request.setHeader("content-type", "application/x-thrift")
+ request.write(msg)
+ request.finish()
+
+ def render_POST(self, request):
+ request.content.seek(0, 0)
+ data = request.content.read()
+ tmi = TTransport.TMemoryBuffer(data)
+ tmo = TTransport.TMemoryBuffer()
+
+ iprot = self.inputProtocolFactory.getProtocol(tmi)
+ oprot = self.outputProtocolFactory.getProtocol(tmo)
+
+ d = self.processor.process(iprot, oprot)
+ d.addCallback(self._cbProcess, request, tmo)
+ return server.NOT_DONE_YET