Elijah Ahianyo 1 год назад
Родитель
Сommit
b652d40ee5

+ 2 - 0
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -8,7 +8,9 @@
 
 {% block export %}
 export default function Component() {
+{% if state_name %}
   const {{state_name}} = useContext(StateContext)
+{% endif %}
   const {{const.router}} = useRouter()
   const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
   const focusRef = useRef();

+ 15 - 0
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -1,14 +1,29 @@
 import { createContext, useState } from "react"
 import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
 
+{% if initial_state %}
 export const initialState = {{ initial_state|json_dumps }}
+{% else %}
+export const initialState = {}
+{% endif %}
+
 export const ColorModeContext = createContext(null);
 export const StateContext = createContext(null);
 export const EventLoopContext = createContext(null);
+{% if client_storage %}
 export const clientStorage = {{ client_storage|json_dumps }}
+{% else %}
+export const clientStorage = {}
+{% endif %}
+
+{% if state_name %}
 export const initialEvents = () => [
     Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
 ]
+{% else %}
+export const initialEvents = () => []
+{% endif %}
+
 export const isDevMode = {{ is_dev_mode|json_dumps }}
 
 export function EventLoopProvider({ children }) {

+ 30 - 18
reflex/app.py

@@ -53,11 +53,9 @@ from reflex.route import (
     verify_route_validity,
 )
 from reflex.state import (
-    DefaultState,
     RouterData,
     State,
     StateManager,
-    StateManagerMemory,
     StateUpdate,
 )
 from reflex.utils import console, format, prerequisites, types
@@ -96,10 +94,10 @@ class App(Base):
     socket_app: Optional[ASGIApp] = None
 
     # The state class to use for the app.
-    state: Type[State] = DefaultState
+    state: Optional[Type[State]] = None
 
     # Class to manage many client states.
-    state_manager: StateManager = StateManagerMemory(state=DefaultState)
+    _state_manager: Optional[StateManager] = None
 
     # The styling to apply to each component.
     style: ComponentStyle = {}
@@ -148,19 +146,19 @@ class App(Base):
             )
         super().__init__(*args, **kwargs)
         state_subclasses = State.__subclasses__()
-        inferred_state = state_subclasses[-1]
+        inferred_state = state_subclasses[-1] if state_subclasses else None
         is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
 
         # Special case to allow test cases have multiple subclasses of rx.State.
         if not is_testing_env:
-            # Only the default state and the client state should be allowed as subclasses.
-            if len(state_subclasses) > 2:
+            # Only one State class is allowed.
+            if len(state_subclasses) > 1:
                 raise ValueError(
                     "rx.State has been subclassed multiple times. Only one subclass is allowed"
                 )
 
             # verify that provided state is valid
-            if self.state not in [DefaultState, inferred_state]:
+            if self.state and inferred_state and self.state is not inferred_state:
                 console.warn(
                     f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
                     f" Defaulting to root state: ({inferred_state.__name__})"
@@ -172,15 +170,15 @@ class App(Base):
         # Add middleware.
         self.middleware.append(HydrateMiddleware())
 
-        # Set up the state manager.
-        self.state_manager = StateManager.create(state=self.state)
-
         # Set up the API.
         self.api = FastAPI()
         self.add_cors()
         self.add_default_endpoints()
 
-        if self.state is not DefaultState:
+        if self.state:
+            # Set up the state manager.
+            self._state_manager = StateManager.create(state=self.state)
+
             # Set up the Socket.IO AsyncServer.
             self.sio = AsyncServer(
                 async_mode="asgi",
@@ -212,10 +210,7 @@ class App(Base):
         self.setup_admin_dash()
 
         # If a State is not used and no overlay_component is specified, do not render the connection modal
-        if (
-            self.state is DefaultState
-            and self.overlay_component is default_overlay_component
-        ):
+        if self.state is None and self.overlay_component is default_overlay_component:
             self.overlay_component = None
 
     def __repr__(self) -> str:
@@ -224,7 +219,7 @@ class App(Base):
         Returns:
             The string representation of the app.
         """
-        return f"<App state={self.state.__name__}>"
+        return f"<App state={self.state.__name__ if self.state else None}>"
 
     def __call__(self) -> FastAPI:
         """Run the backend api instance.
@@ -252,6 +247,20 @@ class App(Base):
             allow_origins=["*"],
         )
 
+    @property
+    def state_manager(self) -> StateManager:
+        """Get the state manager.
+
+        Returns:
+            The initialized state manager.
+
+        Raises:
+            ValueError: if the state has not been initialized.
+        """
+        if self._state_manager is None:
+            raise ValueError("The state manager has not been initialized.")
+        return self._state_manager
+
     async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
         """Preprocess the event.
 
@@ -385,7 +394,8 @@ class App(Base):
         verify_route_validity(route)
 
         # Apply dynamic args to the route.
-        self.state.setup_dynamic_args(get_route_args(route))
+        if self.state:
+            self.state.setup_dynamic_args(get_route_args(route))
 
         # Generate the component if it is a callable.
         component = self._generate_component(component)
@@ -715,6 +725,7 @@ class App(Base):
         """
         if self.event_namespace is None:
             raise RuntimeError("App has not been initialized yet.")
+
         # Get exclusive access to the state.
         async with self.state_manager.modify_state(token) as state:
             # No other event handler can modify the state while in this context.
@@ -862,6 +873,7 @@ def upload(app: App):
         for file in files:
             assert file.filename is not None
             file.filename = file.filename.split(":")[-1]
+
         # Get the state for the session.
         async with app.state_manager.modify_state(token) as state:
             # get the current session ID

+ 0 - 1
reflex/app.pyi

@@ -32,7 +32,6 @@ from reflex.route import (
     verify_route_validity as verify_route_validity,
 )
 from reflex.state import (
-    DefaultState as DefaultState,
     State as State,
     StateManager as StateManager,
     StateUpdate as StateUpdate,

+ 16 - 9
reflex/compiler/compiler.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 
 import os
 from pathlib import Path
-from typing import Type
+from typing import Optional, Type
 
 from reflex import constants
 from reflex.compiler import templates, utils
@@ -89,7 +89,7 @@ def _compile_theme(theme: dict) -> str:
     return templates.THEME.render(theme=theme)
 
 
-def _compile_contexts(state: Type[State]) -> str:
+def _compile_contexts(state: Optional[Type[State]]) -> str:
     """Compile the initial state and contexts.
 
     Args:
@@ -98,11 +98,16 @@ def _compile_contexts(state: Type[State]) -> str:
     Returns:
         The compiled context file.
     """
-    return templates.CONTEXT.render(
-        initial_state=utils.compile_state(state),
-        state_name=state.get_name(),
-        client_storage=utils.compile_client_storage(state),
-        is_dev_mode=os.environ.get("REFLEX_ENV_MODE", "dev") == "dev",
+    is_dev_mode = os.environ.get("REFLEX_ENV_MODE", "dev") == "dev"
+    return (
+        templates.CONTEXT.render(
+            initial_state=utils.compile_state(state),
+            state_name=state.get_name(),
+            client_storage=utils.compile_client_storage(state),
+            is_dev_mode=is_dev_mode,
+        )
+        if state
+        else templates.CONTEXT.render(is_dev_mode=is_dev_mode)
     )
 
 
@@ -125,13 +130,15 @@ def _compile_page(
     imports = utils.compile_imports(imports)
 
     # Compile the code to render the component.
+    kwargs = {"state_name": state.get_name()} if state else {}
+
     return templates.PAGE.render(
         imports=imports,
         dynamic_imports=component.get_dynamic_imports(),
         custom_codes=component.get_custom_code(),
-        state_name=state.get_name(),
         hooks=component.get_hooks(),
         render=component.render(),
+        **kwargs,
     )
 
 
@@ -296,7 +303,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
     return output_path, code
 
 
-def compile_contexts(state: Type[State]) -> tuple[str, str]:
+def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
     """Compile the initial state / context.
 
     Args:

+ 1 - 7
reflex/state.py

@@ -1368,12 +1368,6 @@ class StateProxy(wrapt.ObjectProxy):
         super().__setattr__(name, value)
 
 
-class DefaultState(State):
-    """The default empty state."""
-
-    pass
-
-
 class StateUpdate(Base):
     """A state update sent to the frontend."""
 
@@ -1394,7 +1388,7 @@ class StateManager(Base, ABC):
     state: Type[State]
 
     @classmethod
-    def create(cls, state: Type[State] = DefaultState):
+    def create(cls, state: Type[State]):
         """Create a new state manager.
 
         Args:

+ 2 - 2
reflex/testing.py

@@ -495,13 +495,13 @@ class AppHarness:
         if isinstance(self.state_manager, StateManagerRedis):
             # Temporarily replace the app's state manager with our own, since
             # the redis connection is on the backend_thread event loop
-            self.app_instance.state_manager = self.state_manager
+            self.app_instance._state_manager = self.state_manager
         try:
             async with self.app_instance.modify_state(token) as state:
                 yield state
         finally:
             if isinstance(self.state_manager, StateManagerRedis):
-                self.app_instance.state_manager = app_state_manager
+                self.app_instance._state_manager = app_state_manager
                 await self.state_manager.redis.close()
 
     def poll_for_content(

+ 12 - 8
tests/test_app.py

@@ -25,7 +25,6 @@ from reflex import AdminDash, constants
 from reflex.app import (
     App,
     ComponentCallable,
-    DefaultState,
     default_overlay_component,
     process,
     upload,
@@ -49,6 +48,12 @@ from .states import (
 )
 
 
+class EmptyState(State):
+    """An empty state."""
+
+    pass
+
+
 @pytest.fixture
 def index_page():
     """An index page.
@@ -192,7 +197,6 @@ def test_default_app(app: App):
     Args:
         app: The app to test.
     """
-    assert app.state() == DefaultState()
     assert app.middleware == [HydrateMiddleware()]
     assert app.style == Style()
     assert app.admin_dash is None
@@ -240,14 +244,14 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
     assert set(app.pages.keys()) == {"test"}
 
 
-def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool):
+def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
     """Test adding a page with dynamic route variable to an app.
 
     Args:
-        app: The app to test.
         index_page: The index page.
         windows_platform: Whether the system is windows.
     """
+    app = App(state=EmptyState)
     route = "/test/[dynamic]"
     if windows_platform:
         route.lstrip("/").replace("/", "\\")
@@ -255,7 +259,7 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool
     app.add_page(index_page, route=route)
     assert set(app.pages.keys()) == {"test/[dynamic]"}
     assert "dynamic" in app.state.computed_vars
-    assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == {
+    assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
         constants.ROUTER
     }
     assert constants.ROUTER in app.state().computed_var_dependencies
@@ -1093,9 +1097,9 @@ async def test_process_events(mocker, token: str):
 @pytest.mark.parametrize(
     ("state", "overlay_component", "exp_page_child"),
     [
-        (DefaultState, default_overlay_component, None),
-        (DefaultState, None, None),
-        (DefaultState, Text.create("foo"), Text),
+        (None, default_overlay_component, None),
+        (None, None, None),
+        (None, Text.create("foo"), Text),
         (State, default_overlay_component, Fragment),
         (State, None, None),
         (State, Text.create("foo"), Text),

+ 4 - 1
tests/test_testing.py

@@ -16,7 +16,10 @@ def test_app_harness(tmp_path):
     def BasicApp():
         import reflex as rx
 
-        app = rx.App()
+        class State(rx.State):
+            pass
+
+        app = rx.App(state=State)
         app.add_page(lambda: rx.text("Basic App"), route="/", title="index")
         app.compile()