summaryrefslogtreecommitdiffstats
path: root/vstf/vstf/common/message.py
blob: 926091fbe96bf3549db60eb8d7d6db6f542035ba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
import uuid
import logging
import traceback
from vstf.common import constants

LOG = logging.getLogger(__name__)


def json_defaults(obj):
    if isinstance(obj, set):
        return list(obj)
    return "unknow obj"


def encode(msg):
    """obj to string"""
    if isinstance(msg, str):
        return msg
    else:
        return json.dumps(msg, default=json_defaults)


def decode(msg):
    """string to obj"""
    if isinstance(msg, str):
        return json.loads(msg)
    else:
        return msg


def gen_corrid():
    return str(uuid.uuid4())


def add_context(msg, **kwargs):
    return {'head': kwargs, 'body': msg}


def get_context(msg):
    if "head" in msg.iterkeys():
        return msg['head']
    else:
        return ""


def get_body(msg):
    if "body" in msg.iterkeys():
        return msg['body']
    else:
        return None


def get_corrid(context):
    """
    :param return: string of corrid or empty
    """
    if "corrid" in context.iterkeys():
        return context['corrid']
    else:
        return ""


def send(func, data):
    # the message must be a string
    if not isinstance(data, str):
        raise ValueError("the data must be a string")

    # the message's len must > 0
    msg_len = len(data)
    if msg_len <= 0:
        return True

    # the message's len must be less 999999999
    if len(str(msg_len)) > constants.MSG_FLAG_LEN:
        raise ValueError("the data's len too long")

    data = (constants.MSG_FLAG % (msg_len)) + data
    total_send = msg_len + constants.MSG_FLAG_LEN

    count = 0
    while count < total_send:
        sent = func(data[count:])
        if 0 == sent:
            raise RuntimeError("socket connection broken")
        count += sent

    return msg_len


def sendto(func, data, addr):
    # the message must be a string
    if not isinstance(data, str):
        raise ValueError("the data must be a string")

    # the message's len must > 0
    msg_len = len(data)
    if msg_len <= 0:
        return True

    # the message's len must be less 999999999
    if len(str(msg_len)) > constants.MSG_FLAG_LEN:
        raise ValueError("the data's len too long")

    data = (constants.MSG_FLAG % (msg_len)) + data
    total_send = msg_len + constants.MSG_FLAG_LEN

    count = 0
    while count < total_send:
        sent = func(data[count:], addr)
        if 0 == sent:
            raise RuntimeError("socket connection broken")
        count += sent

    return msg_len


def recv(func):
    head = func(constants.MSG_FLAG_LEN)
    # the FIN change to '' in python
    if head == '':
        raise RuntimeError("socket connection broken")

    if not head.isdigit():
        raise ValueError("the msg head is not a num.")

    msg_len = int(head)
    chunks = []
    count = 0
    while count < msg_len:
        chunk = func(min(msg_len - count, constants.buff_size))
        if chunk == '':
            raise RuntimeError("socket connection broken")
        chunks.append(chunk)
        count += len(chunk)

    return ''.join(chunks)


def dumpstrace():
    return traceback.format_exc()