summaryrefslogtreecommitdiffstats
path: root/app/utils/mongo_access.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/utils/mongo_access.py')
-rw-r--r--app/utils/mongo_access.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/app/utils/mongo_access.py b/app/utils/mongo_access.py
new file mode 100644
index 0000000..1425017
--- /dev/null
+++ b/app/utils/mongo_access.py
@@ -0,0 +1,137 @@
+###############################################################################
+# Copyright (c) 2017 Koren Lev (Cisco Systems), Yaron Yogev (Cisco Systems) #
+# and others #
+# #
+# All rights reserved. This program and the accompanying materials #
+# are made available under the terms of the Apache License, Version 2.0 #
+# which accompanies this distribution, and is available at #
+# http://www.apache.org/licenses/LICENSE-2.0 #
+###############################################################################
+import os
+
+from pymongo import MongoClient
+
+from utils.config_file import ConfigFile
+from utils.dict_naming_converter import DictNamingConverter
+from utils.logging.console_logger import ConsoleLogger
+from utils.logging.file_logger import FileLogger
+
+
+# Provides access to MongoDB using PyMongo library
+#
+# Notes on authentication:
+# default config file is calipso_mongo_access.conf
+# you can also specify name of file from CLI with --mongo_config
+
+
+class MongoAccess(DictNamingConverter):
+ client = None
+ db = None
+ default_conf_file = '/local_dir/calipso_mongo_access.conf'
+ config_file = None
+
+ DB_NAME = 'calipso'
+ LOG_FILENAME = 'mongo_access.log'
+ DEFAULT_LOG_FILE = os.path.join(os.path.abspath("."), LOG_FILENAME)
+
+ def __init__(self):
+ super().__init__()
+ self.log_file = os.path.join(FileLogger.LOG_DIRECTORY,
+ MongoAccess.LOG_FILENAME)
+
+ try:
+ self.log = FileLogger(self.log_file)
+ except OSError as e:
+ ConsoleLogger().warning("Couldn't use file {} for logging. "
+ "Using default location: {}.\n"
+ "Error: {}"
+ .format(self.log_file,
+ self.DEFAULT_LOG_FILE,
+ e))
+
+ self.log_file = self.DEFAULT_LOG_FILE
+ self.log = FileLogger(self.log_file)
+
+ self.connect_params = {}
+ self.mongo_connect(self.config_file)
+
+ def is_db_ready(self) -> bool:
+ return MongoAccess.client is not None
+
+ @staticmethod
+ def set_config_file(_conf_file):
+ MongoAccess.config_file = _conf_file
+
+ def mongo_connect(self, config_file_path=""):
+ if MongoAccess.client:
+ return
+
+ self.connect_params = {
+ "server": "localhost",
+ "port": 27017
+ }
+
+ if not config_file_path:
+ config_file_path = self.default_conf_file
+
+ try:
+ config_file = ConfigFile(config_file_path)
+ # read connection parameters from config file
+ config_params = config_file.read_config()
+ self.connect_params.update(config_params)
+ except Exception as e:
+ self.log.exception(e)
+ raise
+
+ self.prepare_connect_uri()
+ MongoAccess.client = MongoClient(
+ self.connect_params["server"],
+ self.connect_params["port"]
+ )
+ MongoAccess.db = getattr(MongoAccess.client,
+ config_params.get('auth_db', self.DB_NAME))
+ self.log.info('Connected to MongoDB')
+
+ def prepare_connect_uri(self):
+ params = self.connect_params
+ self.log.debug('connecting to MongoDb server: {}'
+ .format(params['server']))
+ uri = 'mongodb://'
+ if 'password' in params:
+ uri = uri + params['user'] + ':' + params['password'] + '@'
+ uri = uri + params['server']
+ if 'auth_db' in params:
+ uri = uri + '/' + params['auth_db']
+ self.connect_params['server'] = uri
+
+ @staticmethod
+ def update_document(collection, document, upsert=False):
+ if isinstance(collection, str):
+ collection = MongoAccess.db[collection]
+ doc_id = document.pop('_id')
+ collection.update_one({'_id': doc_id}, {'$set': document},
+ upsert=upsert)
+ document['_id'] = doc_id
+
+ @staticmethod
+ def encode_dots(s):
+ return s.replace(".", "[dot]")
+
+ @staticmethod
+ def decode_dots(s):
+ return s.replace("[dot]", ".")
+
+ # Mongo will not accept dot (".") in keys, or $ in start of keys
+ # $ in beginning of key does not happen in OpenStack,
+ # so need to translate only "." --> "[dot]"
+ @staticmethod
+ def encode_mongo_keys(item):
+ return MongoAccess.change_dict_naming_convention(item, MongoAccess.encode_dots)
+
+ @staticmethod
+ def decode_mongo_keys(item):
+ return MongoAccess.change_dict_naming_convention(item, MongoAccess.decode_dots)
+
+ @staticmethod
+ def decode_object_id(item: dict):
+ return dict(item, **{"_id": str(item["_id"])}) if item and "_id" in item else item