Ver Fonte

propagate protecting class attribute to other entities

Toan Quach há 9 meses atrás
pai
commit
06f2082f94

+ 20 - 1
taipy/core/cycle/cycle.py

@@ -14,13 +14,14 @@ import uuid
 from datetime import datetime
 from typing import Any, Dict, Optional
 
+from taipy.config.common._validate_id import _validate_id
 from taipy.config.common.frequency import Frequency
 
 from .._entity._entity import _Entity
 from .._entity._labeled import _Labeled
 from .._entity._properties import _Properties
 from .._entity._reload import _Reloader, _self_reload, _self_setter
-from ..exceptions.exceptions import _SuspiciousFileOperation
+from ..exceptions.exceptions import AttributeKeyAlreadyExisted, _SuspiciousFileOperation
 from ..notification.event import Event, EventEntityType, EventOperation, _make_event
 from .cycle_id import CycleId
 
@@ -102,6 +103,7 @@ class Cycle(_Entity, _Labeled):
     _ID_PREFIX = "CYCLE"
     __SEPARATOR = "_"
     _MANAGER_NAME = "cycle"
+    __CHECK_INIT_DONE_ATTR_NAME = "_init_done"
 
     def __init__(
         self,
@@ -121,6 +123,8 @@ class Cycle(_Entity, _Labeled):
         self.id = id or self._new_id(self._name)
         self._properties = _Properties(self, **properties)
 
+        self._init_done = True
+
     def _new_name(self, name: Optional[str] = None) -> str:
         if name:
             return name
@@ -211,6 +215,21 @@ class Cycle(_Entity, _Labeled):
 
         return CycleId(_get_valid_filename(Cycle.__SEPARATOR.join([Cycle._ID_PREFIX, name, str(uuid.uuid4())])))
 
+    def __setattr__(self, name: str, value: Any) -> None:
+        if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
+            return super().__setattr__(name, value)
+        else:
+            protected_attribute_name = _validate_id(name)
+            try:
+                if protected_attribute_name not in self._properties:
+                    raise AttributeError
+                raise AttributeKeyAlreadyExisted(name)
+            except AttributeError:
+                return super().__setattr__(name, value)
+
+    def _get_attributes(self, protected_attribute_name, attribute_name):
+        raise AttributeError
+
     def __getattr__(self, attribute_name):
         protected_attribute_name = attribute_name
         if protected_attribute_name in self._properties:

+ 14 - 3
taipy/core/data/data_node.py

@@ -29,7 +29,7 @@ from .._entity._ready_to_run_property import _ReadyToRunProperty
 from .._entity._reload import _Reloader, _self_reload, _self_setter
 from .._version._version_manager_factory import _VersionManagerFactory
 from ..common._warnings import _warn_deprecated
-from ..exceptions.exceptions import DataNodeIsBeingEdited, NoData
+from ..exceptions.exceptions import AttributeKeyAlreadyExisted, DataNodeIsBeingEdited, NoData
 from ..job.job_id import JobId
 from ..notification.event import Event, EventEntityType, EventOperation, _make_event
 from ..reason import DataNodeEditInProgress, DataNodeIsNotWritten
@@ -139,6 +139,7 @@ class DataNode(_Entity, _Labeled):
     _MANAGER_NAME: str = "data"
     _PATH_KEY = "path"
     __EDIT_TIMEOUT = 30
+    __CHECK_INIT_DONE_ATTR_NAME = "_init_done"
 
     _TAIPY_PROPERTIES: Set[str] = set()
 
@@ -174,6 +175,7 @@ class DataNode(_Entity, _Labeled):
         self._edits: List[Edit] = edits or []
 
         self._properties: _Properties = _Properties(self, **kwargs)
+        self._init_done = True
 
     @staticmethod
     def _new_id(config_id: str) -> DataNodeId:
@@ -347,8 +349,17 @@ class DataNode(_Entity, _Labeled):
     def __setstate__(self, state):
         vars(self).update(state)
 
-    # def __setattr__(self, name: str, value: Any) -> None:
-    #     return super().__setattr__(name, value)
+    def __setattr__(self, name: str, value: Any) -> None:
+        if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
+            return super().__setattr__(name, value)
+        else:
+            protected_attribute_name = _validate_id(name)
+            try:
+                if protected_attribute_name not in self._properties:
+                    raise AttributeError
+                raise AttributeKeyAlreadyExisted(name)
+            except AttributeError:
+                return super().__setattr__(name, value)
 
     def _get_attributes(self, protected_attribute_name, attribute_name):
         raise AttributeError

+ 4 - 4
taipy/core/data/mongo.py

@@ -143,15 +143,15 @@ class MongoCollectionDataNode(DataNode):
             properties.get(self.__COLLECTION_KEY, "")
         ]
 
-        self.custom_document = properties[self._CUSTOM_DOCUMENT_PROPERTY]
+        self.custom_mongo_document = properties[self._CUSTOM_DOCUMENT_PROPERTY]
 
         self._decoder = self._default_decoder
-        custom_decoder = getattr(self.custom_document, "decode", None)
+        custom_decoder = getattr(self.custom_mongo_document, "decode", None)
         if callable(custom_decoder):
             self._decoder = custom_decoder
 
         self._encoder = self._default_encoder
-        custom_encoder = getattr(self.custom_document, "encode", None)
+        custom_encoder = getattr(self.custom_mongo_document, "encode", None)
         if callable(custom_encoder):
             self._encoder = custom_encoder
 
@@ -275,7 +275,7 @@ class MongoCollectionDataNode(DataNode):
         Returns:
             A custom document object.
         """
-        return self.custom_document(**document)
+        return self.custom_mongo_document(**document)
 
     def _default_encoder(self, document_object: Any) -> Dict:
         """Encode a custom document object to a dictionary for writing to MongoDB.

+ 1 - 1
taipy/core/scenario/scenario.py

@@ -193,7 +193,7 @@ class Scenario(_Entity, Submittable, _Labeled):
             except AttributeError:
                 return super().__setattr__(name, value)
 
-    def _get_attributes(self, protected_attribute_name, attribute_name):
+    def _get_attributes(self, protected_attribute_name, attribute_name) -> Union[Sequence, Task, DataNode]:
         sequences = self._get_sequences()
         if protected_attribute_name in sequences:
             return sequences[protected_attribute_name]

+ 25 - 5
taipy/core/sequence/sequence.py

@@ -27,7 +27,7 @@ from .._version._version_manager_factory import _VersionManagerFactory
 from ..common._listattributes import _ListAttributes
 from ..common._utils import _Subscriber
 from ..data.data_node import DataNode
-from ..exceptions.exceptions import NonExistingTask
+from ..exceptions.exceptions import AttributeKeyAlreadyExisted, NonExistingTask
 from ..job.job import Job
 from ..notification.event import Event, EventEntityType, EventOperation, _make_event
 from ..submission.submission import Submission
@@ -126,6 +126,7 @@ class Sequence(_Entity, Submittable, _Labeled):
     _ID_PREFIX = "SEQUENCE"
     _SEPARATOR = "_"
     _MANAGER_NAME = "sequence"
+    __CHECK_INIT_DONE_ATTR_NAME = "_init_done"
 
     def __init__(
         self,
@@ -144,6 +145,7 @@ class Sequence(_Entity, Submittable, _Labeled):
         self._parent_ids = parent_ids or set()
         self._properties = _Properties(self, **properties)
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
+        self._init_done = True
 
     @staticmethod
     def _new_id(sequence_name: str, scenario_id) -> SequenceId:
@@ -156,10 +158,21 @@ class Sequence(_Entity, Submittable, _Labeled):
     def __eq__(self, other):
         return isinstance(other, Sequence) and self.id == other.id
 
-    def __getattr__(self, attribute_name):
-        protected_attribute_name = _validate_id(attribute_name)
-        if protected_attribute_name in self._properties:
-            return _tpl._replace_templates(self._properties[protected_attribute_name])
+    def __setattr__(self, name: str, value: Any) -> None:
+        if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
+            return super().__setattr__(name, value)
+        else:
+            protected_attribute_name = _validate_id(name)
+            try:
+                if protected_attribute_name not in self._properties and not self._get_attributes(
+                    protected_attribute_name, name
+                ):
+                    raise AttributeError
+                raise AttributeKeyAlreadyExisted(name)
+            except AttributeError:
+                return super().__setattr__(name, value)
+
+    def _get_attributes(self, protected_attribute_name, attribute_name) -> Union[Task, DataNode]:
         tasks = self._get_tasks()
         if protected_attribute_name in tasks:
             return tasks[protected_attribute_name]
@@ -170,6 +183,13 @@ class Sequence(_Entity, Submittable, _Labeled):
                 return task.output[protected_attribute_name]
         raise AttributeError(f"{attribute_name} is not an attribute of sequence {self.id}")
 
+    def __getattr__(self, attribute_name):
+        protected_attribute_name = _validate_id(attribute_name)
+        if protected_attribute_name in self._properties:
+            return _tpl._replace_templates(self._properties[protected_attribute_name])
+
+        return self._get_attributes(protected_attribute_name, attribute_name)
+
     @property  # type: ignore
     @_self_reload(_MANAGER_NAME)
     def tasks(self) -> Dict[str, Task]:

+ 24 - 4
taipy/core/task/task.py

@@ -22,6 +22,7 @@ from .._entity._properties import _Properties
 from .._entity._reload import _Reloader, _self_reload, _self_setter
 from .._version._version_manager_factory import _VersionManagerFactory
 from ..data.data_node import DataNode
+from ..exceptions import AttributeKeyAlreadyExisted
 from ..notification.event import Event, EventEntityType, EventOperation, _make_event
 from ..submission.submission import Submission
 from .task_id import TaskId
@@ -97,6 +98,7 @@ class Task(_Entity, _Labeled):
     _ID_PREFIX = "TASK"
     __ID_SEPARATOR = "_"
     _MANAGER_NAME = "task"
+    __CHECK_INIT_DONE_ATTR_NAME = "_init_done"
 
     def __init__(
         self,
@@ -121,6 +123,7 @@ class Task(_Entity, _Labeled):
         self._version = version or _VersionManagerFactory._build_manager()._get_latest_version()
         self._skippable = skippable
         self._properties = _Properties(self, **properties)
+        self._init_done = True
 
     def __hash__(self):
         return hash(self.id)
@@ -134,16 +137,33 @@ class Task(_Entity, _Labeled):
     def __setstate__(self, state):
         vars(self).update(state)
 
-    def __getattr__(self, attribute_name):
-        protected_attribute_name = _validate_id(attribute_name)
-        if protected_attribute_name in self._properties:
-            return _tpl._replace_templates(self._properties[protected_attribute_name])
+    def __setattr__(self, name: str, value: Any) -> None:
+        if self.__CHECK_INIT_DONE_ATTR_NAME not in dir(self) or name in dir(self):
+            return super().__setattr__(name, value)
+        else:
+            protected_attribute_name = _validate_id(name)
+            try:
+                if protected_attribute_name not in self._properties and not self._get_attributes(
+                    protected_attribute_name, name
+                ):
+                    raise AttributeError
+                raise AttributeKeyAlreadyExisted(name)
+            except AttributeError:
+                return super().__setattr__(name, value)
+
+    def _get_attributes(self, protected_attribute_name, attribute_name) -> DataNode:
         if protected_attribute_name in self.input:
             return self.input[protected_attribute_name]
         if protected_attribute_name in self.output:
             return self.output[protected_attribute_name]
         raise AttributeError(f"{attribute_name} is not an attribute of task {self.id}")
 
+    def __getattr__(self, attribute_name):
+        protected_attribute_name = _validate_id(attribute_name)
+        if protected_attribute_name in self._properties:
+            return _tpl._replace_templates(self._properties[protected_attribute_name])
+        return self._get_attributes(protected_attribute_name, attribute_name)
+
     @property
     def properties(self):
         self._properties = _Reloader()._reload(self._MANAGER_NAME, self)._properties

+ 31 - 0
tests/core/cycle/test_cycle.py

@@ -8,14 +8,18 @@
 # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
+
 import datetime
 from datetime import timedelta
 
+import pytest
+
 from taipy.config.common.frequency import Frequency
 from taipy.core import CycleId
 from taipy.core.cycle._cycle_manager import _CycleManager
 from taipy.core.cycle._cycle_manager_factory import _CycleManagerFactory
 from taipy.core.cycle.cycle import Cycle
+from taipy.core.exceptions import AttributeKeyAlreadyExisted
 from taipy.core.task.task import Task
 
 
@@ -120,6 +124,33 @@ def test_add_property_to_scenario(current_datetime):
     assert cycle.new_key == "new_value"
 
 
+def test_get_set_property_and_attribute(current_datetime):
+    cycle_manager = _CycleManagerFactory()._build_manager()
+
+    cycle = Cycle(
+        Frequency.WEEKLY,
+        {"key": "value"},
+        current_datetime,
+        current_datetime,
+        current_datetime,
+        name="foo",
+    )
+    cycle_manager._set(cycle)
+
+    assert cycle.properties == {"key": "value"}
+
+    cycle.properties["new_key"] = "new_value"
+    cycle.another_key = "another_value"
+
+    assert cycle.key == "value"
+    assert cycle.new_key == "new_value"
+    assert cycle.another_key == "another_value"
+    assert cycle.properties["new_key"] == "new_value"
+
+    with pytest.raises(AttributeKeyAlreadyExisted):
+        cycle.key = "KeyAlreadyUsed"
+
+
 def test_auto_set_and_reload(current_datetime):
     cycle_1 = Cycle(
         Frequency.WEEKLY,

+ 19 - 1
tests/core/data/test_data_node.py

@@ -25,7 +25,7 @@ from taipy.core.data._data_manager_factory import _DataManagerFactory
 from taipy.core.data.data_node import DataNode
 from taipy.core.data.data_node_id import DataNodeId
 from taipy.core.data.in_memory import InMemoryDataNode
-from taipy.core.exceptions.exceptions import DataNodeIsBeingEdited, NoData
+from taipy.core.exceptions.exceptions import AttributeKeyAlreadyExisted, DataNodeIsBeingEdited, NoData
 from taipy.core.job.job_id import JobId
 from taipy.core.task.task import Task
 
@@ -119,6 +119,24 @@ class TestDataNode:
         with pytest.raises(InvalidConfigurationId):
             DataNode("foo bar")
 
+    def test_get_set_property_and_attribute(self):
+        dn_cfg = Config.configure_data_node("bar", key="value")
+        dn = _DataManager._create_and_set(dn_cfg, "", "")
+
+        assert "key" in dn.properties.keys()
+        assert dn.key == "value"
+
+        dn.properties["new_key"] = "new_value"
+        dn.another_key = "another_value"
+
+        assert dn.key == "value"
+        assert dn.new_key == "new_value"
+        assert dn.another_key == "another_value"
+        assert dn.properties["new_key"] == "new_value"
+
+        with pytest.raises(AttributeKeyAlreadyExisted):
+            dn.key = "KeyAlreadyUsed"
+
     def test_read_write(self):
         dn = FakeDataNode("foo_bar")
         with pytest.raises(NoData):

+ 1 - 1
tests/core/data/test_mongo_data_node.py

@@ -88,7 +88,7 @@ class TestMongoCollectionDataNode:
         assert mongo_dn.owner_id is None
         assert mongo_dn.job_ids == []
         assert mongo_dn.is_ready_for_reading
-        assert mongo_dn.custom_document == MongoDefaultDocument
+        assert mongo_dn.custom_mongo_document == MongoDefaultDocument
 
     @pytest.mark.parametrize("properties", __properties)
     def test_get_user_properties(self, properties):

+ 1 - 1
tests/core/scenario/test_scenario.py

@@ -163,7 +163,7 @@ def test_create_scenario_and_add_sequences():
     assert scenario.sequences == {"sequence_1": scenario.sequence_1, "sequence_2": scenario.sequence_2}
 
 
-def test_get_set_property_to_scenario():
+def test_get_set_property_and_attribute():
     dn_cfg = Config.configure_data_node("bar")
     s_cfg = Config.configure_scenario("foo", additional_data_node_configs=[dn_cfg], key="value")
     scenario = create_scenario(s_cfg)

+ 29 - 0
tests/core/sequence/test_sequence.py

@@ -20,6 +20,7 @@ from taipy.core.data._data_manager_factory import _DataManagerFactory
 from taipy.core.data.data_node import DataNode
 from taipy.core.data.in_memory import InMemoryDataNode
 from taipy.core.data.pickle import PickleDataNode
+from taipy.core.exceptions import AttributeKeyAlreadyExisted, PropertyKeyAlreadyExisted
 from taipy.core.scenario._scenario_manager import _ScenarioManager
 from taipy.core.scenario.scenario import Scenario
 from taipy.core.sequence._sequence_manager import _SequenceManager
@@ -126,6 +127,34 @@ def test_create_sequence():
         assert sequence_2.get_simple_label() == sequence_2.name
 
 
+def test_get_set_property_and_attribute():
+    dn_cfg = Config.configure_data_node("bar")
+    task_config = Config.configure_task("print", print, [dn_cfg], None)
+    scenario_config = Config.configure_scenario("scenario", [task_config])
+
+    scenario = _ScenarioManager._create(scenario_config)
+    scenario.add_sequences({"seq": list(scenario.tasks.values())})
+    sequence = scenario.sequences["seq"]
+    sequence.properties["key"] = "value"
+
+    assert sequence.properties == {"name": "seq", "key": "value"}
+    assert sequence.key == "value"
+
+    sequence.properties["new_key"] = "new_value"
+    sequence.another_key = "another_value"
+
+    assert sequence.key == "value"
+    assert sequence.new_key == "new_value"
+    assert sequence.another_key == "another_value"
+    assert sequence.properties == {"name": "seq", "key": "value", "new_key": "new_value"}
+
+    with pytest.raises(AttributeKeyAlreadyExisted):
+        sequence.bar = "KeyAlreadyUsed"
+
+    with pytest.raises(PropertyKeyAlreadyExisted):
+        sequence.properties["bar"] = "KeyAlreadyUsed"
+
+
 def test_check_consistency():
     sequence_1 = Sequence({}, [], "name_1")
     assert sequence_1._is_consistent()

+ 29 - 0
tests/core/task/test_task.py

@@ -21,6 +21,8 @@ from taipy.core.data._data_manager import _DataManager
 from taipy.core.data.csv import CSVDataNode
 from taipy.core.data.data_node import DataNode
 from taipy.core.data.in_memory import InMemoryDataNode
+from taipy.core.exceptions import AttributeKeyAlreadyExisted, PropertyKeyAlreadyExisted
+from taipy.core.scenario._scenario_manager import _ScenarioManager
 from taipy.core.task._task_manager import _TaskManager
 from taipy.core.task._task_manager_factory import _TaskManagerFactory
 from taipy.core.task.task import Task
@@ -110,6 +112,33 @@ def test_create_task():
         assert task.get_simple_label() == task.config_id
 
 
+def test_get_set_property_and_attribute():
+    dn_cfg = Config.configure_data_node("bar")
+    task_config = Config.configure_task("print", print, [dn_cfg], None)
+    scenario_config = Config.configure_scenario("scenario", [task_config])
+    scenario = _ScenarioManager._create(scenario_config)
+    task = scenario.tasks["print"]
+
+    task.properties["key"] = "value"
+
+    assert task.properties == {"key": "value"}
+    assert task.key == "value"
+
+    task.properties["new_key"] = "new_value"
+    task.another_key = "another_value"
+
+    assert task.key == "value"
+    assert task.new_key == "new_value"
+    assert task.another_key == "another_value"
+    assert task.properties == {"key": "value", "new_key": "new_value"}
+
+    with pytest.raises(AttributeKeyAlreadyExisted):
+        task.bar = "KeyAlreadyUsed"
+
+    with pytest.raises(PropertyKeyAlreadyExisted):
+        task.properties["bar"] = "KeyAlreadyUsed"
+
+
 def test_can_not_change_task_output(output):
     task = Task("name_1", {}, print, output=output)