Selaa lähdekoodia

support async callback with state (#2295)

* support async callback with state
resolves #2288

* fix test
add a simple example

---------

Co-authored-by: Fred Lefévère-Laoide <Fred.Lefevere-Laoide@Taipy.io>
Fred Lefévère-Laoide 5 kuukautta sitten
vanhempi
säilyke
69111574ef

+ 52 - 0
doc/gui/examples/async_callback.py

@@ -0,0 +1,52 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+# -----------------------------------------------------------------------------------------
+# To execute this script, make sure that the taipy-gui package is installed in your
+# Python environment and run:
+#     python <script>
+# -----------------------------------------------------------------------------------------
+# Demonstrate how to update the value of a variable across multiple clients.
+# This application creates a thread that sets a variable to the current time.
+# The value is updated for every client when Gui.broadcast_change() is invoked.
+# -----------------------------------------------------------------------------------------
+import asyncio
+
+import taipy.gui.builder as tgb
+from taipy.gui import Gui, State
+
+
+# This callback is invoked inside a separate thread
+# it can access the state but cannot return a value
+async def heavy_function(state: State):
+    state.logs = "Starting...\n"
+    state.logs += "Searching documents\n"
+    await asyncio.sleep(5)
+    state.logs += "Responding to user\n"
+    await asyncio.sleep(5)
+    state.logs += "Fact Checking\n"
+    await asyncio.sleep(5)
+    state.result = "Done!"
+
+logs = ""
+result = "No response yet"
+
+with tgb.Page() as main_page:
+    # the async callback is used as any other callback
+    tgb.button("Respond", on_action=heavy_function)
+    with tgb.part("card"):
+        tgb.text("{logs}", mode="pre")
+
+    tgb.text("# Result", mode="md")
+    tgb.text("{result}")
+
+
+if __name__ == "__main__":
+    Gui(main_page).run(title="Async - Callback")

+ 13 - 5
taipy/gui/gui.py

@@ -25,7 +25,7 @@ import uuid
 import warnings
 from importlib import metadata, util
 from importlib.util import find_spec
-from inspect import currentframe, getabsfile, ismethod, ismodule
+from inspect import currentframe, getabsfile, iscoroutinefunction, ismethod, ismodule
 from pathlib import Path
 from threading import Thread, Timer
 from types import FrameType, FunctionType, LambdaType, ModuleType, SimpleNamespace
@@ -73,7 +73,7 @@ from .extension.library import Element, ElementLibrary
 from .page import Page
 from .partial import Partial
 from .server import _Server
-from .state import State, _GuiState
+from .state import State, _AsyncState, _GuiState
 from .types import _WsType
 from .utils import (
     _delscopeattr,
@@ -115,6 +115,7 @@ from .utils._evaluator import _Evaluator
 from .utils._variable_directory import _is_moduled_variable, _VariableDirectory
 from .utils.chart_config_builder import _build_chart_config
 from .utils.table_col_builder import _enhance_columns
+from .utils.threads import _invoke_async_callback
 
 
 class Gui:
@@ -1143,7 +1144,6 @@ class Gui:
             for var, val in state_context.items():
                 self._update_var(var, val, True, forward=False)
 
-
     @staticmethod
     def set_unsupported_data_converter(converter: t.Optional[t.Callable[[t.Any], t.Any]]) -> None:
         """Set a custom converter for unsupported data types.
@@ -1588,7 +1588,12 @@ class Gui:
 
     def _call_function_with_state(self, user_function: t.Callable, args: t.Optional[t.List[t.Any]] = None) -> t.Any:
         cp_args = [] if args is None else args.copy()
-        cp_args.insert(0, self.__get_state())
+        cp_args.insert(
+            0,
+            _AsyncState(t.cast(_GuiState, self.__get_state()))
+            if iscoroutinefunction(user_function)
+            else self.__get_state(),
+        )
         argcount = user_function.__code__.co_argcount
         if argcount > 0 and ismethod(user_function):
             argcount -= 1
@@ -1597,7 +1602,10 @@ class Gui:
         else:
             cp_args = cp_args[:argcount]
         with self.__event_manager:
-            return user_function(*cp_args)
+            if iscoroutinefunction(user_function):
+                return _invoke_async_callback(user_function, cp_args)
+            else:
+                return user_function(*cp_args)
 
     def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]:
         return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext()

+ 9 - 7
taipy/gui/gui_actions.py

@@ -15,6 +15,7 @@ import typing as t
 from ._warnings import _warn
 from .gui import Gui
 from .state import State
+from .utils.callable import _is_function
 
 
 def download(
@@ -382,19 +383,20 @@ def invoke_long_callback(
     """
     if not state or not isinstance(state._gui, Gui):
         _warn("'invoke_long_callback()' must be called in the context of a callback.")
+        return
 
     if user_status_function_args is None:
         user_status_function_args = []
     if user_function_args is None:
         user_function_args = []
 
-    state_id = get_state_id(state)
-    module_context = get_module_context(state)
+    this_gui = state.get_gui()
+
+    state_id = this_gui._get_client_id()
+    module_context = this_gui._get_locals_context()
     if not isinstance(state_id, str) or not isinstance(module_context, str):
         return
 
-    this_gui = state._gui
-
     def callback_on_exception(state: State, function_name: str, e: Exception):
         if not this_gui._call_on_exception(function_name, e):
             _warn(f"invoke_long_callback(): Exception raised in function {function_name}()", e)
@@ -405,10 +407,10 @@ def invoke_long_callback(
         function_name: t.Optional[str] = None,
         function_result: t.Optional[t.Any] = None,
     ):
-        if callable(user_status_function):
+        if _is_function(user_status_function):
             this_gui.invoke_callback(
                 str(state_id),
-                user_status_function,
+                t.cast(t.Callable, user_status_function),
                 [status] + list(user_status_function_args) + [function_result],  # type: ignore
                 str(module_context),
             )
@@ -438,5 +440,5 @@ def invoke_long_callback(
 
     thread = threading.Thread(target=user_function_in_thread, args=user_function_args)
     thread.start()
-    if isinstance(period, int) and period >= 500 and callable(user_status_function):
+    if isinstance(period, int) and period >= 500 and _is_function(user_status_function):
         thread_status(thread.name, period / 1000.0, 0)

+ 25 - 4
taipy/gui/state.py

@@ -171,10 +171,7 @@ class _GuiState(State):
         "_get_placeholder_attrs",
         "_add_attribute",
     )
-    __placeholder_attrs = (
-        "_taipy_p1",
-        "_current_context",
-    )
+    __placeholder_attrs = ("_taipy_p1", "_current_context", "__state_id")
     __excluded_attrs = __attrs + __methods + __placeholder_attrs
 
     def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
@@ -278,3 +275,27 @@ class _GuiState(State):
             gui = super().__getattribute__(_GuiState.__gui_attr)
             return gui._bind_var_val(name, default_value)
         return False
+
+
+class _AsyncState(_GuiState):
+    def __init__(self, state: State) -> None:
+        super().__init__(state.get_gui(), [], [])
+        self._set_placeholder("__state_id", state.get_gui()._get_client_id())
+
+    @staticmethod
+    def __set_var_in_state(state: State, var_name: str, value: t.Any):
+        setattr(state, var_name, value)
+
+    @staticmethod
+    def __get_var_from_state(state: State, var_name: str):
+        return getattr(state, var_name)
+
+    def __setattr__(self, var_name: str, var_value: t.Any) -> None:
+        self.get_gui().invoke_callback(
+            t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__set_var_in_state, [var_name, var_value]
+        )
+
+    def __getattr__(self, var_name: str) -> t.Any:
+        return self.get_gui().invoke_callback(
+            t.cast(str, self._get_placeholder("__state_id")), _AsyncState.__get_var_from_state, [var_name]
+        )

+ 22 - 0
taipy/gui/utils/threads.py

@@ -0,0 +1,22 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+import asyncio
+import threading
+import typing as t
+
+
+def _thread_async_target(user_function, args: t.List[t.Any]):
+    asyncio.run(user_function(*args))
+
+
+def _invoke_async_callback(user_function, args: t.List[t.Any]):
+    thread = threading.Thread(target=_thread_async_target, args=[user_function, args])
+    thread.start()

+ 1 - 0
tests/gui/gui_specific/test_state.py

@@ -54,6 +54,7 @@ def test_state(gui: Gui):
         assert state._get_placeholder_attrs() == (
             "_taipy_p1",
             "_current_context",
+            "__state_id"
         )
 
         assert get_a(state) == 20