Pārlūkot izejas kodu

authorization for submission access (#2276)

* authorization for submission access
resolves taipy-enterprise#562

* lint

* fix test

* fix test

---------

Co-authored-by: Fred Lefévère-Laoide <Fred.Lefevere-Laoide@Taipy.io>
Fred Lefévère-Laoide 5 mēneši atpakaļ
vecāks
revīzija
b1894c8ca9

+ 30 - 25
taipy/gui_core/_context.py

@@ -70,6 +70,7 @@ from ._adapters import (
     _GuiCoreScenarioProperties,
     _invoke_action,
 )
+from ._utils import _ClientStatus
 from .filters import CustomScenarioFilter
 
 
@@ -92,7 +93,7 @@ class _GuiCoreContext(CoreEventConsumerBase):
         self.data_nodes_by_owner: t.Optional[t.Dict[t.Optional[str], t.List[DataNode]]] = None
         self.scenario_configs: t.Optional[t.List[t.Tuple[str, str]]] = None
         self.jobs_list: t.Optional[t.List[Job]] = None
-        self.client_submission: t.Dict[str, SubmissionStatus] = {}
+        self.client_submission: t.Dict[str, _ClientStatus] = {}
         # register to taipy core notification
         reg_id, reg_queue = Notifier.register()
         # locks
@@ -162,28 +163,32 @@ class _GuiCoreContext(CoreEventConsumerBase):
         self.broadcast_core_changed({"scenario": scenario_id or True})
 
     def submission_status_callback(self, submission_id: t.Optional[str] = None, event: t.Optional[Event] = None):
-        if not submission_id or not is_readable(t.cast(SubmissionId, submission_id)):
+        if not submission_id:
             return
         submission = None
         new_status = None
         payload: t.Optional[t.Dict[str, t.Any]] = None
         client_id: t.Optional[str] = None
         try:
-            last_status = self.client_submission.get(submission_id)
-            if not last_status:
+            last_client_status = self.client_submission.get(submission_id)
+            if not last_client_status:
                 return
 
-            submission = t.cast(Submission, core_get(submission_id))
-            if not submission or not submission.entity_id:
-                return
+            client_id = last_client_status.client_id
+
+            with self.gui._get_authorization(client_id):
+                if not is_readable(t.cast(SubmissionId, submission_id)):
+                    return
+                submission = t.cast(Submission, core_get(submission_id))
+                if not submission or not submission.entity_id:
+                    return
 
-            payload = {}
-            new_status = t.cast(SubmissionStatus, submission.submission_status)
+                payload = {}
+                new_status = t.cast(SubmissionStatus, submission.submission_status)
 
-            client_id = submission.properties.get("client_id")
-            if client_id:
-                running_tasks = {}
-                with self.gui._get_authorization(client_id):
+                if client_id:
+                    running_tasks = {}
+                    # with self.gui._get_authorization(client_id):
                     for job in submission.jobs:
                         job = job if isinstance(job, Job) else t.cast(Job, core_get(job))
                         running_tasks[job.task.id] = (
@@ -195,7 +200,7 @@ class _GuiCoreContext(CoreEventConsumerBase):
                         )
                     payload.update(tasks=running_tasks)
 
-                    if last_status is not new_status:
+                    if last_client_status.submission_status is not new_status:
                         # callback
                         submission_name = submission.properties.get("on_submission")
                         if submission_name:
@@ -213,15 +218,15 @@ class _GuiCoreContext(CoreEventConsumerBase):
                                 submission.properties.get("module_context"),
                             )
 
-            with self.submissions_lock:
-                if new_status in (
-                    SubmissionStatus.COMPLETED,
-                    SubmissionStatus.FAILED,
-                    SubmissionStatus.CANCELED,
-                ):
+            if new_status in (
+                SubmissionStatus.COMPLETED,
+                SubmissionStatus.FAILED,
+                SubmissionStatus.CANCELED,
+            ):
+                with self.submissions_lock:
                     self.client_submission.pop(submission_id, None)
-                else:
-                    self.client_submission[submission_id] = new_status
+            else:
+                last_client_status.submission_status = new_status
 
         except Exception as e:
             _warn(f"Submission ({submission_id}) is not available", e)
@@ -634,11 +639,11 @@ class _GuiCoreContext(CoreEventConsumerBase):
                     client_id=self.gui._get_client_id(),
                     module_context=self.gui._get_locals_context(),
                 )
+                client_status = _ClientStatus(self.gui._get_client_id(), submission_entity.submission_status)
                 with self.submissions_lock:
-                    self.client_submission[submission_entity.id] = submission_entity.submission_status
+                    self.client_submission[submission_entity.id] = client_status
                 if Config.core.mode == "development":
-                    with self.submissions_lock:
-                        self.client_submission[submission_entity.id] = SubmissionStatus.SUBMITTED
+                    client_status.submission_status = SubmissionStatus.SUBMITTED
                     self.submission_status_callback(submission_entity.id)
                 _GuiCoreContext.__assign_var(state, error_var, "")
         except Exception as e:

+ 20 - 0
taipy/gui_core/_utils.py

@@ -0,0 +1,20 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+import typing as t
+from dataclasses import dataclass
+
+from taipy.core.submission.submission_status import SubmissionStatus
+
+
+@dataclass
+class _ClientStatus:
+    client_id: t.Optional[str]
+    submission_status: SubmissionStatus

+ 40 - 37
tests/gui_core/test_context_is_readable.py

@@ -27,8 +27,9 @@ from taipy.core.scenario._scenario_manager_factory import _ScenarioManagerFactor
 from taipy.core.submission._submission_manager_factory import _SubmissionManagerFactory
 from taipy.core.submission.submission import Submission, SubmissionStatus
 from taipy.core.task._task_manager_factory import _TaskManagerFactory
-from taipy.gui import Gui
+from taipy.gui import Gui, State
 from taipy.gui_core._context import _GuiCoreContext
+from taipy.gui_core._utils import _ClientStatus
 
 a_cycle = Cycle(Frequency.DAILY, {}, datetime.now(), datetime.now(), datetime.now(), id=CycleId("CYCLE_id"))
 a_scenario = Scenario("scenario_config_id", None, {}, sequences={"sequence": {}})
@@ -66,9 +67,14 @@ def mock_core_get(entity_id):
     return a_task
 
 
-class MockState:
+class MockState(State):
     def __init__(self, **kwargs) -> None:
-        self.assign = kwargs.get("assign")
+        self.assign = t.cast(t.Callable, kwargs.get("assign")) # type: ignore[method-assign]
+        self.gui = t.cast(Gui, kwargs.get("gui"))
+    def get_gui(self):
+        return self.gui
+    def broadcast(self, name: str, value: t.Any):
+        pass
 
 
 class TestGuiCoreContext_is_readable:
@@ -96,7 +102,7 @@ class TestGuiCoreContext_is_readable:
     def test_cycle_adapter(self):
         with patch("taipy.gui_core._context.core_get", side_effect=mock_core_get):
             gui_core_context = _GuiCoreContext(Mock())
-            gui_core_context.scenario_by_cycle = {"a": 1}
+            gui_core_context.scenario_by_cycle = t.cast(dict, {"a": 1})
             outcome = gui_core_context.cycle_adapter(a_cycle)
             assert isinstance(outcome, list)
             assert outcome[0] == a_cycle.id
@@ -120,9 +126,9 @@ class TestGuiCoreContext_is_readable:
             gui_core_context = _GuiCoreContext(Mock())
             assign = Mock()
             gui_core_context.crud_scenario(
-                MockState(assign=assign),
+                MockState(assign=assign, gui=gui_core_context.gui),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         "",
                         "",
@@ -132,7 +138,7 @@ class TestGuiCoreContext_is_readable:
                         {"name": "name", "id": a_scenario.id},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_not_called()
 
@@ -141,7 +147,7 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.crud_scenario(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             "",
                             "",
@@ -151,7 +157,7 @@ class TestGuiCoreContext_is_readable:
                             {"name": "name", "id": a_scenario.id},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -164,12 +170,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.edit_entity(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"name": "name", "id": a_scenario.id},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called_once()
             assert assign.call_args.args[0] == "error_var"
@@ -180,12 +186,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.edit_entity(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"name": "name", "id": a_scenario.id},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -198,10 +204,7 @@ class TestGuiCoreContext_is_readable:
             mockGui._get_authorization = lambda s: contextlib.nullcontext()
             gui_core_context = _GuiCoreContext(mockGui)
 
-            def sub_cb():
-                return True
-
-            gui_core_context.client_submission[a_submission.id] = SubmissionStatus.UNDEFINED
+            gui_core_context.client_submission[a_submission.id] = _ClientStatus("client_id", SubmissionStatus.UNDEFINED)
             gui_core_context.submission_status_callback(a_submission.id)
             mockget.assert_called()
             found = False
@@ -248,12 +251,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.act_on_jobs(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"id": [a_job.id], "action": "delete"},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called_once()
             assert assign.call_args.args[0] == "error_var"
@@ -263,12 +266,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.act_on_jobs(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"id": [a_job.id], "action": "cancel"},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called_once()
             assert assign.call_args.args[0] == "error_var"
@@ -279,12 +282,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.act_on_jobs(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"id": [a_job.id], "action": "delete"},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -294,12 +297,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.act_on_jobs(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"id": [a_job.id], "action": "cancel"},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -312,12 +315,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.edit_data_node(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"id": a_datanode.id},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called_once()
             assert assign.call_args.args[0] == "error_var"
@@ -328,12 +331,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.edit_data_node(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"id": a_datanode.id},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -348,12 +351,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.lock_datanode_for_edit(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"id": a_datanode.id},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called_once()
             assert assign.call_args.args[0] == "error_var"
@@ -364,12 +367,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.lock_datanode_for_edit(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"id": a_datanode.id},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"
@@ -395,12 +398,12 @@ class TestGuiCoreContext_is_readable:
             gui_core_context.update_data(
                 MockState(assign=assign),
                 "",
-                {
+                t.cast(dict, {
                     "args": [
                         {"id": a_datanode.id},
                     ],
                     "error_id": "error_var",
-                },
+                }),
             )
             assign.assert_called()
             assert assign.call_args_list[0].args[0] == "error_var"
@@ -411,12 +414,12 @@ class TestGuiCoreContext_is_readable:
                 gui_core_context.update_data(
                     MockState(assign=assign),
                     "",
-                    {
+                    t.cast(dict, {
                         "args": [
                             {"id": a_datanode.id},
                         ],
                         "error_id": "error_var",
-                    },
+                    }),
                 )
                 assign.assert_called_once()
                 assert assign.call_args.args[0] == "error_var"