From 2e35a7e46f0929438c1c206e3116caa829f07dc6 Mon Sep 17 00:00:00 2001 From: Thomas Duval Date: Fri, 5 Oct 2018 16:54:37 +0200 Subject: Update code to 4.6 official version Change-Id: Ibd0da0e476e24b2685f54693efc11f7a58d40a62 --- python_moondb/python_moondb/backends/sql.py | 496 +++++++++++++++++++--------- 1 file changed, 343 insertions(+), 153 deletions(-) (limited to 'python_moondb/python_moondb/backends/sql.py') diff --git a/python_moondb/python_moondb/backends/sql.py b/python_moondb/python_moondb/backends/sql.py index 7310e7f3..d25586ba 100644 --- a/python_moondb/python_moondb/backends/sql.py +++ b/python_moondb/python_moondb/backends/sql.py @@ -20,7 +20,8 @@ import sqlalchemy logger = logging.getLogger("moon.db.driver.sql") Base = declarative_base() -DEBUG = True if configuration.get_configuration("logging")['logging']['loggers']['moon']['level'] == "DEBUG" else False +DEBUG = True if configuration.get_configuration("logging")['logging']['loggers']['moon'][ + 'level'] == "DEBUG" else False class DictBase: @@ -50,7 +51,6 @@ class DictBase: class JsonBlob(sql_types.TypeDecorator): - impl = sql.Text def process_bind_param(self, value, dialect): @@ -174,6 +174,7 @@ class PerimeterDataBase(DictBase): id = sql.Column(sql.String(64), primary_key=True) name = sql.Column(sql.String(256), nullable=False) value = sql.Column(JsonBlob(), nullable=True) + @declared_attr def policy_id(cls): return sql.Column(sql.ForeignKey("policies.id"), nullable=False) @@ -254,10 +255,11 @@ class ActionAssignment(Base, PerimeterAssignmentBase): class MetaRule(Base, DictBase): __tablename__ = 'meta_rules' - attributes = ['id', 'name', 'subject_categories', 'object_categories', 'action_categories', 'value'] + attributes = ['id', 'name', 'subject_categories', 'object_categories', 'action_categories', + 'value'] id = sql.Column(sql.String(64), primary_key=True) name = sql.Column(sql.String(256), nullable=False) - subject_categories = sql.Column(JsonBlob(), nullable=True) + subject_categories = sql.Column(JsonBlob(), nullable=True) object_categories = sql.Column(JsonBlob(), nullable=True) action_categories = sql.Column(JsonBlob(), nullable=True) value = sql.Column(JsonBlob(), nullable=True) @@ -353,8 +355,10 @@ class PDPConnector(BaseConnector, PDPDriver): d.update(value_wo_name) setattr(ref, "value", d) return {ref.id: ref.to_dict()} - except sqlalchemy.exc.IntegrityError: - raise exceptions.PdpExisting + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.PdpExisting + raise error def delete_pdp(self, pdp_id): with self.get_session_for_write() as session: @@ -375,8 +379,10 @@ class PDPConnector(BaseConnector, PDPDriver): }) session.add(new) return {new.id: new.to_dict()} - except sqlalchemy.exc.IntegrityError: - raise exceptions.PdpExisting + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.PdpExisting + raise error def get_pdp(self, pdp_id=None): with self.get_session_for_read() as session: @@ -390,20 +396,27 @@ class PDPConnector(BaseConnector, PDPDriver): class PolicyConnector(BaseConnector, PolicyDriver): def update_policy(self, policy_id, value): - with self.get_session_for_write() as session: - query = session.query(Policy) - query = query.filter_by(id=policy_id) - ref = query.first() - if ref: - value_wo_other_info = copy.deepcopy(value) - value_wo_other_info.pop("name", None) - value_wo_other_info.pop("model_id", None) - ref.name = value["name"] - ref.model_id= value["model_id"] - d = dict(ref.value) - d.update(value_wo_other_info) - setattr(ref, "value", d) - return {ref.id: ref.to_dict()} + try: + with self.get_session_for_write() as session: + query = session.query(Policy) + query = query.filter_by(id=policy_id) + ref = query.first() + + if ref: + value_wo_other_info = copy.deepcopy(value) + value_wo_other_info.pop("name", None) + value_wo_other_info.pop("model_id", None) + ref.name = value["name"] + ref.model_id = value["model_id"] + d = dict(ref.value) + d.update(value_wo_other_info) + setattr(ref, "value", d) + return {ref.id: ref.to_dict()} + + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.PolicyExisting + raise error def delete_policy(self, policy_id): with self.get_session_for_write() as session: @@ -411,40 +424,49 @@ class PolicyConnector(BaseConnector, PolicyDriver): session.delete(ref) def add_policy(self, policy_id=None, value=None): - with self.get_session_for_write() as session: - value_wo_other_info = copy.deepcopy(value) - value_wo_other_info.pop("name", None) - value_wo_other_info.pop("model_id", None) - new = Policy.from_dict({ - "id": policy_id if policy_id else uuid4().hex, - "name": value["name"], - "model_id": value.get("model_id", ""), - "value": value_wo_other_info - }) - session.add(new) - return {new.id: new.to_dict()} + try: + with self.get_session_for_write() as session: + value_wo_other_info = copy.deepcopy(value) + value_wo_other_info.pop("name", None) + value_wo_other_info.pop("model_id", None) + new = Policy.from_dict({ + "id": policy_id if policy_id else uuid4().hex, + "name": value["name"], + "model_id": value.get("model_id", ""), + "value": value_wo_other_info + }) + session.add(new) + return {new.id: new.to_dict()} - def get_policies(self, policy_id=None): + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.PolicyExisting + raise error + + def get_policies(self, policy_id=None, policy_name=None): with self.get_session_for_read() as session: query = session.query(Policy) if policy_id: query = query.filter_by(id=policy_id) + elif policy_name: + query = query.filter_by(name=policy_name) + ref_list = query.all() return {_ref.id: _ref.to_dict() for _ref in ref_list} def __get_perimeters(self, ClassType, policy_id, perimeter_id=None): + # if not policy_id: + # raise exceptions.PolicyUnknown + with self.get_session_for_read() as session: query = session.query(ClassType) - ref_list = copy.deepcopy(query.all()) + if perimeter_id: - for _ref in ref_list: - _ref_value = _ref.to_return() - if perimeter_id == _ref.id: - if policy_id and policy_id in _ref_value["policy_list"]: - return {_ref.id: _ref_value} - else: - return {} - elif policy_id: + query = query.filter_by(id=perimeter_id) + + ref_list = copy.deepcopy(query.all()) + + if policy_id: results = [] for _ref in ref_list: _ref_value = _ref.to_return() @@ -453,9 +475,48 @@ class PolicyConnector(BaseConnector, PolicyDriver): return {_ref.id: _ref.to_return() for _ref in results} return {_ref.id: _ref.to_return() for _ref in ref_list} - def __set_perimeter(self, ClassType, ClassTypeException, policy_id, perimeter_id=None, value=None): + def __get_perimeter_by_name(self, ClassType, perimeter_name): + # if not policy_id: + # raise exceptions.PolicyUnknown + with self.get_session_for_read() as session: + query = session.query(ClassType) + if not perimeter_name or not perimeter_name.strip(): + raise exceptions.PerimeterContentError('invalid name') + query = query.filter_by(name=perimeter_name) + ref_list = copy.deepcopy(query.all()) + return {_ref.id: _ref.to_return() for _ref in ref_list} + + def __update_perimeter(self, class_type, class_type_exception, perimeter_id, value): + if not perimeter_id: + return exceptions.PerimeterContentError + with self.get_session_for_write() as session: + query = session.query(class_type) + query = query.filter_by(id=perimeter_id) + _perimeter = query.first() + if not _perimeter: + raise class_type_exception + temp_perimeter = copy.deepcopy(_perimeter.to_dict()) + if 'name' in value: + temp_perimeter['value']['name'] = value['name'] + if 'description' in value: + temp_perimeter['value']['description'] = value['description'] + if 'extra' in value: + temp_perimeter['value']['extra'] = value['extra'] + name = temp_perimeter['value']['name'] + temp_perimeter['value'].pop("name", None) + new_perimeter = class_type.from_dict({ + "id": temp_perimeter["id"], + "name": name, + "value": temp_perimeter["value"] + }) + _perimeter.value = new_perimeter.value + _perimeter.name = new_perimeter.name + return {_perimeter.id: _perimeter.to_return()} + + def __set_perimeter(self, ClassType, ClassTypeException, policy_id, perimeter_id=None, + value=None): if not value or "name" not in value or not value["name"].strip(): - raise exceptions.PerimeterNameInvalid + raise exceptions.PerimeterContentError('invalid name') with self.get_session_for_write() as session: _perimeter = None if perimeter_id: @@ -485,10 +546,16 @@ class PolicyConnector(BaseConnector, PolicyDriver): return {new.id: new.to_return()} else: _value = copy.deepcopy(_perimeter.to_dict()) - if "policy_list" not in _value["value"] or type(_value["value"]["policy_list"]) is not list: + if "policy_list" not in _value["value"] or type( + _value["value"]["policy_list"]) is not list: _value["value"]["policy_list"] = [] if policy_id and policy_id not in _value["value"]["policy_list"]: _value["value"]["policy_list"].append(policy_id) + else: + if policy_id: + raise exceptions.PolicyExisting + raise exceptions.PerimeterContentError + _value["value"].update(value) name = _value["value"]["name"] @@ -513,7 +580,10 @@ class PolicyConnector(BaseConnector, PolicyDriver): try: old_perimeter["value"]["policy_list"].remove(policy_id) new_perimeter = ClassType.from_dict(old_perimeter) - setattr(_perimeter, "value", getattr(new_perimeter, "value")) + if not new_perimeter.value["policy_list"]: + session.delete(_perimeter) + else: + setattr(_perimeter, "value", getattr(new_perimeter, "value")) except ValueError: if not _perimeter.value["policy_list"]: session.delete(_perimeter) @@ -521,11 +591,25 @@ class PolicyConnector(BaseConnector, PolicyDriver): def get_subjects(self, policy_id, perimeter_id=None): return self.__get_perimeters(Subject, policy_id, perimeter_id) + def get_subject_by_name(self, perimeter_name): + return self.__get_perimeter_by_name(Subject, perimeter_name) + def set_subject(self, policy_id, perimeter_id=None, value=None): try: - return self.__set_perimeter(Subject, exceptions.SubjectExisting, policy_id, perimeter_id=perimeter_id, value=value) - except sqlalchemy.exc.IntegrityError: - raise exceptions.SubjectExisting + return self.__set_perimeter(Subject, exceptions.SubjectExisting, policy_id, + perimeter_id=perimeter_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.SubjectExisting + raise error + + def update_subject(self, perimeter_id, value): + try: + return self.__update_perimeter(Subject, exceptions.SubjectExisting, perimeter_id, value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.SubjectExisting + raise error def delete_subject(self, policy_id, perimeter_id): self.__delete_perimeter(Subject, exceptions.SubjectUnknown, policy_id, perimeter_id) @@ -533,12 +617,26 @@ class PolicyConnector(BaseConnector, PolicyDriver): def get_objects(self, policy_id, perimeter_id=None): return self.__get_perimeters(Object, policy_id, perimeter_id) + def get_object_by_name(self, perimeter_name): + return self.__get_perimeter_by_name(Object, perimeter_name) + def set_object(self, policy_id, perimeter_id=None, value=None): try: - return self.__set_perimeter(Object, exceptions.ObjectExisting, policy_id, perimeter_id=perimeter_id, value=value) - except sqlalchemy.exc.IntegrityError as e: - logger.exception("IntegrityError {}".format(e)) - raise exceptions.ObjectExisting + return self.__set_perimeter(Object, exceptions.ObjectExisting, policy_id, + perimeter_id=perimeter_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + logger.exception("IntegrityError {}".format(error)) + if 'UNIQUE constraint' in str(error): + raise exceptions.ObjectExisting + raise error + + def update_object(self, perimeter_id, value): + try: + return self.__update_perimeter(Object, exceptions.ObjectExisting, perimeter_id, value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ObjectExisting + raise error def delete_object(self, policy_id, perimeter_id): self.__delete_perimeter(Object, exceptions.ObjectUnknown, policy_id, perimeter_id) @@ -546,18 +644,31 @@ class PolicyConnector(BaseConnector, PolicyDriver): def get_actions(self, policy_id, perimeter_id=None): return self.__get_perimeters(Action, policy_id, perimeter_id) + def get_action_by_name(self, perimeter_name): + return self.__get_perimeter_by_name(Action, perimeter_name) + def set_action(self, policy_id, perimeter_id=None, value=None): try: - return self.__set_perimeter(Action, exceptions.ActionExisting, policy_id, perimeter_id=perimeter_id, value=value) - except sqlalchemy.exc.IntegrityError: - raise exceptions.ActionExisting + return self.__set_perimeter(Action, exceptions.ActionExisting, policy_id, + perimeter_id=perimeter_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ActionExisting + raise error + + def update_action(self, perimeter_id, value): + try: + return self.__update_perimeter(Action, exceptions.ActionExisting, perimeter_id, value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ActionExisting + raise error def delete_action(self, policy_id, perimeter_id): self.__delete_perimeter(Action, exceptions.ActionUnknown, policy_id, perimeter_id) - def __is_data_exist(self, ClassType, data_id=None, category_id=None): - if not data_id: - return False + def __is_data_exist(self, ClassType, category_id=None): + with self.get_session_for_read() as session: query = session.query(ClassType) query = query.filter_by(category_id=category_id) @@ -573,8 +684,13 @@ class PolicyConnector(BaseConnector, PolicyDriver): query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) elif policy_id and category_id: query = query.filter_by(policy_id=policy_id, category_id=category_id) - else: + elif category_id: query = query.filter_by(category_id=category_id) + elif policy_id: + query = query.filter_by(policy_id=policy_id) + else: + raise exceptions.PolicyUnknown + ref_list = query.all() return { "policy_id": policy_id, @@ -582,7 +698,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): "data": {_ref.id: _ref.to_dict() for _ref in ref_list} } - def __set_data(self, ClassType, ClassTypeData, policy_id, data_id=None, category_id=None, value=None): + def __set_data(self, ClassType, ClassTypeData, policy_id, data_id=None, category_id=None, + value=None): with self.get_session_for_write() as session: query = session.query(ClassTypeData) query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) @@ -612,65 +729,81 @@ class PolicyConnector(BaseConnector, PolicyDriver): "data": {ref.id: ref.to_dict()} } - def __delete_data(self, ClassType, policy_id, data_id): + def __delete_data(self, ClassType, policy_id, category_id, data_id): + + if not data_id: + raise exceptions.DataUnknown with self.get_session_for_write() as session: query = session.query(ClassType) - query = query.filter_by(policy_id=policy_id, id=data_id) + if category_id: + query = query.filter_by(policy_id=policy_id, category_id=category_id, id=data_id) + else: + query = query.filter_by(policy_id=policy_id, id=data_id) ref = query.first() if ref: session.delete(ref) - def is_subject_data_exist(self, data_id=None, category_id=None): - return self.__is_data_exist(SubjectData, data_id=data_id, category_id=category_id) + def is_subject_data_exist(self, category_id=None): + return self.__is_data_exist(SubjectData, category_id=category_id) def get_subject_data(self, policy_id, data_id=None, category_id=None): return self.__get_data(SubjectData, policy_id, data_id=data_id, category_id=category_id) def set_subject_data(self, policy_id, data_id=None, category_id=None, value=None): try: - return self.__set_data(Subject, SubjectData, policy_id, data_id=data_id, category_id=category_id, value=value) - except sqlalchemy.exc.IntegrityError: - raise exceptions.SubjectScopeExisting + return self.__set_data(Subject, SubjectData, policy_id, data_id=data_id, + category_id=category_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.SubjectScopeExisting + raise error - def delete_subject_data(self, policy_id, data_id): - return self.__delete_data(SubjectData, policy_id, data_id) + def delete_subject_data(self, policy_id, category_id, data_id): + return self.__delete_data(SubjectData, policy_id, category_id, data_id) - def is_object_data_exist(self, data_id=None, category_id=None): - return self.__is_data_exist(ObjectData, data_id=data_id, category_id=category_id) + def is_object_data_exist(self, category_id=None): + return self.__is_data_exist(ObjectData, category_id=category_id) def get_object_data(self, policy_id, data_id=None, category_id=None): return self.__get_data(ObjectData, policy_id, data_id=data_id, category_id=category_id) def set_object_data(self, policy_id, data_id=None, category_id=None, value=None): try: - return self.__set_data(Object, ObjectData, policy_id, data_id=data_id, category_id=category_id, value=value) - except sqlalchemy.exc.IntegrityError: - raise exceptions.ObjectScopeExisting + return self.__set_data(Object, ObjectData, policy_id, data_id=data_id, + category_id=category_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ObjectScopeExisting + raise error - def delete_object_data(self, policy_id, data_id): - return self.__delete_data(ObjectData, policy_id, data_id) + def delete_object_data(self, policy_id, category_id, data_id): + return self.__delete_data(ObjectData, policy_id, category_id, data_id) - def is_action_data_exist(self, data_id=None,category_id=None): - return self.__is_data_exist(ActionData, data_id=data_id, category_id=category_id) + def is_action_data_exist(self, category_id=None): + return self.__is_data_exist(ActionData, category_id=category_id) def get_action_data(self, policy_id, data_id=None, category_id=None): return self.__get_data(ActionData, policy_id, data_id=data_id, category_id=category_id) def set_action_data(self, policy_id, data_id=None, category_id=None, value=None): try: - return self.__set_data(Action, ActionData, policy_id, data_id=data_id, category_id=category_id, value=value) - except sqlalchemy.exc.IntegrityError: - raise exceptions.ActionScopeExisting + return self.__set_data(Action, ActionData, policy_id, data_id=data_id, + category_id=category_id, value=value) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ActionScopeExisting + raise error - def delete_action_data(self, policy_id, data_id): - return self.__delete_data(ActionData, policy_id, data_id) + def delete_action_data(self, policy_id, category_id, data_id): + return self.__delete_data(ActionData, policy_id, category_id, data_id) def get_subject_assignments(self, policy_id, subject_id=None, category_id=None): with self.get_session_for_write() as session: query = session.query(SubjectAssignment) if subject_id and category_id: - #TODO change the subject_id to perimeter_id to allow code refactoring - query = query.filter_by(policy_id=policy_id, subject_id=subject_id, category_id=category_id) + # TODO change the subject_id to perimeter_id to allow code refactoring + query = query.filter_by(policy_id=policy_id, subject_id=subject_id, + category_id=category_id) elif subject_id: query = query.filter_by(policy_id=policy_id, subject_id=subject_id) else: @@ -681,7 +814,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): def add_subject_assignment(self, policy_id, subject_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(SubjectAssignment) - query = query.filter_by(policy_id=policy_id, subject_id=subject_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, subject_id=subject_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -704,10 +838,27 @@ class PolicyConnector(BaseConnector, PolicyDriver): session.add(ref) return {ref.id: ref.to_dict()} + def is_subject_category_has_assignment(self, category_id): + return self.__is_category_has_assignment(SubjectAssignment, category_id) + + def is_object_category_has_assignment(self, category_id): + return self.__is_category_has_assignment(ObjectAssignment, category_id) + + def is_action_category_has_assignment(self, category_id): + return self.__is_category_has_assignment(ActionAssignment, category_id) + + def __is_category_has_assignment(self, ClassType, category_id): + with self.get_session_for_write() as session: + query = session.query(ClassType) + query = query.filter_by(category_id=category_id) + count = query.count() + return count > 0 + def delete_subject_assignment(self, policy_id, subject_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(SubjectAssignment) - query = query.filter_by(policy_id=policy_id, subject_id=subject_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, subject_id=subject_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -724,8 +875,9 @@ class PolicyConnector(BaseConnector, PolicyDriver): with self.get_session_for_write() as session: query = session.query(ObjectAssignment) if object_id and category_id: - #TODO change the object_id to perimeter_id to allow code refactoring - query = query.filter_by(policy_id=policy_id, object_id=object_id, category_id=category_id) + # TODO change the object_id to perimeter_id to allow code refactoring + query = query.filter_by(policy_id=policy_id, object_id=object_id, + category_id=category_id) elif object_id: query = query.filter_by(policy_id=policy_id, object_id=object_id) else: @@ -736,7 +888,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): def add_object_assignment(self, policy_id, object_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(ObjectAssignment) - query = query.filter_by(policy_id=policy_id, object_id=object_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, object_id=object_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -762,7 +915,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): def delete_object_assignment(self, policy_id, object_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(ObjectAssignment) - query = query.filter_by(policy_id=policy_id, object_id=object_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, object_id=object_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -777,12 +931,17 @@ class PolicyConnector(BaseConnector, PolicyDriver): def get_action_assignments(self, policy_id, action_id=None, category_id=None): with self.get_session_for_write() as session: + if not policy_id: + return exceptions.PolicyUnknown query = session.query(ActionAssignment) if action_id and category_id: # TODO change the action_id to perimeter_id to allow code refactoring - query = query.filter_by(policy_id=policy_id, action_id=action_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, action_id=action_id, + category_id=category_id) elif action_id: query = query.filter_by(policy_id=policy_id, action_id=action_id) + elif category_id: + query = query.filter_by(policy_id=policy_id, category_id=category_id) else: query = query.filter_by(policy_id=policy_id) ref_list = query.all() @@ -791,7 +950,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): def add_action_assignment(self, policy_id, action_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(ActionAssignment) - query = query.filter_by(policy_id=policy_id, action_id=action_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, action_id=action_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -817,7 +977,8 @@ class PolicyConnector(BaseConnector, PolicyDriver): def delete_action_assignment(self, policy_id, action_id, category_id, data_id): with self.get_session_for_write() as session: query = session.query(ActionAssignment) - query = query.filter_by(policy_id=policy_id, action_id=action_id, category_id=category_id) + query = query.filter_by(policy_id=policy_id, action_id=action_id, + category_id=category_id) ref = query.first() if ref: old_ref = copy.deepcopy(ref.to_dict()) @@ -837,7 +998,7 @@ class PolicyConnector(BaseConnector, PolicyDriver): query = query.filter_by(policy_id=policy_id, rule_id=rule_id) ref = query.first() return {ref.id: ref.to_dict()} - elif meta_rule_id: + elif meta_rule_id and policy_id: query = query.filter_by(policy_id=policy_id, meta_rule_id=meta_rule_id) ref_list = query.all() return { @@ -853,6 +1014,14 @@ class PolicyConnector(BaseConnector, PolicyDriver): "rules": list(map(lambda x: x.to_dict(), ref_list)) } + def is_meta_rule_has_rules(self, meta_rule_id): + with self.get_session_for_read() as session: + query = session.query(Rule) + + query = query.filter_by(meta_rule_id=meta_rule_id) + count = query.count() + return count > 0 + def add_rule(self, policy_id, meta_rule_id, value): try: rules = self.get_rules(policy_id, meta_rule_id=meta_rule_id) @@ -870,8 +1039,10 @@ class PolicyConnector(BaseConnector, PolicyDriver): ) session.add(ref) return {ref.id: ref.to_dict()} - except sqlalchemy.exc.IntegrityError: - raise exceptions.RuleExisting + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.RuleExisting + raise error def delete_rule(self, policy_id, rule_id): with self.get_session_for_write() as session: @@ -885,19 +1056,24 @@ class PolicyConnector(BaseConnector, PolicyDriver): class ModelConnector(BaseConnector, ModelDriver): def update_model(self, model_id, value): - with self.get_session_for_write() as session: - query = session.query(Model) - if model_id: - query = query.filter_by(id=model_id) - ref = query.first() - if ref: - value_wo_name = copy.deepcopy(value) - value_wo_name.pop("name", None) - setattr(ref, "name", value["name"]) - d = dict(ref.value) - d.update(value_wo_name) - setattr(ref, "value", d) - return {ref.id: ref.to_dict()} + try: + with self.get_session_for_write() as session: + query = session.query(Model) + if model_id: + query = query.filter_by(id=model_id) + ref = query.first() + if ref: + value_wo_name = copy.deepcopy(value) + value_wo_name.pop("name", None) + setattr(ref, "name", value["name"]) + d = dict(ref.value) + d.update(value_wo_name) + setattr(ref, "value", d) + return {ref.id: ref.to_dict()} + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ModelExisting + raise error def delete_model(self, model_id): with self.get_session_for_write() as session: @@ -916,8 +1092,9 @@ class ModelConnector(BaseConnector, ModelDriver): }) session.add(new) return {new.id: new.to_dict()} - except sqlalchemy.exc.IntegrityError as e: - raise exceptions.ModelExisting + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ModelExisting def get_models(self, model_id=None): with self.get_session_for_read() as session: @@ -931,37 +1108,44 @@ class ModelConnector(BaseConnector, ModelDriver): return r def set_meta_rule(self, meta_rule_id, value): - with self.get_session_for_write() as session: - value_wo_other_data = copy.deepcopy(value) - value_wo_other_data.pop("name", None) - value_wo_other_data.pop("subject_categories", None) - value_wo_other_data.pop("object_categories", None) - value_wo_other_data.pop("action_categories", None) - if meta_rule_id is None: - try: - ref = MetaRule.from_dict( - { - "id": uuid4().hex, - "name": value["name"], - "subject_categories": value["subject_categories"], - "object_categories": value["object_categories"], - "action_categories": value["action_categories"], - "value": value_wo_other_data - } - ) - session.add(ref) - except sqlalchemy.exc.IntegrityError as e: - raise exceptions.MetaRuleExisting - else: - query = session.query(MetaRule) - query = query.filter_by(id=meta_rule_id) - ref = query.first() - setattr(ref, "name", value["name"]) - setattr(ref, "subject_categories", value["subject_categories"]) - setattr(ref, "object_categories", value["object_categories"]) - setattr(ref, "action_categories", value["action_categories"]) - setattr(ref, "value", value_wo_other_data) - return {ref.id: ref.to_dict()} + try: + with self.get_session_for_write() as session: + value_wo_other_data = copy.deepcopy(value) + value_wo_other_data.pop("name", None) + value_wo_other_data.pop("subject_categories", None) + value_wo_other_data.pop("object_categories", None) + value_wo_other_data.pop("action_categories", None) + if meta_rule_id is None: + try: + ref = MetaRule.from_dict( + { + "id": uuid4().hex, + "name": value["name"], + "subject_categories": value["subject_categories"], + "object_categories": value["object_categories"], + "action_categories": value["action_categories"], + "value": value_wo_other_data + } + ) + session.add(ref) + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.MetaRuleExisting + raise error + else: + query = session.query(MetaRule) + query = query.filter_by(id=meta_rule_id) + ref = query.first() + setattr(ref, "name", value["name"]) + setattr(ref, "subject_categories", value["subject_categories"]) + setattr(ref, "object_categories", value["object_categories"]) + setattr(ref, "action_categories", value["action_categories"]) + setattr(ref, "value", value_wo_other_data) + return {ref.id: ref.to_dict()} + except sqlalchemy.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.MetaRuleExisting + raise error def get_meta_rules(self, meta_rule_id=None): with self.get_session_for_read() as session: @@ -988,7 +1172,7 @@ class ModelConnector(BaseConnector, ModelDriver): return {_ref.id: _ref.to_dict() for _ref in ref_list} def __add_perimeter_category(self, ClassType, name, description, uuid=None): - if not name.strip(): + if not name or not name.strip(): raise exceptions.CategoryNameInvalid with self.get_session_for_write() as session: ref = ClassType.from_dict( @@ -1015,8 +1199,10 @@ class ModelConnector(BaseConnector, ModelDriver): def add_subject_category(self, name, description, uuid=None): try: return self.__add_perimeter_category(SubjectCategory, name, description, uuid=uuid) - except sql.exc.IntegrityError as e: - raise exceptions.SubjectCategoryExisting() + except sql.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.SubjectCategoryExisting + raise error def delete_subject_category(self, category_id): self.__delete_perimeter_category(SubjectCategory, category_id) @@ -1027,8 +1213,10 @@ class ModelConnector(BaseConnector, ModelDriver): def add_object_category(self, name, description, uuid=None): try: return self.__add_perimeter_category(ObjectCategory, name, description, uuid=uuid) - except sql.exc.IntegrityError as e: - raise exceptions.ObjectCategoryExisting() + except sql.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ObjectCategoryExisting + raise error def delete_object_category(self, category_id): self.__delete_perimeter_category(ObjectCategory, category_id) @@ -1040,8 +1228,10 @@ class ModelConnector(BaseConnector, ModelDriver): def add_action_category(self, name, description, uuid=None): try: return self.__add_perimeter_category(ActionCategory, name, description, uuid=uuid) - except sql.exc.IntegrityError as e: - raise exceptions.ActionCategoryExisting() + except sql.exc.IntegrityError as error: + if 'UNIQUE constraint' in str(error): + raise exceptions.ActionCategoryExisting + raise error def delete_action_category(self, category_id): self.__delete_perimeter_category(ActionCategory, category_id) -- cgit 1.2.3-korg