summaryrefslogtreecommitdiffstats
path: root/app/test/fetch/link_finders/test_find_implicit_links.py
blob: 993168818d2b91bd120a7476ee4a5defdc0e67d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
###############################################################################
# Copyright (c) 2017 Koren Lev (Cisco Systems), Yaron Yogev (Cisco Systems)   #
# and others                                                                  #
#                                                                             #
# All rights reserved. This program and the accompanying materials            #
# are made available under the terms of the Apache License, Version 2.0       #
# which accompanies this distribution, and is available at                    #
# http://www.apache.org/licenses/LICENSE-2.0                                  #
###############################################################################
import bson

from discover.link_finders.find_implicit_links import FindImplicitLinks
from test.fetch.test_fetch import TestFetch
from unittest.mock import MagicMock
from test.fetch.link_finders.test_data.test_find_implicit_links import *

from utils.inventory_mgr import InventoryMgr


class TestFindImplicitLinks(TestFetch):

    def setUp(self):
        super().setUp()
        self.configure_environment()
        self.fetcher = FindImplicitLinks()
        self.fetcher.set_env(ENV)
        self.fetcher.constraint_attributes = ['network']
        self.original_write_link = self.inv.write_link
        self.inv.write_link = lambda x: x
        self.original_objectid = bson.ObjectId
        bson.ObjectId = lambda x: x

    def tearDown(self):
        super().tearDown()
        bson.ObjectId = self.original_objectid
        self.inv.write_link = self.original_write_link

    def test_get_constraint_attributes(self):
        original_find = InventoryMgr.find
        InventoryMgr.find = MagicMock(return_value=CLIQUE_CONSTRAINTS)
        constraint_types = self.fetcher.get_constraint_attributes()
        self.assertEqual(sorted(constraint_types), sorted(CONSTRAINTS))
        InventoryMgr.find = original_find

    def test_constraints_match(self):
        matcher = self.fetcher.constraints_match
        self.assertTrue(matcher(LINK_ATTRIBUTES_NONE, LINK_ATTRIBUTES_NONE_2))
        self.assertTrue(matcher(LINK_ATTRIBUTES_NONE, LINK_ATTRIBUTES_EMPTY))
        self.assertTrue(matcher(LINK_ATTRIBUTES_NONE, LINK_ATTR_V1))
        self.assertTrue(matcher(LINK_ATTRIBUTES_EMPTY, LINK_ATTR_V1))
        self.assertTrue(matcher(LINK_ATTR_V1, LINK_ATTR_V1_2))
        self.assertTrue(matcher(LINK_ATTR_V1,
                                LINK_ATTR_V1_AND_A2V2))
        self.assertFalse(matcher(LINK_ATTR_V1, LINK_ATTR_V2))

    def test_links_match(self):
        matcher = self.fetcher.links_match
        self.assertFalse(matcher(LINK_TYPE_1, LINK_TYPE_1_2))
        self.assertFalse(matcher(LINK_TYPE_1, LINK_TYPE_1_REVERSED))
        self.assertFalse(matcher(LINK_TYPE_4_NET1, LINK_TYPE_5_NET2))
        self.assertFalse(matcher(LINK_TYPE_1_2, LINK_TYPE_2))
        self.assertTrue(matcher(LINK_TYPE_1, LINK_TYPE_2))

    def test_get_link_constraint_attributes(self):
        getter = self.fetcher.get_link_constraint_attributes
        self.assertEqual(getter(LINK_TYPE_1, LINK_TYPE_1_2), {})
        self.assertEqual(getter(LINK_TYPE_1, LINK_TYPE_4_NET1),
                         LINK_TYPE_4_NET1.get('attributes'))
        self.assertEqual(getter(LINK_TYPE_4_NET1, LINK_TYPE_1),
                         LINK_TYPE_4_NET1.get('attributes'))
        self.assertEqual(getter(LINK_TYPE_1, LINK_TYPE_5_NET2),
                         LINK_TYPE_5_NET2.get('attributes'))
        self.assertEqual(getter(LINK_TYPE_4_NET1, LINK_TYPE_6_NET1),
                         LINK_TYPE_4_NET1.get('attributes'))

    def test_get_attr(self):
        getter = self.fetcher.get_attr
        self.assertIsNone(getter('host', {}, {}))
        self.assertIsNone(getter('host', {'host': 'v1'}, {'host': 'v2'}))
        self.assertEqual(getter('host', {'host': 'v1'}, {}), 'v1')
        self.assertEqual(getter('host', {}, {'host': 'v2'}), 'v2')
        self.assertEqual(getter('host', {'host': 'v1'}, {'host': 'v1'}), 'v1')

    def test_add_implicit_link(self):
        original_write_link = self.inv.write_link
        self.inv.write_link = lambda x: x
        original_objectid = bson.ObjectId
        bson.ObjectId = lambda x: x
        add_func = self.fetcher.add_implicit_link
        self.assertEqual(add_func(LINK_TYPE_4_NET1, LINK_TYPE_6_NET1),
                         LINK_TYPE_7_NET1)
        bson.ObjectId = original_objectid
        self.inv.write_link = original_write_link

    def test_get_transitive_closure(self):
        self.fetcher.links = [
            {'pass': 0, 'link': LINK_FULL_A2B},
            {'pass': 0, 'link': LINK_FULL_B2C},
            {'pass': 0, 'link': LINK_FULL_C2D},
            {'pass': 0, 'link': LINK_FULL_D2E},
        ]
        self.fetcher.get_transitive_closure()
        for pass_no in range(1, len(IMPLICIT_LINKS)):
            implicit_links = [l for l in self.fetcher.links
                              if l['pass'] == pass_no]
            self.assertEqual(implicit_links, IMPLICIT_LINKS[pass_no-1],
                             'incorrect links for pass #{}'.format(pass_no))