summaryrefslogtreecommitdiffstats
path: root/laas-fog/source/utilities.py
blob: bbe0946700c82fdcdef2989be6e9849640f678d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
#############################################################################
#Copyright 2017 Parker Berberian and others                                 #
#                                                                           #
#Licensed under the Apache License, Version 2.0 (the "License");            #
#you may not use this file except in compliance with the License.           #
#You may obtain a copy of the License at                                    #
#                                                                           #
#    http://www.apache.org/licenses/LICENSE-2.0                             #
#                                                                           #
#Unless required by applicable law or agreed to in writing, software        #
#distributed under the License is distributed on an "AS IS" BASIS,          #
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.   #
#See the License for the specific language governing permissions and        #
#limitations under the License.                                             #
#############################################################################
"""

import os
import logging
import string
import sys
import subprocess
import xml.dom
import xml.dom.minidom
import re
import random
import yaml
from database import HostDataBase, BookingDataBase
from api.vpn import VPN
LOGGING_DIR = ""


class Utilities:
    """
    This class defines some useful functions that may be needed
    throughout the provisioning and deployment stage.
    The utility object is carried through most of the deployment process.
    """
    def __init__(self, host_ip, hostname, conf):
        """
        init function
        host_ip is the ip of the target host
        hostname is the FOG hostname of the host
        conf is the parsed config file
        """
        self.host = host_ip
        self.hostname = hostname
        root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
        self.scripts = os.path.join(root_dir, "hostScripts/")
        self.remoteDir = "/root/hostScripts/"
        self.conf = conf
        self.logger = logging.getLogger(hostname)

    def execRemoteScript(self, script, args=[]):
        """
        executes the given script on the
        remote host with the given args.
        script must be found in laas/hostScripts
        """
        cmd = [self.remoteDir+script]
        for arg in args:
            cmd.append(arg)
        self.sshExec(cmd)

    def waitForBoot(self):
        """
        Continually pings the host, waiting for it to boot
        """
        i = 0
        while (not self.pingHost()) and i < 30:
            i += 1
        if i == 30:
            self.logger.error("Host %s has not booted", self.host)
            sys.exit(1)

    def checkHost(self):
        """
        returns true if the host responds to two pings.
        Sometimes, while a host is pxe booting, a host will
        respond to one ping but quickly go back offline.
        """
        if self.pingHost() and self.pingHost():
            return True
        return False

    def pingHost(self):
        """
        returns true if the host responds to a ping
        """
        i = 0
        response = 1
        cmd = "ping -c 1 "+self.host
        cmd = cmd.split(' ')
        nul = open(os.devnull, 'w')
        while i < 10 and response != 0:
            response = subprocess.call(cmd, stdout=nul, stderr=nul)
            i = i + 1
        if response == 0:
            return True
        return False

    def copyDir(self, localDir, remoteDir):
        """
        uses scp to copy localDir to remoteDir on the
        remote host
        """
        cmd = "mkdir -p "+remoteDir
        self.sshExec(cmd.split(" "))
        cmd = "scp -o StrictHostKeyChecking=no -r "
        cmd += localDir+" root@"+self.host+":/root"
        cmd = cmd.split()
        nul = open(os.devnull, 'w')
        subprocess.call(cmd, stdout=nul, stderr=nul)

    def copyScripts(self):
        """
        Copies the hostScrpts dir to the remote host.
        """
        self.copyDir(self.scripts, self.remoteDir)

    def sshExec(self, args):
        """
        executes args as an ssh
        command on the remote host.
        """
        cmd = ['ssh', 'root@'+self.host]
        for arg in args:
            cmd.append(arg)
        nul = open(os.devnull, 'w')
        return subprocess.call(cmd, stdout=nul, stderr=nul)

    def resetKnownHosts(self):
        """
        edits your known hosts file to remove the previous entry of host
        Sometimes, the flashing process gives the remote host a new
        signature, and ssh complains about it.
        """
        lines = []
        sshFile = open('/root/.ssh/known_hosts', 'r')
        lines = sshFile.read()
        sshFile.close()
        lines = lines.split('\n')
        sshFile = open('/root/.ssh/known_hosts', 'w')
        for line in lines:
            if self.host not in line:
                sshFile.write(line+'\n')
        sshFile.close()

    def restartHost(self):
        """
        restarts the remote host
        """
        cmd = ['shutdown', '-r', 'now']
        self.sshExec(cmd)

    @staticmethod
    def randoString(length):
        """
        this is an adapted version of the code found here:
        https://stackoverflow.com/questions/2257441/
        random-string-generation-with-upper-case-letters-and-digits-in-python
        generates a random alphanumeric string of length length.
        """
        randStr = ''
        chars = string.ascii_uppercase + string.digits
        for x in range(length):
            randStr += random.SystemRandom().choice(chars)
        return randStr

    def changePassword(self):
        """
        Sets the root password to a random string and returns it
        """
        paswd = self.randoString(15)
        command = "printf "+paswd+" | passwd --stdin root"
        self.sshExec(command.split(' '))
        return paswd

    def markHostDeployed(self):
        """
        Tells the database that this host has finished its deployment
        """
        db = HostDataBase(self.conf['database'])
        db.makeHostDeployed(self.hostname)
        db.close()

    def make_vpn_user(self):
        """
        Creates a vpn user and associates it with this booking
        """
        config = yaml.safe_load(open(self.conf['vpn_config']))
        myVpn = VPN(config)
        # name = dashboard.getUserName()
        u, p, uid = myVpn.makeNewUser()  # may pass name arg if wanted
        self.logger.info("%s", "created new vpn user")
        self.logger.info("username: %s", u)
        self.logger.info("password: %s", p)
        self.logger.info("vpn user uid: %s", uid)
        self.add_vpn_user(uid)

    def add_vpn_user(self, uid):
        """
        Adds the dn of the vpn user to the database
        so that we can clean it once the booking ends
        """
        db = BookingDataBase(self.conf['database'])
        # converts from hostname to pharos resource id
        inventory = yaml.safe_load(open(self.conf['inventory']))
        host_id = -1
        for resource_id in inventory.keys():
            if inventory[resource_id] == self.hostname:
                host_id = resource_id
                break
        db.setVPN(host_id, uid)

    def finishDeployment(self):
        """
        Last method call once a host is finished being deployed.
        It notifies the database and changes the password to
        a random string
        """
        self.markHostDeployed()
        self.make_vpn_user()
        passwd = self.changePassword()
        self.logger.info("host %s provisioning done", self.hostname)
        self.logger.info("You may access the host at %s", self.host)
        self.logger.info("The password is %s", passwd)
        notice = "You should change all passwords for security"
        self.logger.warning('%s', notice)

    @staticmethod
    def restartRemoteHost(host_ip):
        """
        This method assumes that you already have ssh access to the target
        """
        nul = open(os.devnull, 'w')
        ret_code = subprocess.call([
            'ssh', '-o', 'StrictHostKeyChecking=no',
            'root@'+host_ip,
            'shutdown', '-r', 'now'],
            stdout=nul, stderr=nul)

        return ret_code

    @staticmethod
    def getName(xmlString):
        """
        Gets the name value from xml. for example:
        <name>Parker</name> returns Parker
        """
        xmlDoc = xml.dom.minidom.parseString(xmlString)
        nameNode = xmlDoc.documentElement.getElementsByTagName('name')
        name = str(nameNode[0].firstChild.nodeValue)
        return name

    @staticmethod
    def getXMLFiles(directory):
        """
        searches directory non-recursively and
        returns a list of all xml files
        """
        contents = os.listdir(directory)
        fileContents = []
        for item in contents:
            if os.path.isfile(os.path.join(directory, item)):
                fileContents.append(os.path.join(directory, item))
        xmlFiles = []
        for item in fileContents:
            if 'xml' in os.path.basename(item):
                xmlFiles.append(item)
        return xmlFiles

    @staticmethod
    def createLogger(name, log_dir=LOGGING_DIR):
        """
        Initializes the logger if it does not yet exist, and returns it.
        Because of how python logging works, calling logging.getLogger()
        with the same name always returns a reference to the same log file.
        So we can call this method from anywhere with the hostname as
        the name arguement and it will return the log file for that host.
        The formatting includes the level of importance and the time stamp
        """
        global LOGGING_DIR
        if log_dir != LOGGING_DIR:
            LOGGING_DIR = log_dir
        log = logging.getLogger(name)
        if len(log.handlers) > 0:  # if this logger is already initialized
            return log
        log.setLevel(10)
        han = logging.FileHandler(os.path.join(log_dir, name+".log"))
        han.setLevel(10)
        log_format = '[%(levelname)s] %(asctime)s [#] %(message)s'
        formatter = logging.Formatter(fmt=log_format)
        han.setFormatter(formatter)
        log.addHandler(han)
        return log

    @staticmethod
    def getIPfromMAC(macAddr, logFile, remote=None):
        """
        searches through the dhcp logs for the given mac
        and returns the associated ip. Will retrieve the
        logFile from a remote host if remote is given.
        if given, remote should be an ip address or hostname that
        we can ssh to.
        """
        if remote is not None:
            logFile = Utilities.retrieveFile(remote, logFile)
        ip = Utilities.getIPfromLog(macAddr, logFile)
        if remote is not None:
            os.remove(logFile)
        return ip

    @staticmethod
    def retrieveFile(host, remote_loc, local_loc=os.getcwd()):
        """
        Retrieves file from host and puts it in the current directory
        unless local_loc is given.
        """
        subprocess.call(['scp', 'root@'+host+':'+remote_loc, local_loc])
        return os.path.join(local_loc, os.path.basename(remote_loc))

    @staticmethod
    def getIPfromLog(macAddr, logFile):
        """
        Helper method for getIPfromMAC.
        uses regex to find the ip address in the
        log
        """
        try:
            messagesFile = open(logFile, "r")
            allLines = messagesFile.readlines()
        except Exception:
            sys.exit(1)
        importantLines = []
        for line in allLines:
            if macAddr in line and "DHCPACK" in line:
                importantLines.append(line)
        ipRegex = r'(\d+\.\d+\.\d+\.\d+)'
        IPs = []
        for line in importantLines:
            IPs.append(re.findall(ipRegex, line))
        if len(IPs) > 0 and len(IPs[-1]) > 0:
            return IPs[-1][0]
        return None