Bläddra i källkod

Track state usage (#2441)

* rebase

* pass include_children kwarg in radix FormRoot

* respect include_children

* ruff fixes

* readd statemanager init, run pyi gen

* minor performance imporovements, fix for state changes

* fix pyi and pyright

* pass include_children for chakra

* remove old state detection

* add test for unused states in stateless app

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 1 år sedan
förälder
incheckning
19a5cdd408

+ 3 - 0
integration/test_tailwind.py

@@ -28,6 +28,9 @@ def TailwindApp(
     import reflex as rx
     import reflex.components.radix.themes as rdxt
 
+    class UnusedState(rx.State):
+        pass
+
     def index():
         return rx.el.div(
             rx.chakra.text(paragraph_text, class_name=paragraph_class_name),

+ 95 - 53
reflex/app.py

@@ -91,6 +91,12 @@ def default_overlay_component() -> Component:
     return Fragment.create(connection_pulser(), connection_modal())
 
 
+class OverlayFragment(Fragment):
+    """Alias for Fragment, used to wrap the overlay_component."""
+
+    pass
+
+
 class App(Base):
     """A Reflex application."""
 
@@ -159,7 +165,7 @@ class App(Base):
 
         Raises:
             ValueError: If the event namespace is not provided in the config.
-                        Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
+                        Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
                         of the DefaultState and the client app state).
         """
         if "connect_error_component" in kwargs:
@@ -167,12 +173,12 @@ class App(Base):
                 "`connect_error_component` is deprecated, use `overlay_component` instead"
             )
         super().__init__(*args, **kwargs)
-        state_subclasses = BaseState.__subclasses__()
+        base_state_subclasses = BaseState.__subclasses__()
 
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
         if not is_testing_env():
             # Only one Base State class is allowed.
-            if len(state_subclasses) > 1:
+            if len(base_state_subclasses) > 1:
                 raise ValueError(
                     "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
                 )
@@ -184,12 +190,6 @@ class App(Base):
                     deprecation_version="0.3.5",
                     removal_version="0.5.0",
                 )
-            # 2 substates are built-in and not considered when determining if app is stateless.
-            if len(State.class_subclasses) > 2:
-                self.state = State
-        # Get the config
-        config = get_config()
-
         # Add middleware.
         self.middleware.append(HydrateMiddleware())
 
@@ -198,45 +198,60 @@ class App(Base):
         self.add_cors()
         self.add_default_endpoints()
 
-        if self.state:
-            # Set up the state manager.
-            self._state_manager = StateManager.create(state=self.state)
-
-            # Set up the Socket.IO AsyncServer.
-            self.sio = AsyncServer(
-                async_mode="asgi",
-                cors_allowed_origins=(
-                    "*"
-                    if config.cors_allowed_origins == ["*"]
-                    else config.cors_allowed_origins
-                ),
-                cors_credentials=True,
-                max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
-                ping_interval=constants.Ping.INTERVAL,
-                ping_timeout=constants.Ping.TIMEOUT,
-            )
+        self.setup_state()
 
-            # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
-            self.socket_app = ASGIApp(self.sio, socketio_path="")
-            namespace = config.get_event_namespace()
+        # Set up the admin dash.
+        self.setup_admin_dash()
 
-            if not namespace:
-                raise ValueError("event namespace must be provided in the config.")
+    def enable_state(self) -> None:
+        """Enable state for the app."""
+        if not self.state:
+            self.state = State
+            self.setup_state()
 
-            # Create the event namespace and attach the main app. Not related to any paths.
-            self.event_namespace = EventNamespace(namespace, self)
+    def setup_state(self) -> None:
+        """Set up the state for the app.
 
-            # Register the event namespace with the socket.
-            self.sio.register_namespace(self.event_namespace)
-            # Mount the socket app with the API.
-            self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
+        Raises:
+            ValueError: If the event namespace is not provided in the config.
+                        If the state has not been enabled.
+        """
+        if not self.state:
+            return
 
-        # Set up the admin dash.
-        self.setup_admin_dash()
+        config = get_config()
 
-        # If a State is not used and no overlay_component is specified, do not render the connection modal
-        if self.state is None and self.overlay_component is default_overlay_component:
-            self.overlay_component = None
+        # Set up the state manager.
+        self._state_manager = StateManager.create(state=self.state)
+
+        # Set up the Socket.IO AsyncServer.
+        self.sio = AsyncServer(
+            async_mode="asgi",
+            cors_allowed_origins=(
+                "*"
+                if config.cors_allowed_origins == ["*"]
+                else config.cors_allowed_origins
+            ),
+            cors_credentials=True,
+            max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
+            ping_interval=constants.Ping.INTERVAL,
+            ping_timeout=constants.Ping.TIMEOUT,
+        )
+
+        # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
+        self.socket_app = ASGIApp(self.sio, socketio_path="")
+        namespace = config.get_event_namespace()
+
+        if not namespace:
+            raise ValueError("event namespace must be provided in the config.")
+
+        # Create the event namespace and attach the main app. Not related to any paths.
+        self.event_namespace = EventNamespace(namespace, self)
+
+        # Register the event namespace with the socket.
+        self.sio.register_namespace(self.event_namespace)
+        # Mount the socket app with the API.
+        self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
 
     def __repr__(self) -> str:
         """Get the string representation of the app.
@@ -430,21 +445,24 @@ class App(Base):
         # Check if the route given is valid
         verify_route_validity(route)
 
-        # Apply dynamic args to the route.
-        if self.state:
-            self.state.setup_dynamic_args(get_route_args(route))
+        # Setup dynamic args for the route.
+        # this state assignment is only required for tests using the deprecated state kwarg for App
+        state = self.state if self.state else State
+        state.setup_dynamic_args(get_route_args(route))
 
         # Generate the component if it is a callable.
         component = self._generate_component(component)
 
-        # Wrap the component in a fragment with optional overlay.
-        if self.overlay_component is not None:
-            component = Fragment.create(
-                self._generate_component(self.overlay_component),
-                component,
-            )
-        else:
-            component = Fragment.create(component)
+        if self.state is None:
+            for var in component._get_vars(include_children=True):
+                if not var._var_data:
+                    continue
+                if not var._var_data.state:
+                    continue
+                self.enable_state()
+                break
+
+        component = OverlayFragment.create(component)
 
         # Add meta information to the component.
         compiler_utils.add_meta(
@@ -649,6 +667,28 @@ class App(Base):
         # By default, compile the app.
         return True
 
+    def _add_overlay_to_component(self, component: Component) -> Component:
+        if self.overlay_component is None:
+            return component
+
+        children = component.children
+        overlay_component = self._generate_component(self.overlay_component)
+
+        if children[0] == overlay_component:
+            return component
+
+        # recreate OverlayFragment with overlay_component as first child
+        component = OverlayFragment.create(overlay_component, *children)
+
+        return component
+
+    def _setup_overlay_component(self):
+        """If a State is not used and no overlay_component is specified, do not render the connection modal."""
+        if self.state is None and self.overlay_component is default_overlay_component:
+            self.overlay_component = None
+        for k, component in self.pages.items():
+            self.pages[k] = self._add_overlay_to_component(component)
+
     def compile(self):
         """compile_() is the new function for performing compilation.
         Reflex framework will call it automatically as needed.
@@ -682,6 +722,8 @@ class App(Base):
         if not self._should_compile():
             return
 
+        self._setup_overlay_component()
+
         # Create a progress bar.
         progress = Progress(
             *Progress.get_default_columns()[:-1],

+ 6 - 0
reflex/app.pyi

@@ -64,6 +64,11 @@ Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 
 def default_overlay_component() -> Component: ...
 
+class OverlayFragment(Fragment):
+    @overload
+    @classmethod
+    def create(cls, *children, **props) -> "OverlayFragment": ...  # type: ignore
+
 class App(Base):
     pages: Dict[str, Component]
     stylesheets: List[str]
@@ -122,6 +127,7 @@ class App(Base):
     def compile(self) -> None: ...
     def compile_(self) -> None: ...
     def modify_state(self, token: str) -> AsyncContextManager[State]: ...
+    def _setup_overlay_component(self) -> None: ...
     def _process_background(
         self, state: State, event: Event
     ) -> asyncio.Task | None: ...

+ 4 - 1
reflex/components/base/bare.py

@@ -33,9 +33,12 @@ class Bare(Component):
     def _render(self) -> Tag:
         return Tagless(contents=str(self.contents))
 
-    def _get_vars(self) -> Iterator[Var]:
+    def _get_vars(self, include_children: bool = False) -> Iterator[Var]:
         """Walk all Vars used in this component.
 
+        Args:
+            include_children: Whether to include Vars from children.
+
         Yields:
             The contents if it is a Var, otherwise nothing.
         """

+ 12 - 1
reflex/components/component.py

@@ -812,9 +812,12 @@ class Component(BaseComponent, ABC):
                         event_args.extend(args)
                 yield event_trigger, event_args
 
-    def _get_vars(self) -> list[Var]:
+    def _get_vars(self, include_children: bool = False) -> list[Var]:
         """Walk all Vars used in this component.
 
+        Args:
+            include_children: Whether to include Vars from children.
+
         Returns:
             Each var referenced by the component (props, styles, event handlers).
         """
@@ -860,6 +863,14 @@ class Component(BaseComponent, ABC):
                 var = Var.create_safe(comp_prop)
                 if var._var_data is not None:
                     vars.append(var)
+
+        # Get Vars associated with children.
+        if include_children:
+            for child in self.children:
+                if not isinstance(child, Component):
+                    continue
+                vars.extend(child._get_vars(include_children=include_children))
+
         return vars
 
     def _get_custom_code(self) -> str | None:

+ 2 - 2
reflex/components/el/elements/forms.py

@@ -222,8 +222,8 @@ class Form(BaseHTML):
                 )._replace(merge_var_data=ref_var._var_data)
         return form_refs
 
-    def _get_vars(self) -> Iterator[Var]:
-        yield from super()._get_vars()
+    def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
+        yield from super()._get_vars(include_children=include_children)
         yield from self._get_form_refs().values()
 
     def _exclude_props(self) -> list[str]:

+ 9 - 0
reflex/config.py

@@ -219,6 +219,15 @@ class Config(Base):
         self._non_default_attributes.update(kwargs)
         self._replace_defaults(**kwargs)
 
+    @property
+    def module(self) -> str:
+        """Get the module name of the app.
+
+        Returns:
+            The module name.
+        """
+        return ".".join([self.app_name, self.app_name])
+
     @staticmethod
     def check_deprecated_values(**kwargs):
         """Check for deprecated config values.

+ 2 - 0
reflex/config.pyi

@@ -99,6 +99,8 @@ class Config(Base):
         gunicorn_worker_class: Optional[str] = None,
         **kwargs
     ) -> None: ...
+    @property
+    def module(self) -> str: ...
     @staticmethod
     def check_deprecated_values(**kwargs) -> None: ...
     def update_from_env(self) -> None: ...

+ 16 - 0
reflex/state.py

@@ -2918,3 +2918,19 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
         True if the code attempts to access a member of StateContexts.
     """
     return bool("useContext(StateContexts" in javascript_code)
+
+
+def reload_state_module(
+    module: str,
+    state: Type[BaseState] = State,
+) -> None:
+    """Reset rx.State subclasses to avoid conflict when reloading.
+
+    Args:
+        module: The module to reload.
+        state: Recursive argument for the state class to reload.
+    """
+    for subclass in tuple(state.class_subclasses):
+        reload_state_module(module=module, state=subclass)
+        if subclass.__module__ == module and module is not None:
+            state.class_subclasses.remove(subclass)

+ 0 - 1
reflex/testing.py

@@ -232,7 +232,6 @@ class AppHarness:
             State.get_class_substate.cache_clear()
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
-            # self.app_module.app.
             self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
         self.app_instance = self.app_module.app
         if isinstance(self.app_instance._state_manager, StateManagerRedis):

+ 5 - 5
reflex/utils/prerequisites.py

@@ -185,16 +185,16 @@ def get_app(reload: bool = False) -> ModuleType:
             "Cannot get the app module because `app_name` is not set in rxconfig! "
             "If this error occurs in a reflex test case, ensure that `get_app` is mocked."
         )
-    module = ".".join([config.app_name, config.app_name])
+    module = config.module
     sys.path.insert(0, os.getcwd())
     app = __import__(module, fromlist=(constants.CompileVars.APP,))
+
     if reload:
-        from reflex.state import State
+        from reflex.state import reload_state_module
 
         # Reset rx.State subclasses to avoid conflict when reloading.
-        for subclass in tuple(State.class_subclasses):
-            if subclass.__module__ == module:
-                State.class_subclasses.remove(subclass)
+        reload_state_module(module=module)
+
         # Reload the app module.
         importlib.reload(app)
 

+ 7 - 2
tests/test_app.py

@@ -20,6 +20,7 @@ from reflex import AdminDash, constants
 from reflex.app import (
     App,
     ComponentCallable,
+    OverlayFragment,
     default_overlay_component,
     process,
     upload,
@@ -1182,12 +1183,13 @@ def test_overlay_component(
         exp_page_child: The type of the expected child in the page fragment.
     """
     app = App(state=state, overlay_component=overlay_component)
+    app._setup_overlay_component()
     if exp_page_child is None:
         assert app.overlay_component is None
-    elif isinstance(exp_page_child, Fragment):
+    elif isinstance(exp_page_child, OverlayFragment):
         assert app.overlay_component is not None
         generated_component = app._generate_component(app.overlay_component)  # type: ignore
-        assert isinstance(generated_component, Fragment)
+        assert isinstance(generated_component, OverlayFragment)
         assert isinstance(
             generated_component.children[0],
             Cond,  # ConnectionModal is a Cond under the hood
@@ -1200,7 +1202,10 @@ def test_overlay_component(
         )
 
     app.add_page(rx.box("Index"), route="/test")
+    # overlay components are wrapped during compile only
+    app._setup_overlay_component()
     page = app.pages["test"]
+
     if exp_page_child is not None:
         assert len(page.children) == 3
         children_types = (type(child) for child in page.children)