diff options
Diffstat (limited to 'python_moondb')
-rw-r--r-- | python_moondb/Changelog | 4 | ||||
-rw-r--r-- | python_moondb/python_moondb/__init__.py | 2 | ||||
-rw-r--r-- | python_moondb/python_moondb/backends/sql.py | 536 |
3 files changed, 140 insertions, 402 deletions
diff --git a/python_moondb/Changelog b/python_moondb/Changelog index a7d10b17..b1e8e0ce 100644 --- a/python_moondb/Changelog +++ b/python_moondb/Changelog @@ -57,3 +57,7 @@ CHANGES ----- - Code cleaning +1.2.6 +----- +- Remove some code duplication in moon_db +- handle the extra field for the perimeter diff --git a/python_moondb/python_moondb/__init__.py b/python_moondb/python_moondb/__init__.py index de7c772e..8de4de66 100644 --- a/python_moondb/python_moondb/__init__.py +++ b/python_moondb/python_moondb/__init__.py @@ -3,5 +3,5 @@ # license which can be found in the file 'LICENSE' in this package distribution # or at 'http://www.apache.org/licenses/LICENSE-2.0'. -__version__ = "1.2.5" +__version__ = "1.2.6" diff --git a/python_moondb/python_moondb/backends/sql.py b/python_moondb/python_moondb/backends/sql.py index 1ce8d016..a8e7740b 100644 --- a/python_moondb/python_moondb/backends/sql.py +++ b/python_moondb/python_moondb/backends/sql.py @@ -9,7 +9,7 @@ from uuid import uuid4 import sqlalchemy as sql import logging from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy import create_engine from contextlib import contextmanager from sqlalchemy import types as sql_types @@ -103,36 +103,30 @@ class PDP(Base, DictBase): } -class SubjectCategory(Base, DictBase): - __tablename__ = 'subject_categories' +class PerimeterCategoryBase(DictBase): attributes = ['id', 'name', 'description'] id = sql.Column(sql.String(64), primary_key=True) name = sql.Column(sql.String(256), nullable=False) description = sql.Column(sql.String(256), nullable=True) -class ObjectCategory(Base, DictBase): +class SubjectCategory(Base, PerimeterCategoryBase): + __tablename__ = 'subject_categories' + + +class ObjectCategory(Base, PerimeterCategoryBase): __tablename__ = 'object_categories' - attributes = ['id', 'name', 'description'] - id = sql.Column(sql.String(64), primary_key=True) - name = sql.Column(sql.String(256), nullable=False) - description = sql.Column(sql.String(256), nullable=True) -class ActionCategory(Base, DictBase): +class ActionCategory(Base, PerimeterCategoryBase): __tablename__ = 'action_categories' - attributes = ['id', 'name', 'description'] - id = sql.Column(sql.String(64), primary_key=True) - name = sql.Column(sql.String(256), nullable=False) - description = sql.Column(sql.String(256), nullable=True) -class Subject(Base, DictBase): - __tablename__ = 'subjects' +class PerimeterBase(DictBase): attributes = ['id', 'value'] id = sql.Column(sql.String(64), primary_key=True) value = sql.Column(JsonBlob(), nullable=True) - + __mapper_args__ = {'concrete': True} def __repr__(self): return "{}: {}".format(self.id, json.dumps(self.value)) @@ -142,7 +136,7 @@ class Subject(Base, DictBase): 'name': self.value.get("name", ""), 'description': self.value.get("description", ""), 'email': self.value.get("email", ""), - 'partner_id': self.value.get("partner_id", ""), + 'extra': self.value.get("extra", dict()), 'policy_list': self.value.get("policy_list", []) } @@ -153,63 +147,25 @@ class Subject(Base, DictBase): } -class Object(Base, DictBase): - __tablename__ = 'objects' - attributes = ['id', 'value'] - id = sql.Column(sql.String(64), primary_key=True) - value = sql.Column(JsonBlob(), nullable=True) - - def __repr__(self): - return "{}: {}".format(self.id, json.dumps(self.value)) +class Subject(Base, PerimeterBase): + __tablename__ = 'subjects' - def to_dict(self): - return { - 'id': self.id, - 'value': self.value - } - def to_return(self): - return { - 'id': self.id, - 'name': self.value.get("name", ""), - 'description': self.value.get("description", ""), - 'partner_id': self.value.get("partner_id", ""), - 'policy_list': self.value.get("policy_list", []) - } +class Object(Base, PerimeterBase): + __tablename__ = 'objects' -class Action(Base, DictBase): +class Action(Base, PerimeterBase): __tablename__ = 'actions' - attributes = ['id', 'value'] - id = sql.Column(sql.String(64), primary_key=True) - value = sql.Column(JsonBlob(), nullable=True) - - def __repr__(self): - return "{}: {}".format(self.id, json.dumps(self.value)) - def to_dict(self): - return { - 'id': self.id, - 'value': self.value - } - def to_return(self): - return { - 'id': self.id, - 'name': self.value.get("name", ""), - 'description': self.value.get("description", ""), - 'partner_id': self.value.get("partner_id", ""), - 'policy_list': self.value.get("policy_list", []) - } - - -class SubjectData(Base, DictBase): - __tablename__ = 'subject_data' +class PerimeterDataBase(DictBase): attributes = ['id', 'value', 'category_id', 'policy_id'] id = sql.Column(sql.String(64), primary_key=True) value = sql.Column(JsonBlob(), nullable=True) - category_id = sql.Column(sql.ForeignKey("subject_categories.id"), nullable=False) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) + @declared_attr + def policy_id(cls): + return sql.Column(sql.ForeignKey("policies.id"), nullable=False) def to_dict(self): return { @@ -221,79 +177,68 @@ class SubjectData(Base, DictBase): } -class ObjectData(Base, DictBase): +class SubjectData(Base, PerimeterDataBase): + __tablename__ = 'subject_data' + category_id = sql.Column(sql.ForeignKey("subject_categories.id"), nullable=False) + + +class ObjectData(Base, PerimeterDataBase): __tablename__ = 'object_data' - attributes = ['id', 'value', 'category_id', 'policy_id'] - id = sql.Column(sql.String(64), primary_key=True) - value = sql.Column(JsonBlob(), nullable=True) category_id = sql.Column(sql.ForeignKey("object_categories.id"), nullable=False) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) -class ActionData(Base, DictBase): +class ActionData(Base, PerimeterDataBase): __tablename__ = 'action_data' - attributes = ['id', 'value', 'category_id', 'policy_id'] - id = sql.Column(sql.String(64), primary_key=True) - value = sql.Column(JsonBlob(), nullable=True) category_id = sql.Column(sql.ForeignKey("action_categories.id"), nullable=False) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) -class SubjectAssignment(Base, DictBase): - __tablename__ = 'subject_assignments' +class PerimeterAssignmentBase(DictBase): attributes = ['id', 'assignments', 'policy_id', 'subject_id', 'category_id'] id = sql.Column(sql.String(64), primary_key=True) assignments = sql.Column(JsonBlob(), nullable=True) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) - subject_id = sql.Column(sql.ForeignKey("subjects.id"), nullable=False) - category_id = sql.Column(sql.ForeignKey("subject_categories.id"), nullable=False) + category_id = None - def to_dict(self): + @declared_attr + def policy_id(cls): + return sql.Column(sql.ForeignKey("policies.id"), nullable=False) + + def _to_dict(self, element_key, element_value): return { "id": self.id, "policy_id": self.policy_id, - "subject_id": self.subject_id, + element_key: element_value, "category_id": self.category_id, "assignments": self.assignments, } -class ObjectAssignment(Base, DictBase): +class SubjectAssignment(Base, PerimeterAssignmentBase): + __tablename__ = 'subject_assignments' + subject_id = sql.Column(sql.ForeignKey("subjects.id"), nullable=False) + category_id = sql.Column(sql.ForeignKey("subject_categories.id"), nullable=False) + + def to_dict(self): + return self._to_dict("subject_id", self.subject_id) + + +class ObjectAssignment(Base, PerimeterAssignmentBase): __tablename__ = 'object_assignments' attributes = ['id', 'assignments', 'policy_id', 'object_id', 'category_id'] - id = sql.Column(sql.String(64), primary_key=True) - assignments = sql.Column(JsonBlob(), nullable=True) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) object_id = sql.Column(sql.ForeignKey("objects.id"), nullable=False) category_id = sql.Column(sql.ForeignKey("object_categories.id"), nullable=False) def to_dict(self): - return { - "id": self.id, - "policy_id": self.policy_id, - "object_id": self.object_id, - "category_id": self.category_id, - "assignments": self.assignments, - } + return self._to_dict("object_id", self.object_id) -class ActionAssignment(Base, DictBase): +class ActionAssignment(Base, PerimeterAssignmentBase): __tablename__ = 'action_assignments' attributes = ['id', 'assignments', 'policy_id', 'action_id', 'category_id'] - id = sql.Column(sql.String(64), primary_key=True) - assignments = sql.Column(JsonBlob(), nullable=True) - policy_id = sql.Column(sql.ForeignKey("policies.id"), nullable=False) action_id = sql.Column(sql.ForeignKey("actions.id"), nullable=False) category_id = sql.Column(sql.ForeignKey("action_categories.id"), nullable=False) def to_dict(self): - return { - "id": self.id, - "policy_id": self.policy_id, - "action_id": self.action_id, - "category_id": self.category_id, - "assignments": self.assignments, - } + return self._to_dict("action_id", self.action_id) class MetaRule(Base, DictBase): @@ -446,9 +391,9 @@ class PolicyConnector(BaseConnector, PolicyDriver): ref_list = query.all() return {_ref.id: _ref.to_dict() for _ref in ref_list} - def get_subjects(self, policy_id, perimeter_id=None): + def __get_perimeters(self, ClassType, policy_id, perimeter_id=None): with self.get_session_for_read() as session: - query = session.query(Subject) + query = session.query(ClassType) ref_list = copy.deepcopy(query.all()) if perimeter_id: for _ref in ref_list: @@ -467,190 +412,83 @@ class PolicyConnector(BaseConnector, PolicyDriver): return {_ref.id: _ref.to_return() for _ref in results} return {_ref.id: _ref.to_return() for _ref in ref_list} - def set_subject(self, policy_id, perimeter_id=None, value=None): - _subject = None + def __set_perimeter(self, ClassType, policy_id, perimeter_id=None, value=None): + _perimeter = None with self.get_session_for_write() as session: if perimeter_id: - query = session.query(Subject) + query = session.query(ClassType) query = query.filter_by(id=perimeter_id) - _subject = query.first() - if not _subject: + _perimeter = query.first() + if not _perimeter: if "policy_list" not in value or type(value["policy_list"]) is not list: value["policy_list"] = [] if policy_id and policy_id not in value["policy_list"]: value["policy_list"] = [policy_id, ] - new = Subject.from_dict({ + new = ClassType.from_dict({ "id": perimeter_id if perimeter_id else uuid4().hex, "value": value }) session.add(new) return {new.id: new.to_return()} else: - _value = copy.deepcopy(_subject.to_dict()) + _value = copy.deepcopy(_perimeter.to_dict()) if "policy_list" not in _value["value"] or type(_value["value"]["policy_list"]) is not list: _value["value"]["policy_list"] = [] if policy_id and policy_id not in _value["value"]["policy_list"]: _value["value"]["policy_list"].append(policy_id) - new_subject = Subject.from_dict(_value) + new_perimeter = ClassType.from_dict(_value) # setattr(_subject, "value", _value["value"]) - setattr(_subject, "value", getattr(new_subject, "value")) - return {_subject.id: _subject.to_return()} + setattr(_perimeter, "value", getattr(new_perimeter, "value")) + return {_perimeter.id: _perimeter.to_return()} - def delete_subject(self, policy_id, perimeter_id): + def __delete_perimeter(self,ClassType, ClassUnknownException, policy_id, perimeter_id): with self.get_session_for_write() as session: - query = session.query(Subject) + query = session.query(ClassType) query = query.filter_by(id=perimeter_id) - _subject = query.first() - if not _subject: - raise SubjectUnknown - old_subject = copy.deepcopy(_subject.to_dict()) + _perimeter = query.first() + if not _perimeter: + raise ClassUnknownException + old_perimeter = copy.deepcopy(_perimeter.to_dict()) # value = _subject.to_dict() try: - old_subject["value"]["policy_list"].remove(policy_id) - new_user = Subject.from_dict(old_subject) - setattr(_subject, "value", getattr(new_user, "value")) + old_perimeter["value"]["policy_list"].remove(policy_id) + new_perimeter = ClassType.from_dict(old_perimeter) + setattr(_perimeter, "value", getattr(new_perimeter, "value")) except ValueError: - if not _subject.value["policy_list"]: - session.delete(_subject) + if not _perimeter.value["policy_list"]: + session.delete(_perimeter) + + def get_subjects(self, policy_id, perimeter_id=None): + return self.__get_perimeters(Subject, policy_id, perimeter_id) + + def set_subject(self, policy_id, perimeter_id=None, value=None): + return self.__set_perimeter(Subject, policy_id, perimeter_id=perimeter_id, value=value) + + def delete_subject(self, policy_id, perimeter_id): + self.__delete_perimeter(Subject, SubjectUnknown, policy_id, perimeter_id) def get_objects(self, policy_id, perimeter_id=None): - with self.get_session_for_read() as session: - query = session.query(Object) - ref_list = copy.deepcopy(query.all()) - if perimeter_id: - for _ref in ref_list: - _ref_value = _ref.to_return() - if perimeter_id == _ref.id: - if policy_id and policy_id in _ref_value["policy_list"]: - return {_ref.id: _ref_value} - else: - return {} - elif policy_id: - results = [] - for _ref in ref_list: - _ref_value = _ref.to_return() - if policy_id in _ref_value["policy_list"]: - results.append(_ref) - return {_ref.id: _ref.to_return() for _ref in results} - return {_ref.id: _ref.to_return() for _ref in ref_list} + return self.__get_perimeters(Object, policy_id, perimeter_id) def set_object(self, policy_id, perimeter_id=None, value=None): - _object = None - with self.get_session_for_write() as session: - if perimeter_id: - query = session.query(Object) - query = query.filter_by(id=perimeter_id) - _object = query.first() - if not _object: - if "policy_list" not in value or type(value["policy_list"]) is not list: - value["policy_list"] = [] - if policy_id and policy_id not in value["policy_list"]: - value["policy_list"] = [policy_id, ] - new = Object.from_dict({ - "id": perimeter_id if perimeter_id else uuid4().hex, - "value": value - }) - session.add(new) - return {new.id: new.to_return()} - else: - _value = copy.deepcopy(_object.to_dict()) - if "policy_list" not in _value["value"] or type(_value["value"]["policy_list"]) is not list: - _value["value"]["policy_list"] = [] - if policy_id and policy_id not in _value["value"]["policy_list"]: - _value["value"]["policy_list"].append(policy_id) - new_object = Object.from_dict(_value) - # setattr(_object, "value", _value["value"]) - setattr(_object, "value", getattr(new_object, "value")) - return {_object.id: _object.to_return()} + return self.__set_perimeter(Object, policy_id, perimeter_id=perimeter_id, value=value) def delete_object(self, policy_id, perimeter_id): - with self.get_session_for_write() as session: - query = session.query(Object) - query = query.filter_by(id=perimeter_id) - _object = query.first() - if not _object: - raise ObjectUnknown - old_object = copy.deepcopy(_object.to_dict()) - # value = _object.to_dict() - try: - old_object["value"]["policy_list"].remove(policy_id) - new_user = Object.from_dict(old_object) - setattr(_object, "value", getattr(new_user, "value")) - except ValueError: - if not _object.value["policy_list"]: - session.delete(_object) + self.__delete_perimeter(Object, ObjectUnknown, policy_id, perimeter_id) def get_actions(self, policy_id, perimeter_id=None): - with self.get_session_for_read() as session: - query = session.query(Action) - ref_list = copy.deepcopy(query.all()) - if perimeter_id: - for _ref in ref_list: - _ref_value = _ref.to_return() - if perimeter_id == _ref.id: - if policy_id and policy_id in _ref_value["policy_list"]: - return {_ref.id: _ref_value} - else: - return {} - elif policy_id: - results = [] - for _ref in ref_list: - _ref_value = _ref.to_return() - if policy_id in _ref_value["policy_list"]: - results.append(_ref) - return {_ref.id: _ref.to_return() for _ref in results} - return {_ref.id: _ref.to_return() for _ref in ref_list} + return self.__get_perimeters(Action, policy_id, perimeter_id) def set_action(self, policy_id, perimeter_id=None, value=None): - _action = None - with self.get_session_for_write() as session: - if perimeter_id: - query = session.query(Action) - query = query.filter_by(id=perimeter_id) - _action = query.first() - if not _action: - if "policy_list" not in value or type(value["policy_list"]) is not list: - value["policy_list"] = [] - if policy_id and policy_id not in value["policy_list"]: - value["policy_list"] = [policy_id, ] - new = Action.from_dict({ - "id": perimeter_id if perimeter_id else uuid4().hex, - "value": value - }) - session.add(new) - return {new.id: new.to_return()} - else: - _value = copy.deepcopy(_action.to_dict()) - if "policy_list" not in _value["value"] or type(_value["value"]["policy_list"]) is not list: - _value["value"]["policy_list"] = [] - if policy_id and policy_id not in _value["value"]["policy_list"]: - _value["value"]["policy_list"].append(policy_id) - new_action = Action.from_dict(_value) - # setattr(_action, "value", _value["value"]) - setattr(_action, "value", getattr(new_action, "value")) - return {_action.id: _action.to_return()} + return self.__set_perimeter(Action, policy_id, perimeter_id=perimeter_id, value=value) def delete_action(self, policy_id, perimeter_id): - with self.get_session_for_write() as session: - query = session.query(Action) - query = query.filter_by(id=perimeter_id) - _action = query.first() - if not _action: - raise ActionUnknown - old_action = copy.deepcopy(_action.to_dict()) - # value = _action.to_dict() - try: - old_action["value"]["policy_list"].remove(policy_id) - new_user = Action.from_dict(old_action) - setattr(_action, "value", getattr(new_user, "value")) - except ValueError: - if not _action.value["policy_list"]: - session.delete(_action) + self.__delete_perimeter(Action, ActionUnknown, policy_id, perimeter_id) - def get_subject_data(self, policy_id, data_id=None, category_id=None): + def __get_perimeter_data(self, ClassType, policy_id, data_id=None, category_id=None): logger.info("driver {} {} {}".format(policy_id, data_id, category_id)) with self.get_session_for_read() as session: - query = session.query(SubjectData) + query = session.query(ClassType) if data_id: query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) else: @@ -663,13 +501,13 @@ class PolicyConnector(BaseConnector, PolicyDriver): "data": {_ref.id: _ref.to_dict() for _ref in ref_list} } - def set_subject_data(self, policy_id, data_id=None, category_id=None, value=None): + def __set_perimeter_data(self, ClassType, ClassTypeData, policy_id, data_id=None, category_id=None, value=None): with self.get_session_for_write() as session: - query = session.query(SubjectData) + query = session.query(ClassTypeData) query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) ref = query.first() if not ref: - new_ref = SubjectData.from_dict( + new_ref = ClassTypeData.from_dict( { "id": data_id if data_id else uuid4().hex, 'value': value, @@ -680,7 +518,7 @@ class PolicyConnector(BaseConnector, PolicyDriver): session.add(new_ref) ref = new_ref else: - for attr in Subject.attributes: + for attr in ClassType.attributes: if attr != 'id': setattr(ref, attr, getattr(ref, attr)) # session.flush() @@ -690,116 +528,46 @@ class PolicyConnector(BaseConnector, PolicyDriver): "data": {ref.id: ref.to_dict()} } - def delete_subject_data(self, policy_id, data_id): + def __delete_perimeter_data(self, ClassType, policy_id, data_id): with self.get_session_for_write() as session: - query = session.query(SubjectData) + query = session.query(ClassType) query = query.filter_by(policy_id=policy_id, id=data_id) ref = query.first() if ref: session.delete(ref) + def get_subject_data(self, policy_id, data_id=None, category_id=None): + return self.__get_perimeter_data(SubjectData, policy_id, data_id=data_id, category_id=category_id) + + def set_subject_data(self, policy_id, data_id=None, category_id=None, value=None): + return self.__set_perimeter_data(Subject, SubjectData, policy_id, data_id=data_id, category_id=category_id, value=value) + + def delete_subject_data(self, policy_id, data_id): + return self.__delete_perimeter_data(SubjectData, policy_id, data_id) + def get_object_data(self, policy_id, data_id=None, category_id=None): - with self.get_session_for_read() as session: - query = session.query(ObjectData) - if data_id: - query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) - else: - query = query.filter_by(policy_id=policy_id, category_id=category_id) - ref_list = query.all() - return { - "policy_id": policy_id, - "category_id": category_id, - "data": {_ref.id: _ref.to_dict() for _ref in ref_list} - } + return self.__get_perimeter_data(ObjectData, policy_id, data_id=data_id, category_id=category_id) def set_object_data(self, policy_id, data_id=None, category_id=None, value=None): - with self.get_session_for_write() as session: - query = session.query(ObjectData) - query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) - ref = query.first() - if not ref: - new_ref = ObjectData.from_dict( - { - "id": data_id if data_id else uuid4().hex, - 'value': value, - 'category_id': category_id, - 'policy_id': policy_id, - } - ) - session.add(new_ref) - ref = new_ref - else: - for attr in Object.attributes: - if attr != 'id': - setattr(ref, attr, getattr(ref, attr)) - # session.flush() - return { - "policy_id": policy_id, - "category_id": category_id, - "data": {ref.id: ref.to_dict()} - } + return self.__set_perimeter_data(Object, ObjectData, policy_id, data_id=data_id, category_id=category_id, value=value) def delete_object_data(self, policy_id, data_id): - with self.get_session_for_write() as session: - query = session.query(ObjectData) - query = query.filter_by(policy_id=policy_id, id=data_id) - ref = query.first() - if ref: - session.delete(ref) + return self.__delete_perimeter_data(ObjectData, policy_id, data_id) def get_action_data(self, policy_id, data_id=None, category_id=None): - with self.get_session_for_read() as session: - query = session.query(ActionData) - if data_id: - query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) - else: - query = query.filter_by(policy_id=policy_id, category_id=category_id) - ref_list = query.all() - return { - "policy_id": policy_id, - "category_id": category_id, - "data": {_ref.id: _ref.to_dict() for _ref in ref_list} - } + return self.__get_perimeter_data(ActionData, policy_id, data_id=data_id, category_id=category_id) def set_action_data(self, policy_id, data_id=None, category_id=None, value=None): - with self.get_session_for_write() as session: - query = session.query(ActionData) - query = query.filter_by(policy_id=policy_id, id=data_id, category_id=category_id) - ref = query.first() - if not ref: - new_ref = ActionData.from_dict( - { - "id": data_id if data_id else uuid4().hex, - 'value': value, - 'category_id': category_id, - 'policy_id': policy_id, - } - ) - session.add(new_ref) - ref = new_ref - else: - for attr in Action.attributes: - if attr != 'id': - setattr(ref, attr, getattr(ref, attr)) - # session.flush() - return { - "policy_id": policy_id, - "category_id": category_id, - "data": {ref.id: ref.to_dict()} - } + return self.__set_perimeter_data(Action, ActionData, policy_id, data_id=data_id, category_id=category_id, value=value) def delete_action_data(self, policy_id, data_id): - with self.get_session_for_write() as session: - query = session.query(ActionData) - query = query.filter_by(policy_id=policy_id, id=data_id) - ref = query.first() - if ref: - session.delete(ref) + return self.__delete_perimeter_data(ActionData, policy_id, data_id) def get_subject_assignments(self, policy_id, subject_id=None, category_id=None): with self.get_session_for_write() as session: query = session.query(SubjectAssignment) if subject_id and category_id: + #TODO change the subject_id to perimeter_id to allow code refactoring query = query.filter_by(policy_id=policy_id, subject_id=subject_id, category_id=category_id) elif subject_id: query = query.filter_by(policy_id=policy_id, subject_id=subject_id) @@ -852,6 +620,7 @@ class PolicyConnector(BaseConnector, PolicyDriver): with self.get_session_for_write() as session: query = session.query(ObjectAssignment) if object_id and category_id: + #TODO change the object_id to perimeter_id to allow code refactoring query = query.filter_by(policy_id=policy_id, object_id=object_id, category_id=category_id) elif object_id: query = query.filter_by(policy_id=policy_id, object_id=object_id) @@ -904,6 +673,7 @@ class PolicyConnector(BaseConnector, PolicyDriver): with self.get_session_for_write() as session: query = session.query(ActionAssignment) if action_id and category_id: + # TODO change the action_id to perimeter_id to allow code refactoring query = query.filter_by(policy_id=policy_id, action_id=action_id, category_id=category_id) elif action_id: query = query.filter_by(policy_id=policy_id, action_id=action_id) @@ -1074,21 +844,21 @@ class ModelConnector(BaseConnector, ModelDriver): if ref: session.delete(ref) - def get_subject_categories(self, category_id=None): + def __get_perimeter_categories(self, ClassType, category_id=None): with self.get_session_for_read() as session: - query = session.query(SubjectCategory) + query = session.query(ClassType) if category_id: query = query.filter_by(id=category_id) ref_list = query.all() return {_ref.id: _ref.to_dict() for _ref in ref_list} - def add_subject_category(self, name, description, uuid=None): + def __add_perimeter_category(self, ClassType, name, description, uuid=None): with self.get_session_for_write() as session: - query = session.query(SubjectCategory) + query = session.query(ClassType) query = query.filter_by(name=name) ref = query.first() if not ref: - ref = SubjectCategory.from_dict( + ref = ClassType.from_dict( { "id": uuid if uuid else uuid4().hex, "name": name, @@ -1098,77 +868,41 @@ class ModelConnector(BaseConnector, ModelDriver): session.add(ref) return {ref.id: ref.to_dict()} - def delete_subject_category(self, category_id): + def __delete_perimeter_category(self, ClassType, category_id): with self.get_session_for_write() as session: - query = session.query(SubjectCategory) + query = session.query(ClassType) query = query.filter_by(id=category_id) ref = query.first() if ref: session.delete(ref) + def get_subject_categories(self, category_id=None): + return self.__get_perimeter_categories(SubjectCategory, category_id=category_id) + + def add_subject_category(self, name, description, uuid=None): + return self.__add_perimeter_category(SubjectCategory, name, description, uuid=uuid) + + def delete_subject_category(self, category_id): + self.__delete_perimeter_category(SubjectCategory, category_id) + def get_object_categories(self, category_id=None): - with self.get_session_for_read() as session: - query = session.query(ObjectCategory) - if category_id: - query = query.filter_by(id=category_id) - ref_list = query.all() - return {_ref.id: _ref.to_dict() for _ref in ref_list} + return self.__get_perimeter_categories(ObjectCategory, category_id=category_id) def add_object_category(self, name, description, uuid=None): - with self.get_session_for_write() as session: - query = session.query(ObjectCategory) - query = query.filter_by(name=name) - ref = query.first() - if not ref: - ref = ObjectCategory.from_dict( - { - "id": uuid if uuid else uuid4().hex, - "name": name, - "description": description - } - ) - session.add(ref) - return {ref.id: ref.to_dict()} + return self.__add_perimeter_category(ObjectCategory, name, description, uuid=uuid) def delete_object_category(self, category_id): - with self.get_session_for_write() as session: - query = session.query(ObjectCategory) - query = query.filter_by(id=category_id) - ref = query.first() - if ref: - session.delete(ref) + self.__delete_perimeter_category(SubjectCategory, category_id) def get_action_categories(self, category_id=None): - with self.get_session_for_read() as session: - query = session.query(ActionCategory) - if category_id: - query = query.filter_by(id=category_id) - ref_list = query.all() - return {_ref.id: _ref.to_dict() for _ref in ref_list} + return self.__get_perimeter_categories(ActionCategory, category_id=category_id) def add_action_category(self, name, description, uuid=None): - with self.get_session_for_write() as session: - query = session.query(ActionCategory) - query = query.filter_by(name=name) - ref = query.first() - if not ref: - ref = ActionCategory.from_dict( - { - "id": uuid if uuid else uuid4().hex, - "name": name, - "description": description - } - ) - session.add(ref) - return {ref.id: ref.to_dict()} + return self.__add_perimeter_category(ActionCategory, name, description, uuid=uuid) def delete_action_category(self, category_id): - with self.get_session_for_write() as session: - query = session.query(ActionCategory) - query = query.filter_by(id=category_id) - ref = query.first() - if ref: - session.delete(ref) + self.__delete_perimeter_category(SubjectCategory, category_id) + # Getter and Setter for subject_category # def get_subject_categories_dict(self, intra_extension_id): |