Parcourir la source

Adding a mock State class that can be instantiated and queried while testing (#2101)

* Adding a testing State class that can be instantiated and queried during tests
I hoped it would not have any impact ont the real code but it has. Any discussion is welcome before merge.
resolves #2098

* tricky to get gui :-)

* consistency

* linter

* test => mock
more tests

* following Long comments

* Fab's right

Co-authored-by: Fabien Lelaquais <86590727+FabienLelaquais@users.noreply.github.com>

---------

Co-authored-by: Fred Lefévère-Laoide <Fred.Lefevere-Laoide@Taipy.io>
Co-authored-by: Fabien Lelaquais <86590727+FabienLelaquais@users.noreply.github.com>
Fred Lefévère-Laoide il y a 7 mois
Parent
commit
691af7133d

+ 7 - 3
taipy/gui/gui.py

@@ -73,7 +73,7 @@ from .extension.library import Element, ElementLibrary
 from .page import Page
 from .page import Page
 from .partial import Partial
 from .partial import Partial
 from .server import _Server
 from .server import _Server
-from .state import State
+from .state import State, _GuiState
 from .types import _WsType
 from .types import _WsType
 from .utils import (
 from .utils import (
     _delscopeattr,
     _delscopeattr,
@@ -2292,7 +2292,9 @@ class Gui:
             if isinstance(callback, str)
             if isinstance(callback, str)
             else _get_lambda_id(t.cast(LambdaType, callback))
             else _get_lambda_id(t.cast(LambdaType, callback))
             if _is_unnamed_function(callback)
             if _is_unnamed_function(callback)
-            else callback.__name__ if callback is not None else None
+            else callback.__name__
+            if callback is not None
+            else None
         )
         )
         func = self.__get_on_cancel_block_ui(action_name)
         func = self.__get_on_cancel_block_ui(action_name)
         def_action_name = func.__name__
         def_action_name = func.__name__
@@ -2809,7 +2811,9 @@ class Gui:
         self.__var_dir.set_default(self.__frame)
         self.__var_dir.set_default(self.__frame)
 
 
         if self.__state is None or is_reloading:
         if self.__state is None or is_reloading:
-            self.__state = State(self, self.__locals_context.get_all_keys(), self.__locals_context.get_all_context())
+            self.__state = _GuiState(
+                self, self.__locals_context.get_all_keys(), self.__locals_context.get_all_context()
+            )
 
 
         if _is_in_notebook():
         if _is_in_notebook():
             # Allow gui.state.x in notebook mode
             # Allow gui.state.x in notebook mode

+ 10 - 0
taipy/gui/mock/__init__.py

@@ -0,0 +1,10 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.

+ 62 - 0
taipy/gui/mock/mock_state.py

@@ -0,0 +1,62 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+import typing as t
+
+from .. import Gui, State
+from ..utils import _MapDict
+
+
+class MockState(State):
+    """A Mock implementation for `State`.
+    TODO
+    example of use:
+    ```py
+    def test_callback():
+        ms = MockState(Gui(""), a = 1)
+        on_action(ms) # function to test
+        assert ms.a == 2
+    ```
+    """
+
+    __VARS = "vars"
+
+    def __init__(self, gui: Gui, **kwargs) -> None:
+        super().__setattr__(MockState.__VARS, {k: _MapDict(v) if isinstance(v, dict) else v for k, v in kwargs.items()})
+        self._gui = gui
+        super().__init__()
+
+    def get_gui(self) -> Gui:
+        return self._gui
+
+    def __getattribute__(self, name: str) -> t.Any:
+        if attr := t.cast(dict, super().__getattribute__(MockState.__VARS)).get(name):
+            return attr
+        try:
+            return super().__getattribute__(name)
+        except Exception:
+            return None
+
+    def __setattr__(self, name: str, value: t.Any) -> None:
+        t.cast(dict, super().__getattribute__(MockState.__VARS))[name] = (
+            _MapDict(value) if isinstance(value, dict) else value
+        )
+
+    def __getitem__(self, key: str):
+        return self
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        return True
+
+    def broadcast(self, name: str, value: t.Any):
+        pass

+ 109 - 94
taipy/gui/state.py

@@ -11,6 +11,7 @@
 
 
 import inspect
 import inspect
 import typing as t
 import typing as t
+from abc import ABC, abstractmethod
 from contextlib import nullcontext
 from contextlib import nullcontext
 from operator import attrgetter
 from operator import attrgetter
 from pathlib import Path
 from pathlib import Path
@@ -25,7 +26,7 @@ if t.TYPE_CHECKING:
     from .gui import Gui
     from .gui import Gui
 
 
 
 
-class State:
+class State(ABC):
     """Accessor to the bound variables from callbacks.
     """Accessor to the bound variables from callbacks.
 
 
     `State` is used when you need to access the value of variables
     `State` is used when you need to access the value of variables
@@ -73,6 +74,87 @@ class State:
     ```
     ```
     """
     """
 
 
+    def __init__(self) -> None:
+        self._gui: "Gui"
+
+    @abstractmethod
+    def get_gui(self) -> "Gui":
+        """Return the Gui instance for this state object.
+
+        Returns:
+            Gui: The Gui instance for this state object.
+        """
+        raise NotImplementedError
+
+    def assign(self, name: str, value: t.Any) -> t.Any:
+        """Assign a value to a state variable.
+
+        This should be used only from within a lambda function used
+        as a callback in a visual element.
+
+        Arguments:
+            name (str): The variable name to assign to.
+            value (Any): The new variable value.
+
+        Returns:
+            Any: The previous value of the variable.
+        """
+        val = attrgetter(name)(self)
+        _attrsetter(self, name, value)
+        return val
+
+    def refresh(self, name: str):
+        """Refresh a state variable.
+
+        This allows to re-sync the user interface with a variable value.
+
+        Arguments:
+            name (str): The variable name to refresh.
+        """
+        val = attrgetter(name)(self)
+        _attrsetter(self, name, val)
+
+    def _set_context(self, gui: "Gui") -> t.ContextManager[None]:
+        return nullcontext()
+
+    def broadcast(self, name: str, value: t.Any):
+        """Update a variable on all clients.
+
+        All connected clients will receive an update of the variable called *name* with the
+        provided value, even if it is not shared.
+
+        Arguments:
+            name (str): The variable name to update.
+            value (Any): The new variable value.
+        """
+        with self._set_context(self._gui):
+            encoded_name = self._gui._bind_var(name)
+            self._gui._broadcast_all_clients(encoded_name, value)
+
+    def __enter__(self):
+        self._gui.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        return self._gui.__exit__(exc_type, exc_value, traceback)
+
+    def set_favicon(self, favicon_path: t.Union[str, Path]):
+        """Change the favicon for the client of this state.
+
+        This function dynamically changes the favicon (the icon associated with the application's
+        pages) of Taipy GUI pages for the specific client of this state.
+
+        Note that the *favicon* parameter to `(Gui.)run()^` can also be used to change
+        the favicon when the application starts.
+
+        Arguments:
+            favicon_path: The path to the image file to use.<br/>
+                This can be expressed as a path name or a URL (relative or not).
+        """
+        self._gui.set_favicon(favicon_path, self)
+
+
+class _GuiState(State):
     __gui_attr = "_gui"
     __gui_attr = "_gui"
     __attrs = (
     __attrs = (
         __gui_attr,
         __gui_attr,
@@ -100,68 +182,66 @@ class State:
     __excluded_attrs = __attrs + __methods + __placeholder_attrs
     __excluded_attrs = __attrs + __methods + __placeholder_attrs
 
 
     def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
     def __init__(self, gui: "Gui", var_list: t.Iterable[str], context_list: t.Iterable[str]) -> None:
-        super().__setattr__(State.__attrs[1], list(State.__filter_var_list(var_list, State.__excluded_attrs)))
-        super().__setattr__(State.__attrs[2], list(context_list))
-        super().__setattr__(State.__attrs[0], gui)
-
-    def get_gui(self) -> "Gui":
-        """Return the Gui instance for this state object.
-
-        Returns:
-            Gui: The Gui instance for this state object.
-        """
-        return super().__getattribute__(State.__gui_attr)
+        super().__setattr__(
+            _GuiState.__attrs[1], list(_GuiState.__filter_var_list(var_list, _GuiState.__excluded_attrs))
+        )
+        super().__setattr__(_GuiState.__attrs[2], list(context_list))
+        super().__setattr__(_GuiState.__attrs[0], gui)
+        super().__init__()
 
 
     @staticmethod
     @staticmethod
     def __filter_var_list(var_list: t.Iterable[str], excluded_attrs: t.Iterable[str]) -> t.Iterable[str]:
     def __filter_var_list(var_list: t.Iterable[str], excluded_attrs: t.Iterable[str]) -> t.Iterable[str]:
         return filter(lambda n: n not in excluded_attrs, var_list)
         return filter(lambda n: n not in excluded_attrs, var_list)
 
 
+    def get_gui(self) -> "Gui":
+        return super().__getattribute__(_GuiState.__gui_attr)
+
     def __getattribute__(self, name: str) -> t.Any:
     def __getattribute__(self, name: str) -> t.Any:
         if name == "__class__":
         if name == "__class__":
-            return State
-        if name in State.__methods:
+            return _GuiState
+        if name in _GuiState.__methods:
             return super().__getattribute__(name)
             return super().__getattribute__(name)
         gui: "Gui" = self.get_gui()
         gui: "Gui" = self.get_gui()
-        if name == State.__gui_attr:
+        if name == _GuiState.__gui_attr:
             return gui
             return gui
-        if name in State.__excluded_attrs:
+        if name in _GuiState.__excluded_attrs:
             raise AttributeError(f"Variable '{name}' is protected and is not accessible.")
             raise AttributeError(f"Variable '{name}' is protected and is not accessible.")
         if gui._is_in_brdcst_callback() and (
         if gui._is_in_brdcst_callback() and (
             name not in gui._get_shared_variables() and not gui._bindings()._is_single_client()
             name not in gui._get_shared_variables() and not gui._bindings()._is_single_client()
         ):
         ):
             raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
             raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
-        if not name.startswith("__") and name not in super().__getattribute__(State.__attrs[1]):
+        if not name.startswith("__") and name not in super().__getattribute__(_GuiState.__attrs[1]):
             raise AttributeError(f"Variable '{name}' is not defined.")
             raise AttributeError(f"Variable '{name}' is not defined.")
         with self._notebook_context(gui), self._set_context(gui):
         with self._notebook_context(gui), self._set_context(gui):
             encoded_name = gui._bind_var(name)
             encoded_name = gui._bind_var(name)
             return getattr(gui._bindings(), encoded_name)
             return getattr(gui._bindings(), encoded_name)
 
 
     def __setattr__(self, name: str, value: t.Any) -> None:
     def __setattr__(self, name: str, value: t.Any) -> None:
-        gui: "Gui" = super().__getattribute__(State.__gui_attr)
+        gui: "Gui" = super().__getattribute__(_GuiState.__gui_attr)
         if gui._is_in_brdcst_callback() and (
         if gui._is_in_brdcst_callback() and (
             name not in gui._get_shared_variables() and not gui._bindings()._is_single_client()
             name not in gui._get_shared_variables() and not gui._bindings()._is_single_client()
         ):
         ):
             raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
             raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
-        if not name.startswith("__") and name not in super().__getattribute__(State.__attrs[1]):
+        if not name.startswith("__") and name not in super().__getattribute__(_GuiState.__attrs[1]):
             raise AttributeError(f"Variable '{name}' is not accessible.")
             raise AttributeError(f"Variable '{name}' is not accessible.")
         with self._notebook_context(gui), self._set_context(gui):
         with self._notebook_context(gui), self._set_context(gui):
             encoded_name = gui._bind_var(name)
             encoded_name = gui._bind_var(name)
             setattr(gui._bindings(), encoded_name, value)
             setattr(gui._bindings(), encoded_name, value)
 
 
     def __getitem__(self, key: str):
     def __getitem__(self, key: str):
-        context = key if key in super().__getattribute__(State.__attrs[2]) else None
+        context = key if key in super().__getattribute__(_GuiState.__attrs[2]) else None
         if context is None:
         if context is None:
-            gui: "Gui" = super().__getattribute__(State.__gui_attr)
+            gui: "Gui" = super().__getattribute__(_GuiState.__gui_attr)
             page_ctx = gui._get_page_context(key)
             page_ctx = gui._get_page_context(key)
             context = page_ctx if page_ctx is not None else None
             context = page_ctx if page_ctx is not None else None
         if context is None:
         if context is None:
             raise RuntimeError(f"Can't resolve context '{key}' from state object")
             raise RuntimeError(f"Can't resolve context '{key}' from state object")
-        self._set_placeholder(State.__placeholder_attrs[1], context)
+        self._set_placeholder(_GuiState.__placeholder_attrs[1], context)
         return self
         return self
 
 
     def _set_context(self, gui: "Gui") -> t.ContextManager[None]:
     def _set_context(self, gui: "Gui") -> t.ContextManager[None]:
-        if (pl_ctx := self._get_placeholder(State.__placeholder_attrs[1])) is not None:
-            self._set_placeholder(State.__placeholder_attrs[1], None)
+        if (pl_ctx := self._get_placeholder(_GuiState.__placeholder_attrs[1])) is not None:
+            self._set_placeholder(_GuiState.__placeholder_attrs[1], None)
             if pl_ctx != gui._get_locals_context():
             if pl_ctx != gui._get_locals_context():
                 return gui._set_locals_context(pl_ctx)
                 return gui._set_locals_context(pl_ctx)
         if len(inspect.stack()) > 1:
         if len(inspect.stack()) > 1:
@@ -176,7 +256,7 @@ class State:
         return gui.get_flask_app().app_context() if not has_app_context() and _is_in_notebook() else nullcontext()
         return gui.get_flask_app().app_context() if not has_app_context() and _is_in_notebook() else nullcontext()
 
 
     def _get_placeholder(self, name: str):
     def _get_placeholder(self, name: str):
-        if name in State.__placeholder_attrs:
+        if name in _GuiState.__placeholder_attrs:
             try:
             try:
                 return super().__getattribute__(name)
                 return super().__getattribute__(name)
             except AttributeError:
             except AttributeError:
@@ -184,81 +264,16 @@ class State:
         return None
         return None
 
 
     def _set_placeholder(self, name: str, value: t.Any):
     def _set_placeholder(self, name: str, value: t.Any):
-        if name in State.__placeholder_attrs:
+        if name in _GuiState.__placeholder_attrs:
             super().__setattr__(name, value)
             super().__setattr__(name, value)
 
 
     def _get_placeholder_attrs(self):
     def _get_placeholder_attrs(self):
-        return State.__placeholder_attrs
+        return _GuiState.__placeholder_attrs
 
 
     def _add_attribute(self, name: str, default_value: t.Optional[t.Any] = None) -> bool:
     def _add_attribute(self, name: str, default_value: t.Optional[t.Any] = None) -> bool:
-        attrs: t.List[str] = super().__getattribute__(State.__attrs[1])
+        attrs: t.List[str] = super().__getattribute__(_GuiState.__attrs[1])
         if name not in attrs:
         if name not in attrs:
             attrs.append(name)
             attrs.append(name)
-            gui = super().__getattribute__(State.__gui_attr)
+            gui = super().__getattribute__(_GuiState.__gui_attr)
             return gui._bind_var_val(name, default_value)
             return gui._bind_var_val(name, default_value)
         return False
         return False
-
-    def assign(self, name: str, value: t.Any) -> t.Any:
-        """Assign a value to a state variable.
-
-        This should be used only from within a lambda function used
-        as a callback in a visual element.
-
-        Arguments:
-            name (str): The variable name to assign to.
-            value (Any): The new variable value.
-
-        Returns:
-            Any: The previous value of the variable.
-        """
-        val = attrgetter(name)(self)
-        _attrsetter(self, name, value)
-        return val
-
-    def refresh(self, name: str):
-        """Refresh a state variable.
-
-        This allows to re-sync the user interface with a variable value.
-
-        Arguments:
-            name (str): The variable name to refresh.
-        """
-        val = attrgetter(name)(self)
-        _attrsetter(self, name, val)
-
-    def broadcast(self, name: str, value: t.Any):
-        """Update a variable on all clients.
-
-        All connected clients will receive an update of the variable called *name* with the
-        provided value, even if it is not shared.
-
-        Arguments:
-            name (str): The variable name to update.
-            value (Any): The new variable value.
-        """
-        gui: "Gui" = super().__getattribute__(State.__gui_attr)
-        with self._set_context(gui):
-            encoded_name = gui._bind_var(name)
-            gui._broadcast_all_clients(encoded_name, value)
-
-    def __enter__(self):
-        super().__getattribute__(State.__attrs[0]).__enter__()
-        return self
-
-    def __exit__(self, exc_type, exc_value, traceback):
-        return super().__getattribute__(State.__attrs[0]).__exit__(exc_type, exc_value, traceback)
-
-    def set_favicon(self, favicon_path: t.Union[str, Path]):
-        """Change the favicon for the client of this state.
-
-        This function dynamically changes the favicon (the icon associated with the application's
-        pages) of Taipy GUI pages for the specific client of this state.
-
-        Note that the *favicon* parameter to `(Gui.)run()^` can also be used to change
-        the favicon when the application starts.
-
-        Arguments:
-            favicon_path: The path to the image file to use.<br/>
-                This can be expressed as a path name or a URL (relative or not).
-        """
-        super().__getattribute__(State.__gui_attr).set_favicon(favicon_path, self)

+ 3 - 2
tests/gui/actions/test_download.py

@@ -10,8 +10,9 @@
 # specific language governing permissions and limitations under the License.
 # specific language governing permissions and limitations under the License.
 
 
 import inspect
 import inspect
+import typing as t
 
 
-from flask import g
+from flask import Flask, g
 
 
 from taipy.gui import Gui, Markdown, State, download
 from taipy.gui import Gui, Markdown, State, download
 
 
@@ -30,7 +31,7 @@ def test_download(gui: Gui, helpers):
     gui.run(run_server=False)
     gui.run(run_server=False)
     flask_client = gui._server.test_client()
     flask_client = gui._server.test_client()
     # WS client and emit
     # WS client and emit
-    ws_client = gui._server._ws.test_client(gui._server.get_flask())
+    ws_client = gui._server._ws.test_client(t.cast(Flask, gui._server.get_flask()))
     cid = helpers.create_scope_and_get_sid(gui)
     cid = helpers.create_scope_and_get_sid(gui)
     # Get the jsx once so that the page will be evaluated -> variable will be registered
     # Get the jsx once so that the page will be evaluated -> variable will be registered
     flask_client.get(f"/taipy-jsx/test?client_id={cid}")
     flask_client.get(f"/taipy-jsx/test?client_id={cid}")

+ 102 - 0
tests/gui/mock/test_mock_state.py

@@ -0,0 +1,102 @@
+# Copyright 2021-2024 Avaiga Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations under the License.
+
+from unittest.mock import Mock
+
+from taipy.gui import Gui, State
+from taipy.gui.mock.mock_state import MockState
+from taipy.gui.utils import _MapDict
+
+
+def test_gui():
+    gui = Gui("")
+    ms = MockState(gui)
+    assert ms.get_gui() is gui
+    assert ms._gui is gui
+
+
+def test_read_attr():
+    gui = Gui("")
+    ms = MockState(gui, a=1)
+    assert ms is not None
+    assert ms.a == 1
+    assert ms.b is None
+
+
+def test_read_context():
+    ms = MockState(Gui(""), a=1)
+    assert ms["b"] is not None
+    assert ms["b"].a == 1
+
+
+def test_write_attr():
+    ms = MockState(Gui(""), a=1)
+    ms.a = 2
+    assert ms.a == 2
+    ms.b = 3
+    assert ms.b == 3
+    ms.a += 1
+    assert ms.a == 3
+
+def test_dict():
+    ms = MockState(Gui(""))
+    a_dict = {"a": 1}
+    ms.d = a_dict
+    assert isinstance(ms.d, _MapDict)
+    assert ms.d._dict is a_dict
+
+
+def test_write_context():
+    ms = MockState(Gui(""), a=1)
+    ms["page"].a = 2
+    assert ms["page"].a == 2
+    ms["page"].b = 3
+    assert ms["page"].b == 3
+
+def test_assign():
+    ms = MockState(Gui(""), a=1)
+    ms.assign("a", 2)
+    assert ms.a == 2
+    ms.assign("b", 1)
+    assert ms.b == 1
+
+def test_refresh():
+    ms = MockState(Gui(""), a=1)
+    ms.refresh("a")
+    assert ms.a == 1
+    ms.a = 2
+    ms.refresh("a")
+    assert ms.a == 2
+
+def test_context_manager():
+    with MockState(Gui(""), a=1) as ms:
+        assert ms is not None
+        ms.a = 2
+    assert ms.a == 2
+
+def test_broadcast():
+    ms = MockState(Gui(""), a=1)
+    ms.broadcast("a", 2)
+
+def test_set_favicon():
+    gui = Gui("")
+    gui.set_favicon = Mock()
+    ms = MockState(gui, a=1)
+    ms.set_favicon("a_path")
+    gui.set_favicon.assert_called_once()
+
+def test_callback():
+    def on_action(state: State):
+        state.assign("a", 2)
+
+    ms = MockState(Gui(""), a=1)
+    on_action(ms)
+    assert ms.a == 2