Преглед изворни кода

support async callback with state (#2295)

* support async callback with state
resolves #2288

* fix test
add a simple example

---------

Co-authored-by: Fred Lefévère-Laoide <Fred.Lefevere-Laoide@Taipy.io>
Fred Lefévère-Laoide пре 5 месеци
родитељ
комит
69111574ef

+ 52 - 0
doc/gui/examples/async_callback.py

@@ -0,0 +1,52 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+# -----------------------------------------------------------------------------------------
+# To execute this script, make sure that the taipy-gui package is installed in your
+# Python environment and run:
+#     python <script>
+# -----------------------------------------------------------------------------------------
+# Demonstrate how to update the value of a variable across multiple clients.
+# This application creates a thread that sets a variable to the current time.
+# The value is updated for every client when Gui.broadcast_change() is invoked.
+# -----------------------------------------------------------------------------------------
+import asyncio
+
+import taipy.gui.builder as tgb
+from taipy.gui import Gui, State
+
+
+# This callback is invoked inside a separate thread
+# it can access the state but cannot return a value
+async def heavy_function(state: State):
+    state.logs = "Starting...\n"
+    state.logs += "Searching documents\n"
+    await asyncio.sleep(5)
+    state.logs += "Responding to user\n"
+    await asyncio.sleep(5)
+    state.logs += "Fact Checking\n"
+    await asyncio.sleep(5)
+    state.result = "Done!"
+
+logs = ""
+result = "No response yet"
+
+with tgb.Page() as main_page:
+    # the async callback is used as any other callback
+    tgb.button("Respond", on_action=heavy_function)
+    with tgb.part("card"):
+        tgb.text("{logs}", mode="pre")
+
+    tgb.text("# Result", mode="md")
+    tgb.text("{result}")
+
+
+if __name__ == "__main__":
+    Gui(main_page).run(title="Async - Callback")

+ 13 - 5
taipy/gui/gui.py

@@ -25,7 +25,7 @@ import uuid
 import warnings
 import warnings
 from importlib import metadata, util
 from importlib import metadata, util
 from importlib.util import find_spec
 from importlib.util import find_spec
-from inspect import currentframe, getabsfile, ismethod, ismodule
+from inspect import currentframe, getabsfile, iscoroutinefunction, ismethod, ismodule
 from pathlib import Path
 from pathlib import Path
 from threading import Thread, Timer
 from threading import Thread, Timer
 from types import FrameType, FunctionType, LambdaType, ModuleType, SimpleNamespace
 from types import FrameType, FunctionType, LambdaType, ModuleType, SimpleNamespace
@@ -73,7 +73,7 @@ from .extension.library import Element, ElementLibrary
 from .page import Page
 from .page import Page
 from .partial import Partial
 from .partial import Partial
 from .server import _Server
 from .server import _Server
-from .state import State, _GuiState
+from .state import State, _AsyncState, _GuiState
 from .types import _WsType
 from .types import _WsType
 from .utils import (
 from .utils import (
     _delscopeattr,
     _delscopeattr,
@@ -115,6 +115,7 @@ from .utils._evaluator import _Evaluator
 from .utils._variable_directory import _is_moduled_variable, _VariableDirectory
 from .utils._variable_directory import _is_moduled_variable, _VariableDirectory
 from .utils.chart_config_builder import _build_chart_config
 from .utils.chart_config_builder import _build_chart_config
 from .utils.table_col_builder import _enhance_columns
 from .utils.table_col_builder import _enhance_columns
+from .utils.threads import _invoke_async_callback
 
 
 
 
 class Gui:
 class Gui:
@@ -1143,7 +1144,6 @@ class Gui:
             for var, val in state_context.items():
             for var, val in state_context.items():
                 self._update_var(var, val, True, forward=False)
                 self._update_var(var, val, True, forward=False)
 
 
-
     @staticmethod
     @staticmethod
     def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None:
     def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None:
         """Set a custom converter for unsupported data types.
         """Set a custom converter for unsupported data types.
@@ -1588,7 +1588,12 @@ class Gui:
 
 
     def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[t.List[t.Any]] = None) -> t.Any:
     def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[t.List[t.Any]] = None) -> t.Any:
         cp_args = [] if args is None else args.copy()
         cp_args = [] if args is None else args.copy()
-        cp_args.insert(0, self.__get_state())
+        cp_args.insert(
+            0,
+            _AsyncState(t.cast(_GuiState, self.__get_state()))
+            if iscoroutinefunction(user_function)
+            else self.__get_state(),
+        )
         argcount = user_function.__code__.co_argcount
         argcount = user_function.__code__.co_argcount
         if argcount > 0 and ismethod(user_function):
         if argcount > 0 and ismethod(user_function):
             argcount -= 1
             argcount -= 1
@@ -1597,7 +1602,10 @@ class Gui:
         else:
         else:
             cp_args = cp_args[:argcount]
             cp_args = cp_args[:argcount]
         with self.__event_manager:
         with self.__event_manager:
-            return user_function(*cp_args)
+            if iscoroutinefunction(user_function):
+                return _invoke_async_callback(user_function, cp_args)
+            else:
+                return user_function(*cp_args)
 
 
     def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]:
     def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]:
         return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext()
         return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext()

+ 9 - 7
taipy/gui/gui_actions.py

@@ -15,6 +15,7 @@ import typing as t
 from ._warnings import _warn
 from ._warnings import _warn
 from .gui import Gui
 from .gui import Gui
 from .state import State
 from .state import State
+from .utils.callable import _is_function
 
 
 
 
 def download(
 def download(
@@ -382,19 +383,20 @@ def invoke_long_callback(
     """
     """
     if not state or not isinstance(state._gui, Gui):
     if not state or not isinstance(state._gui, Gui):
         _warn("'invoke_long_callback()' must be called in the context of a callback.")
         _warn("'invoke_long_callback()' must be called in the context of a callback.")
+        return
 
 
     if user_status_function_args is None:
     if user_status_function_args is None:
         user_status_function_args = []
         user_status_function_args = []
     if user_function_args is None:
     if user_function_args is None:
         user_function_args = []
         user_function_args = []
 
 
-    state_id = get_state_id(state)
-    module_context = get_module_context(state)
+    this_gui = state.get_gui()
+
+    state_id = this_gui._get_client_id()
+    module_context = this_gui._get_locals_context()
     if not isinstance(state_id, str) or not isinstance(module_context, str):
     if not isinstance(state_id, str) or not isinstance(module_context, str):
         return
         return
 
 
-    this_gui = state._gui
-
     def callback_on_exception(state: State, function_name: str, e: Exception):
     def callback_on_exception(state: State, function_name: str, e: Exception):
         if not this_gui._call_on_exception(function_name, e):
         if not this_gui._call_on_exception(function_name, e):
             _warn(f"invoke_long_callback(): Exception raised in function {function_name}()", e)
             _warn(f"invoke_long_callback(): Exception raised in function {function_name}()", e)
@@ -405,10 +407,10 @@ def invoke_long_callback(
         function_name: t.Optional[str] = None,
         function_name: t.Optional[str] = None,
         function_result: t.Optional[t.Any] = None,
         function_result: t.Optional[t.Any] = None,
     ):
     ):
-        if callable(user_status_function):
+        if _is_function(user_status_function):
             this_gui.invoke_callback(
             this_gui.invoke_callback(
                 str(state_id),
                 str(state_id),
-                user_status_function,
+                t.cast(t.Callable, user_status_function),
                 [status] + list(user_status_function_args) + [function_result],  # type: ignore
                 [status] + list(user_status_function_args) + [function_result],  # type: ignore
                 str(module_context),
                 str(module_context),
             )
             )
@@ -438,5 +440,5 @@ def invoke_long_callback(
 
 
     thread = threading.Thread(target=user_function_in_thread, args=user_function_args)
     thread = threading.Thread(target=user_function_in_thread, args=user_function_args)
     thread.start()
     thread.start()
-    if isinstance(period, int) and period >= 500 and callable(user_status_function):
+    if isinstance(period, int) and period >= 500 and _is_function(user_status_function):
         thread_status(thread.name, period / 1000.0, 0)
         thread_status(thread.name, period / 1000.0, 0)

+ 25 - 4
taipy/gui/state.py

@@ -171,10 +171,7 @@ class _GuiState(State):
         "_get_placeholder_attrs",
         "_get_placeholder_attrs",
         "_add_attribute",
         "_add_attribute",
     )
     )
-    __placeholder_attrs = (
-        "_taipy_p1",
-        "_current_context",
-    )
+    __placeholder_attrs = ("_taipy_p1", "_current_context", "__state_id")
     __excluded_attrs = __attrs + __methods + __placeholder_attrs
     __excluded_attrs = __attrs + __methods + __placeholder_attrs
 
 
     def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
     def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
@@ -278,3 +275,27 @@ class _GuiState(State):
             gui = super().__getattribute__(_GuiState.__gui_attr)
             gui = super().__getattribute__(_GuiState.__gui_attr)
             return gui._bind_var_val(name, default_value)
             return gui._bind_var_val(name, default_value)
         return False
         return False
+
+
+class _AsyncState(_GuiState):
+    def __init__(self, state: State) -> None:
+        super().__init__(state.get_gui(), [], [])
+        self._set_placeholder("__state_id", state.get_gui()._get_client_id())
+
+    @staticmethod
+    def __set_var_in_state(state: State, var_name: str, value: t.Any):
+        setattr(state, var_name, value)
+
+    @staticmethod
+    def __get_var_from_state(state: State, var_name: str):
+        return getattr(state, var_name)
+
+    def __setattr__(self, var_name: str, var_value: t.Any) -> None:
+        self.get_gui().invoke_callback(
+            t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__set_var_in_state, [var_name, var_value]
+        )
+
+    def __getattr__(self, var_name: str) -> t.Any:
+        return self.get_gui().invoke_callback(
+            t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__get_var_from_state, [var_name]
+        )

+ 22 - 0
taipy/gui/utils/threads.py

@@ -0,0 +1,22 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+import asyncio
+import threading
+import typing as t
+
+
+def _thread_async_target(user_function, args: t.List[t.Any]):
+    asyncio.run(user_function(*args))
+
+
+def _invoke_async_callback(user_function, args: t.List[t.Any]):
+    thread = threading.Thread(target=_thread_async_target, args=[user_function, args])
+    thread.start()

+ 1 - 0
tests/gui/gui_specific/test_state.py

@@ -54,6 +54,7 @@ def test_state(gui: Gui):
         assert state._get_placeholder_attrs() == (
         assert state._get_placeholder_attrs() == (
             "_taipy_p1",
             "_taipy_p1",
             "_current_context",
             "_current_context",
+            "__state_id"
         )
         )
 
 
         assert get_a(state) == 20
         assert get_a(state) == 20