ソースを参照

rx.App `state` arg should not be required (#1361)

Elijah Ahianyo 1 年間 前
コミット
549ab4e708
5 ファイル変更77 行追加25 行削除
  1. 19 1
      reflex/app.py
  2. 4 0
      reflex/constants.py
  3. 11 0
      tests/conftest.py
  4. 15 13
      tests/middleware/test_hydrate_middleware.py
  5. 28 11
      tests/test_app.py

+ 19 - 1
reflex/app.py

@@ -2,6 +2,7 @@
 
 
 import asyncio
 import asyncio
 import inspect
 import inspect
+import os
 from multiprocessing.pool import ThreadPool
 from multiprocessing.pool import ThreadPool
 from typing import (
 from typing import (
     Any,
     Any,
@@ -42,7 +43,7 @@ from reflex.route import (
     verify_route_validity,
     verify_route_validity,
 )
 )
 from reflex.state import DefaultState, State, StateManager, StateUpdate
 from reflex.state import DefaultState, State, StateManager, StateUpdate
-from reflex.utils import format, types
+from reflex.utils import console, format, types
 
 
 # Define custom types.
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 ComponentCallable = Callable[[], Component]
@@ -100,8 +101,25 @@ class App(Base):
 
 
         Raises:
         Raises:
             ValueError: If the event namespace is not provided in the config.
             ValueError: If the event namespace is not provided in the config.
+                        Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
+                        of the DefaultState and the client app state).
         """
         """
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
+        state_subclasses = State.__subclasses__()
+        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
+
+        # Special case to allow test cases have multiple subclasses of rx.State.
+        if not is_testing_env:
+            # Only the default state and the client state should be allowed as subclasses.
+            if len(state_subclasses) > 2:
+                raise ValueError(
+                    "rx.State has been subclassed multiple times. Only one subclass is allowed"
+                )
+            if self.state != DefaultState:
+                console.deprecate(
+                    "Passing the state as keyword argument to `rx.App` is deprecated."
+                )
+            self.state = state_subclasses[-1]
 
 
         # Get the config
         # Get the config
         config = get_config()
         config = get_config()

+ 4 - 0
reflex/constants.py

@@ -200,6 +200,10 @@ TOKEN_EXPIRATION = 60 * 60
 # The event namespace for websocket
 # The event namespace for websocket
 EVENT_NAMESPACE = get_value("EVENT_NAMESPACE")
 EVENT_NAMESPACE = get_value("EVENT_NAMESPACE")
 
 
+# Testing variables.
+# Testing os env set by pytest when running a test case.
+PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
+
 # Env modes
 # Env modes
 class Env(str, Enum):
 class Env(str, Enum):
     """The environment modes."""
     """The environment modes."""

+ 11 - 0
tests/conftest.py

@@ -9,9 +9,20 @@ import pytest
 
 
 import reflex as rx
 import reflex as rx
 from reflex import constants
 from reflex import constants
+from reflex.app import App
 from reflex.event import EventSpec
 from reflex.event import EventSpec
 
 
 
 
+@pytest.fixture
+def app() -> App:
+    """A base app.
+
+    Returns:
+        The app.
+    """
+    return App()
+
+
 @pytest.fixture(scope="function")
 @pytest.fixture(scope="function")
 def windows_platform() -> Generator:
 def windows_platform() -> Generator:
     """Check if system is windows.
     """Check if system is windows.

+ 15 - 13
tests/middleware/test_hydrate_middleware.py

@@ -78,25 +78,27 @@ def hydrate_middleware() -> HydrateMiddleware:
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "State, expected, event_fixture",
+    "test_state, expected, event_fixture",
     [
     [
         (TestState, {"test_state": {"num": 1}}, "event1"),
         (TestState, {"test_state": {"num": 1}}, "event1"),
         (TestState2, {"test_state2": {"num": 1}}, "event2"),
         (TestState2, {"test_state2": {"num": 1}}, "event2"),
         (TestState3, {"test_state3": {"num": 1}}, "event3"),
         (TestState3, {"test_state3": {"num": 1}}, "event3"),
     ],
     ],
 )
 )
-async def test_preprocess(State, hydrate_middleware, request, event_fixture, expected):
+async def test_preprocess(
+    test_state, hydrate_middleware, request, event_fixture, expected
+):
     """Test that a state hydrate event is processed correctly.
     """Test that a state hydrate event is processed correctly.
 
 
     Args:
     Args:
-        State: state to process event
-        hydrate_middleware: instance of HydrateMiddleware
-        request: pytest fixture request
-        event_fixture: The event fixture(an Event)
-        expected: expected delta
+        test_state: State to process event.
+        hydrate_middleware: Instance of HydrateMiddleware.
+        request: Pytest fixture request.
+        event_fixture: The event fixture(an Event).
+        expected: Expected delta.
     """
     """
-    app = App(state=State, load_events={"index": [State.test_handler]})
-    state = State()
+    app = App(state=test_state, load_events={"index": [test_state.test_handler]})
+    state = test_state()
 
 
     update = await hydrate_middleware.preprocess(
     update = await hydrate_middleware.preprocess(
         app=app, event=request.getfixturevalue(event_fixture), state=state
         app=app, event=request.getfixturevalue(event_fixture), state=state
@@ -120,8 +122,8 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
     """Test that a state hydrate event for multiple on-load events is processed correctly.
     """Test that a state hydrate event for multiple on-load events is processed correctly.
 
 
     Args:
     Args:
-        hydrate_middleware: instance of HydrateMiddleware
-        event1: an Event.
+        hydrate_middleware: Instance of HydrateMiddleware
+        event1: An Event.
     """
     """
     app = App(
     app = App(
         state=TestState,
         state=TestState,
@@ -151,8 +153,8 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
     """Test that app without on_load is processed correctly.
     """Test that app without on_load is processed correctly.
 
 
     Args:
     Args:
-        hydrate_middleware: instance of HydrateMiddleware
-        event1: an Event.
+        hydrate_middleware: Instance of HydrateMiddleware
+        event1: An Event.
     """
     """
     state = TestState()
     state = TestState()
     update = await hydrate_middleware.preprocess(
     update = await hydrate_middleware.preprocess(

+ 28 - 11
tests/test_app.py

@@ -27,16 +27,6 @@ from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
 
 
 
 
-@pytest.fixture
-def app() -> App:
-    """A base app.
-
-    Returns:
-        The app.
-    """
-    return App()
-
-
 @pytest.fixture
 @pytest.fixture
 def index_page():
 def index_page():
     """An index page.
     """An index page.
@@ -79,6 +69,20 @@ def test_state() -> Type[State]:
     return TestState
     return TestState
 
 
 
 
+@pytest.fixture()
+def redundant_test_state() -> Type[State]:
+    """A default state.
+
+    Returns:
+        A default state.
+    """
+
+    class RedundantTestState(State):
+        var: int
+
+    return RedundantTestState
+
+
 @pytest.fixture()
 @pytest.fixture()
 def test_model() -> Type[Model]:
 def test_model() -> Type[Model]:
     """A default model.
     """A default model.
@@ -170,6 +174,19 @@ def test_default_app(app: App):
     assert app.admin_dash is None
     assert app.admin_dash is None
 
 
 
 
+def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
+    """Test that an error is thrown when multiple classes subclass rx.State.
+
+    Args:
+        monkeypatch: Pytest monkeypatch object.
+        test_state: A test state subclassing rx.State.
+        redundant_test_state: Another test state subclassing rx.State.
+    """
+    monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
+    with pytest.raises(ValueError):
+        App()
+
+
 def test_add_page_default_route(app: App, index_page, about_page):
 def test_add_page_default_route(app: App, index_page, about_page):
     """Test adding a page to an app.
     """Test adding a page to an app.
 
 
@@ -708,7 +725,7 @@ class DynamicState(State):
 
 
     There are several counters:
     There are several counters:
       * loaded: counts how many times `on_load` was triggered by the hydrate middleware
       * loaded: counts how many times `on_load` was triggered by the hydrate middleware
-      * counter: counts how many times `on_counter` was triggered by a non-naviagational event
+      * counter: counts how many times `on_counter` was triggered by a non-navigational event
           -> these events should NOT trigger reload or recalculation of router_data dependent vars
           -> these events should NOT trigger reload or recalculation of router_data dependent vars
       * side_effect_counter: counts how many times a computed var was
       * side_effect_counter: counts how many times a computed var was
         recalculated when the dynamic route var was dirty
         recalculated when the dynamic route var was dirty