1
0
Эх сурвалжийг харах

RED-1052/rx.State as Base State (#2146)

Elijah Ahianyo 1 жил өмнө
parent
commit
e3ee98098a
49 өөрчлөгдсөн 356 нэмэгдсэн , 270 устгасан
  1. 1 1
      integration/test_background_task.py
  2. 1 1
      integration/test_call_script.py
  3. 30 24
      integration/test_client_storage.py
  4. 1 1
      integration/test_connection_banner.py
  5. 5 3
      integration/test_dynamic_routes.py
  6. 7 3
      integration/test_event_actions.py
  7. 13 13
      integration/test_event_chain.py
  8. 3 3
      integration/test_form_submit.py
  9. 12 8
      integration/test_input.py
  10. 2 2
      integration/test_login_flow.py
  11. 1 1
      integration/test_radix_themes.py
  12. 1 1
      integration/test_server_side_event.py
  13. 1 1
      integration/test_table.py
  14. 7 7
      integration/test_upload.py
  15. 1 1
      integration/test_var_operations.py
  16. 15 13
      reflex/app.py
  17. 2 1
      reflex/app.pyi
  18. 32 1
      reflex/base.py
  19. 5 5
      reflex/compiler/compiler.py
  20. 4 4
      reflex/compiler/utils.py
  21. 2 0
      reflex/constants/__init__.py
  22. 1 0
      reflex/constants/base.py
  23. 2 2
      reflex/event.py
  24. 2 2
      reflex/middleware/hydrate_middleware.py
  25. 3 3
      reflex/middleware/middleware.py
  26. 70 42
      reflex/state.py
  27. 6 3
      reflex/testing.py
  28. 2 0
      reflex/utils/prerequisites.py
  29. 7 7
      reflex/vars.py
  30. 3 2
      reflex/vars.pyi
  31. 3 0
      tests/__init__.py
  32. 2 2
      tests/components/base/test_script.py
  33. 5 4
      tests/components/datadisplay/conftest.py
  34. 2 2
      tests/components/datadisplay/test_table.py
  35. 2 1
      tests/components/forms/test_debounce.py
  36. 2 1
      tests/components/layout/test_cond.py
  37. 2 2
      tests/components/layout/test_foreach.py
  38. 3 3
      tests/components/test_component.py
  39. 0 21
      tests/conftest.py
  40. 9 5
      tests/middleware/test_hydrate_middleware.py
  41. 4 2
      tests/states/__init__.py
  42. 4 3
      tests/states/mutation.py
  43. 5 4
      tests/states/upload.py
  44. 27 19
      tests/test_app.py
  45. 2 2
      tests/test_event.py
  46. 35 29
      tests/test_state.py
  47. 5 11
      tests/test_var.py
  48. 0 2
      tests/utils/test_format.py
  49. 2 2
      tests/utils/test_utils.py

+ 1 - 1
integration/test_background_task.py

@@ -93,7 +93,7 @@ def BackgroundTask():
             rx.button("Reset", on_click=State.reset_counter, id="reset"),
         )
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.compile()
 

+ 1 - 1
integration/test_call_script.py

@@ -135,7 +135,7 @@ def CallScript():
             yield rx.call_script("inline_counter = 0; external_counter = 0")
             self.reset()
 
-    app = rx.App(state=CallScriptState)
+    app = rx.App(state=rx.State)
     with open("assets/external.js", "w") as f:
         f.write(external_scripts)
 

+ 30 - 24
integration/test_client_storage.py

@@ -97,7 +97,7 @@ def ClientSide():
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
         )
 
-    app = rx.App(state=ClientSideState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index, route="/foo")
     app.compile()
@@ -263,7 +263,6 @@ async def test_client_side_state(
     state_var_input.send_keys("c7")
     input_value_input.send_keys("c7 value")
     set_sub_state_button.click()
-
     state_var_input.send_keys("l1")
     input_value_input.send_keys("l1 value")
     set_sub_state_button.click()
@@ -276,7 +275,6 @@ async def test_client_side_state(
     state_var_input.send_keys("l4")
     input_value_input.send_keys("l4 value")
     set_sub_state_button.click()
-
     state_var_input.send_keys("c1s")
     input_value_input.send_keys("c1s value")
     set_sub_sub_state_button.click()
@@ -285,28 +283,28 @@ async def test_client_side_state(
     set_sub_sub_state_button.click()
 
     exp_cookies = {
-        "client_side_state.client_side_sub_state.c1": {
+        "state.client_side_state.client_side_sub_state.c1": {
             "domain": "localhost",
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c1",
+            "name": "state.client_side_state.client_side_sub_state.c1",
             "path": "/",
             "sameSite": "Lax",
             "secure": False,
             "value": "c1%20value",
         },
-        "client_side_state.client_side_sub_state.c2": {
+        "state.client_side_state.client_side_sub_state.c2": {
             "domain": "localhost",
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c2",
+            "name": "state.client_side_state.client_side_sub_state.c2",
             "path": "/",
             "sameSite": "Lax",
             "secure": False,
             "value": "c2%20value",
         },
-        "client_side_state.client_side_sub_state.c4": {
+        "state.client_side_state.client_side_sub_state.c4": {
             "domain": "localhost",
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c4",
+            "name": "state.client_side_state.client_side_sub_state.c4",
             "path": "/",
             "sameSite": "Strict",
             "secure": False,
@@ -321,19 +319,19 @@ async def test_client_side_state(
             "secure": False,
             "value": "c6%20value",
         },
-        "client_side_state.client_side_sub_state.c7": {
+        "state.client_side_state.client_side_sub_state.c7": {
             "domain": "localhost",
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c7",
+            "name": "state.client_side_state.client_side_sub_state.c7",
             "path": "/",
             "sameSite": "Lax",
             "secure": False,
             "value": "c7%20value",
         },
-        "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
+        "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
             "domain": "localhost",
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
+            "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
             "path": "/",
             "sameSite": "Lax",
             "secure": False,
@@ -354,40 +352,45 @@ async def test_client_side_state(
     input_value_input.send_keys("c3 value")
     set_sub_state_button.click()
     AppHarness._poll_for(
-        lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
+        lambda: "state.client_side_state.client_side_sub_state.c3"
+        in cookie_info_map(driver)
     )
-    c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
+    c3_cookie = cookie_info_map(driver)[
+        "state.client_side_state.client_side_sub_state.c3"
+    ]
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie == {
         "domain": "localhost",
         "httpOnly": False,
-        "name": "client_side_state.client_side_sub_state.c3",
+        "name": "state.client_side_state.client_side_sub_state.c3",
         "path": "/",
         "sameSite": "Lax",
         "secure": False,
         "value": "c3%20value",
     }
     time.sleep(2)  # wait for c3 to expire
-    assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
+    assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
+        driver
+    )
 
     local_storage_items = local_storage.items()
     local_storage_items.pop("chakra-ui-color-mode", None)
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l1")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
         == "l1 value"
     )
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l2")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
         == "l2 value"
     )
     assert local_storage_items.pop("l3") == "l3 value"
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l4")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
         == "l4 value"
     )
     assert (
         local_storage_items.pop(
-            "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
+            "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
         )
         == "l1s value"
     )
@@ -482,12 +485,15 @@ async def test_client_side_state(
 
     # make sure c5 cookie shows up on the `/foo` route
     AppHarness._poll_for(
-        lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
+        lambda: "state.client_side_state.client_side_sub_state.c5"
+        in cookie_info_map(driver)
     )
-    assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
+    assert cookie_info_map(driver)[
+        "state.client_side_state.client_side_sub_state.c5"
+    ] == {
         "domain": "localhost",
         "httpOnly": False,
-        "name": "client_side_state.client_side_sub_state.c5",
+        "name": "state.client_side_state.client_side_sub_state.c5",
         "path": "/foo/",
         "sameSite": "Lax",
         "secure": False,

+ 1 - 1
integration/test_connection_banner.py

@@ -19,7 +19,7 @@ def ConnectionBanner():
     def index():
         return rx.text("Hello World")
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.compile()
 

+ 5 - 3
integration/test_dynamic_routes.py

@@ -56,7 +56,7 @@ def DynamicRoute():
     def redirect_page():
         return rx.fragment(rx.text("redirecting..."))
 
-    app = rx.App(state=DynamicState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index, route="/static/x", on_load=DynamicState.on_load)  # type: ignore
@@ -143,10 +143,12 @@ def poll_for_order(
             return await dynamic_route.get_state(token)
 
         async def _check():
-            return (await _backend_state()).order == exp_order
+            return (await _backend_state()).substates[
+                "dynamic_state"
+            ].order == exp_order
 
         await AppHarness._poll_for_async(_check)
-        assert (await _backend_state()).order == exp_order
+        assert (await _backend_state()).substates["dynamic_state"].order == exp_order
 
     return _poll_for_order
 

+ 7 - 3
integration/test_event_actions.py

@@ -130,7 +130,7 @@ def TestEventAction():
             on_click=EventActionState.on_click("outer"),  # type: ignore
         )
 
-    app = rx.App(state=EventActionState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.compile()
 
@@ -211,10 +211,14 @@ def poll_for_order(
             return await event_action.get_state(token)
 
         async def _check():
-            return (await _backend_state()).order == exp_order
+            return (await _backend_state()).substates[
+                "event_action_state"
+            ].order == exp_order
 
         await AppHarness._poll_for_async(_check)
-        assert (await _backend_state()).order == exp_order
+        assert (await _backend_state()).substates[
+            "event_action_state"
+        ].order == exp_order
 
     return _poll_for_order
 

+ 13 - 13
integration/test_event_chain.py

@@ -122,7 +122,7 @@ def EventChain():
             time.sleep(0.5)
             self.interim_value = "final"
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
 
     token_input = rx.input(
         value=State.router.session.client_token, is_read_only=True, id="token"
@@ -401,12 +401,12 @@ async def test_event_chain_click(
     btn.click()
 
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
     await AppHarness._poll_for_async(_has_all_events)
-    event_order = (await event_chain.get_state(token)).event_order
+    event_order = (await event_chain.get_state(token)).substates["state"].event_order
     assert event_order == exp_event_order
 
 
@@ -453,12 +453,12 @@ async def test_event_chain_on_load(
     token = assert_token(event_chain, driver)
 
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
     await AppHarness._poll_for_async(_has_all_events)
-    backend_state = await event_chain.get_state(token)
+    backend_state = (await event_chain.get_state(token)).substates["state"]
     assert backend_state.event_order == exp_event_order
     assert backend_state.is_hydrated is True
 
@@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
     unmount_button.click()
 
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
     await AppHarness._poll_for_async(_has_all_events)
-    event_order = (await event_chain.get_state(token)).event_order
+    event_order = (await event_chain.get_state(token)).substates["state"].event_order
     assert event_order == exp_event_order
 
 

+ 3 - 3
integration/test_form_submit.py

@@ -22,7 +22,7 @@ def FormSubmit():
         def form_submit(self, form_data: dict):
             self.form_data = form_data
 
-    app = rx.App(state=FormState)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():
@@ -75,7 +75,7 @@ def FormSubmitName():
         def form_submit(self, form_data: dict):
             self.form_data = form_data
 
-    app = rx.App(state=FormState)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():
@@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
     submit_input.click()
 
     async def get_form_data():
-        return (await form_submit.get_state(token)).form_data
+        return (await form_submit.get_state(token)).substates["form_state"].form_data
 
     # wait for the form data to arrive at the backend
     form_data = await AppHarness._poll_for_async(get_form_data)

+ 12 - 8
integration/test_input.py

@@ -16,7 +16,7 @@ def FullyControlledInput():
     class State(rx.State):
         text: str = "initial"
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():
@@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("foo")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "ifoonitial"
-    assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
 
     # clear the input on the backend
     async with fully_controlled_input.modify_state(token) as state:
-        state.text = ""
-    assert (await fully_controlled_input.get_state(token)).text == ""
+        state.substates["state"].text = ""
+    assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
     assert (
         fully_controlled_input.poll_for_value(
             debounce_input, exp_not_equal="ifoonitial"
@@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("getting testing done")
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "getting testing done"
-    assert (
-        await fully_controlled_input.get_state(token)
-    ).text == "getting testing done"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
 
     # type into the on_change input
@@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
-    assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
 
     clear_button.click()

+ 2 - 2
integration/test_login_flow.py

@@ -42,7 +42,7 @@ def LoginSample():
             rx.button("Do it", on_click=State.login, id="doit"),
         )
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(login)
     app.compile()
@@ -137,6 +137,6 @@ def test_login_flow(
     logout_button = driver.find_element(By.ID, "logout")
     logout_button.click()
 
-    assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "")
+    assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
     with pytest.raises(NoSuchElementException):
         driver.find_element(By.ID, "auth-token")

+ 1 - 1
integration/test_radix_themes.py

@@ -81,7 +81,7 @@ def RadixThemesApp():
         )
 
     app = rx.App(
-        state=State,
+        state=rx.State,
         theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
     )
     app.add_page(index)

+ 1 - 1
integration/test_server_side_event.py

@@ -33,7 +33,7 @@ def ServerSideEvent():
         def set_value_return_c(self):
             return rx.set_value("c", "")
 
-    app = rx.App(state=SSState)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():

+ 1 - 1
integration/test_table.py

@@ -26,7 +26,7 @@ def Table():
 
         caption: str = "random caption"
 
-    app = rx.App(state=TableState)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():

+ 7 - 7
integration/test_upload.py

@@ -113,7 +113,7 @@ def UploadFile():
             ),
         )
 
-    app = rx.App(state=UploadState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.compile()
 
@@ -192,7 +192,7 @@ async def test_upload_file(
 
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
-        return (await upload_file.get_state(token))._file_data
+        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
@@ -205,8 +205,8 @@ async def test_upload_file(
     state = await upload_file.get_state(token)
     if secondary:
         # only the secondary form tracks progress and chain events
-        assert state.event_order.count("upload_progress") == 1
-        assert state.event_order.count("chain_event") == 1
+        assert state.substates["upload_state"].event_order.count("upload_progress") == 1
+        assert state.substates["upload_state"].event_order.count("chain_event") == 1
 
 
 @pytest.mark.asyncio
@@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
 
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
-        return (await upload_file.get_state(token))._file_data
+        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
@@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
 
     # look up the backend state and assert on progress
     state = await upload_file.get_state(token)
-    assert state.progress_dicts
-    assert exp_name not in state._file_data
+    assert state.substates["upload_state"].progress_dicts
+    assert exp_name not in state.substates["upload_state"]._file_data
 
     target_file.unlink()

+ 1 - 1
integration/test_var_operations.py

@@ -30,7 +30,7 @@ def VarOperations():
         dict2: dict = {3: 4}
         html_str: str = "<div>hello</div>"
 
-    app = rx.App(state=VarOperationState)
+    app = rx.App(state=rx.State)
 
     @app.add_page
     def index():

+ 15 - 13
reflex/app.py

@@ -57,6 +57,7 @@ from reflex.route import (
     verify_route_validity,
 )
 from reflex.state import (
+    BaseState,
     RouterData,
     State,
     StateManager,
@@ -98,7 +99,7 @@ class App(Base):
     socket_app: Optional[ASGIApp] = None
 
     # The state class to use for the app.
-    state: Optional[Type[State]] = None
+    state: Optional[Type[BaseState]] = None
 
     # Class to manage many client states.
     _state_manager: Optional[StateManager] = None
@@ -149,25 +150,24 @@ class App(Base):
                 "`connect_error_component` is deprecated, use `overlay_component` instead"
             )
         super().__init__(*args, **kwargs)
-        state_subclasses = State.__subclasses__()
-        inferred_state = state_subclasses[-1] if state_subclasses else None
+        state_subclasses = BaseState.__subclasses__()
         is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
 
-        # Special case to allow test cases have multiple subclasses of rx.State.
+        # Special case to allow test cases have multiple subclasses of rx.BaseState.
         if not is_testing_env:
-            # Only one State class is allowed.
+            # Only one Base State class is allowed.
             if len(state_subclasses) > 1:
                 raise ValueError(
-                    "rx.State has been subclassed multiple times. Only one subclass is allowed"
+                    "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
                 )
 
             # verify that provided state is valid
-            if self.state and inferred_state and self.state is not inferred_state:
+            if self.state and self.state is not State:
                 console.warn(
                     f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
-                    f" Defaulting to root state: ({inferred_state.__name__})"
+                    f" Defaulting to root state: ({State.__name__})"
                 )
-            self.state = inferred_state
+            self.state = State
         # Get the config
         config = get_config()
 
@@ -265,7 +265,7 @@ class App(Base):
             raise ValueError("The state manager has not been initialized.")
         return self._state_manager
 
-    async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
+    async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
         """Preprocess the event.
 
         This is where middleware can modify the event before it is processed.
@@ -290,7 +290,7 @@ class App(Base):
                 return out  # type: ignore
 
     async def postprocess(
-        self, state: State, event: Event, update: StateUpdate
+        self, state: BaseState, event: Event, update: StateUpdate
     ) -> StateUpdate:
         """Postprocess the event.
 
@@ -764,7 +764,7 @@ class App(Base):
                 future.result()
 
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.
 
         Args:
@@ -792,7 +792,9 @@ class App(Base):
                     sid=state.router.session.session_id,
                 )
 
-    def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
+    def _process_background(
+        self, state: BaseState, event: Event
+    ) -> asyncio.Task | None:
         """Process an event in the background and emit updates as they arrive.
 
         Args:

+ 2 - 1
reflex/app.pyi

@@ -33,6 +33,7 @@ from reflex.route import (
 )
 from reflex.state import (
     State as State,
+    BaseState as BaseState,
     StateManager as StateManager,
     StateUpdate as StateUpdate,
 )
@@ -69,7 +70,7 @@ class App(Base):
     api: FastAPI
     sio: Optional[AsyncServer]
     socket_app: Optional[ASGIApp]
-    state: Type[State]
+    state: Type[BaseState]
     state_manager: StateManager
     style: ComponentStyle
     middleware: List[Middleware]

+ 32 - 1
reflex/base.py

@@ -1,11 +1,42 @@
 """Define the base Reflex class."""
 from __future__ import annotations
 
-from typing import Any
+import os
+from typing import Any, List, Type
 
 import pydantic
+from pydantic import BaseModel
 from pydantic.fields import ModelField
 
+from reflex import constants
+
+
+def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
+    """Ensure that the field's name does not shadow an existing attribute of the model.
+
+    Args:
+        bases: List of base models to check for shadowed attrs.
+        field_name: name of attribute
+
+    Raises:
+        NameError: If state var field shadows another in its parent state
+    """
+    reload = os.getenv(constants.RELOAD_CONFIG) == "True"
+    for base in bases:
+        try:
+            if not reload and getattr(base, field_name, None):
+                pass
+        except TypeError as te:
+            raise NameError(
+                f'State var "{field_name}" in {base} has been shadowed by a substate var; '
+                f'use a different field name instead".'
+            ) from te
+
+
+# monkeypatch pydantic validate_field_name method to skip validating
+# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
+pydantic.main.validate_field_name = validate_field_name  # type: ignore
+
 
 class Base(pydantic.BaseModel):
     """The base class subclassed by all Reflex classes.

+ 5 - 5
reflex/compiler/compiler.py

@@ -15,7 +15,7 @@ from reflex.components.component import (
     StatefulComponent,
 )
 from reflex.config import get_config
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils.imports import ImportVar
 
 
@@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
     return templates.THEME.render(theme=theme)
 
 
-def _compile_contexts(state: Optional[Type[State]]) -> str:
+def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
     """Compile the initial state and contexts.
 
     Args:
@@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
 
 def _compile_page(
     component: Component,
-    state: Type[State],
+    state: Type[BaseState],
 ) -> str:
     """Compile the component given the app state.
 
@@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
     return output_path, code
 
 
-def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
+def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
     """Compile the initial state / context.
 
     Args:
@@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
 
 
 def compile_page(
-    path: str, component: Component, state: Type[State]
+    path: str, component: Component, state: Type[BaseState]
 ) -> tuple[str, str]:
     """Compile a single page.
 

+ 4 - 4
reflex/compiler/utils.py

@@ -21,7 +21,7 @@ from reflex.components.base import (
     Title,
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
-from reflex.state import Cookie, LocalStorage, State
+from reflex.state import BaseState, Cookie, LocalStorage
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 
@@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
     }
 
 
-def compile_state(state: Type[State]) -> dict:
+def compile_state(state: Type[BaseState]) -> dict:
     """Compile the state of the app.
 
     Args:
@@ -170,7 +170,7 @@ def _compile_client_storage_field(
 
 
 def _compile_client_storage_recursive(
-    state: Type[State],
+    state: Type[BaseState],
 ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
     """Compile the client-side storage for the given state recursively.
 
@@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
     return cookies, local_storage
 
 
-def compile_client_storage(state: Type[State]) -> dict[str, dict]:
+def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
     """Compile the client-side storage for the given state.
 
     Args:

+ 2 - 0
reflex/constants/__init__.py

@@ -6,6 +6,7 @@ from .base import (
     LOCAL_STORAGE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     PYTEST_CURRENT_TEST,
+    RELOAD_CONFIG,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
     Dirs,
@@ -85,6 +86,7 @@ __ALL__ = [
     PYTEST_CURRENT_TEST,
     PRODUCTION_BACKEND_URL,
     Reflex,
+    RELOAD_CONFIG,
     RequirementsTxt,
     RouteArgType,
     RouteRegex,

+ 1 - 0
reflex/constants/base.py

@@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
 # Testing variables.
 # Testing os env set by pytest when running a test case.
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
+RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"

+ 2 - 2
reflex/event.py

@@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 
 if TYPE_CHECKING:
-    from reflex.state import State
+    from reflex.state import BaseState
 
 
 class Event(Base):
@@ -64,7 +64,7 @@ def background(fn):
 
 
 def _no_chain_background_task(
-    state_cls: Type["State"], name: str, fn: Callable
+    state_cls: Type["BaseState"], name: str, fn: Callable
 ) -> Callable:
     """Protect against directly chaining a background task from another event handler.
 

+ 2 - 2
reflex/middleware/hydrate_middleware.py

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
 from reflex import constants
 from reflex.event import Event, fix_events, get_hydrate_event
 from reflex.middleware.middleware import Middleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 from reflex.utils import format
 
 if TYPE_CHECKING:
@@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
     """Middleware to handle initial app hydration."""
 
     async def preprocess(
-        self, app: App, state: State, event: Event
+        self, app: App, state: BaseState, event: Event
     ) -> Optional[StateUpdate]:
         """Preprocess the event.
 

+ 3 - 3
reflex/middleware/middleware.py

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
 
 from reflex.base import Base
 from reflex.event import Event
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 
 if TYPE_CHECKING:
     from reflex.app import App
@@ -16,7 +16,7 @@ class Middleware(Base, ABC):
     """Middleware to preprocess and postprocess requests."""
 
     async def preprocess(
-        self, app: App, state: State, event: Event
+        self, app: App, state: BaseState, event: Event
     ) -> Optional[StateUpdate]:
         """Preprocess the event.
 
@@ -31,7 +31,7 @@ class Middleware(Base, ABC):
         return None
 
     async def postprocess(
-        self, app: App, state: State, event: Event, update: StateUpdate
+        self, app: App, state: BaseState, event: Event, update: StateUpdate
     ) -> StateUpdate:
         """Postprocess the event.
 

+ 70 - 42
reflex/state.py

@@ -7,6 +7,7 @@ import copy
 import functools
 import inspect
 import json
+import os
 import traceback
 import urllib.parse
 import uuid
@@ -81,7 +82,7 @@ class HeaderData(Base):
 class PageData(Base):
     """An object containing page data."""
 
-    host: str = ""  #  repeated with self.headers.origin (remove or keep the duplicate?)
+    host: str = ""  # repeated with self.headers.origin (remove or keep the duplicate?)
     path: str = ""
     raw_path: str = ""
     full_path: str = ""
@@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
 }
 
 
-class State(Base, ABC, extra=pydantic.Extra.allow):
+class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
     # A map from the var name to the var.
@@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The event handlers.
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
 
+    # A set of subclassses of this class.
+    class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
+
     # Mapping of var name to set of computed variables that depend on it
     _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
 
@@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     _always_dirty_substates: ClassVar[Set[str]] = set()
 
     # The parent state.
-    parent_state: Optional[State] = None
+    parent_state: Optional[BaseState] = None
 
     # The substates of the state.
-    substates: Dict[str, State] = {}
+    substates: Dict[str, BaseState] = {}
 
     # The set of dirty vars.
     dirty_vars: Set[str] = set()
@@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The router data for the current page
     router: RouterData = RouterData()
 
-    # The hydrated bool.
-    is_hydrated: bool = False
-
-    def __init__(self, *args, parent_state: State | None = None, **kwargs):
+    def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
         """Initialize the state.
 
         Args:
@@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             parent_state: The parent state.
             **kwargs: The kwargs to pass to the Pydantic init method.
 
-        Raises:
-            ValueError: If a substate class shadows another.
         """
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
 
         # Setup the substates.
         for substate in self.get_substates():
-            substate_name = substate.get_name()
-            if substate_name in self.substates:
-                raise ValueError(
-                    f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
-                    f"substate classes is not allowed."
-                )
-            self.substates[substate_name] = substate(parent_state=self)
+            self.substates[substate.get_name()] = substate(parent_state=self)
         # Convert the event handlers to functions.
         self._init_event_handlers()
 
         # Create a fresh copy of the backend variables for this instance
         self._backend_vars = copy.deepcopy(self.backend_vars)
 
-    def _init_event_handlers(self, state: State | None = None):
+    def _init_event_handlers(self, state: BaseState | None = None):
         """Initialize event handlers.
 
         Allow event handlers to be called directly on the instance. This is
@@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
         Args:
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
+
+        Raises:
+            ValueError: If a substate class shadows another.
         """
+        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
         super().__init_subclass__(**kwargs)
         # Event handlers should not shadow builtin state methods.
         cls._check_overridden_methods()
 
+        # Reset subclass tracking for this class.
+        cls.class_subclasses = set()
+
         # Get the parent vars.
         parent_state = cls.get_parent_state()
         if parent_state is not None:
             cls.inherited_vars = parent_state.vars
             cls.inherited_backend_vars = parent_state.backend_vars
 
+            # Check if another substate class with the same name has already been defined.
+            if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
+                if is_testing_env:
+                    # Clear existing subclass with same name when app is reloaded via
+                    # utils.prerequisites.get_app(reload=True)
+                    parent_state.class_subclasses = set(
+                        c
+                        for c in parent_state.class_subclasses
+                        if c.__name__ != cls.__name__
+                    )
+                else:
+                    # During normal operation, subclasses cannot have the same name, even if they are
+                    # defined in different modules.
+                    raise ValueError(
+                        f"The substate class '{cls.__name__}' has been defined multiple times. "
+                        "Shadowing substate classes is not allowed."
+                    )
+            # Track this new subclass in the parent state's subclasses set.
+            parent_state.class_subclasses.add(cls)
+
         cls.new_backend_vars = {
             name: value
             for name, value in cls.__dict__.items()
@@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
     @classmethod
     @functools.lru_cache()
-    def get_parent_state(cls) -> Type[State] | None:
+    def get_parent_state(cls) -> Type[BaseState] | None:
         """Get the parent state.
 
         Returns:
@@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         parent_states = [
             base
             for base in cls.__bases__
-            if types._issubclass(base, State) and base is not State
+            if types._issubclass(base, BaseState) and base is not BaseState
         ]
         assert len(parent_states) < 2, "Only one parent state is allowed."
         return parent_states[0] if len(parent_states) == 1 else None  # type: ignore
 
     @classmethod
-    @functools.lru_cache()
-    def get_substates(cls) -> set[Type[State]]:
+    def get_substates(cls) -> set[Type[BaseState]]:
         """Get the substates of the state.
 
         Returns:
             The substates of the state.
         """
-        return set(cls.__subclasses__())
+        return cls.class_subclasses
 
     @classmethod
     @functools.lru_cache()
@@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
     @classmethod
     @functools.lru_cache()
-    def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
+    def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
         """Get the class substate.
 
         Args:
@@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         """
         return {
             func[0]: func[1]
-            for func in inspect.getmembers(State, predicate=inspect.isfunction)
+            for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
             if not func[0].startswith("__")
         }
 
@@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.substates.values():
             substate._reset_client_storage()
 
-    def get_substate(self, path: Sequence[str]) -> State | None:
+    def get_substate(self, path: Sequence[str]) -> BaseState | None:
         """Get the substate.
 
         Args:
@@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
     def _get_event_handler(
         self, event: Event
-    ) -> tuple[State | StateProxy, EventHandler]:
+    ) -> tuple[BaseState | StateProxy, EventHandler]:
         """Get the event handler for the given event.
 
         Args:
@@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         )
 
     async def _process_event(
-        self, handler: EventHandler, state: State | StateProxy, payload: Dict
+        self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
     ) -> AsyncIterator[StateUpdate]:
         """Process event.
 
@@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             d.update(substate_d)
         return d
 
-    async def __aenter__(self) -> State:
+    async def __aenter__(self) -> BaseState:
         """Enter the async context manager protocol.
 
         This should not be used for the State class, but exists for
@@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         pass
 
 
+class State(BaseState):
+    """The app Base State."""
+
+    # The hydrated bool.
+    is_hydrated: bool = False
+
+
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
 
@@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
     """A class to manage many client states."""
 
     # The state class to use.
-    state: Type[State]
+    state: Type[BaseState]
 
     @classmethod
-    def create(cls, state: Type[State]):
+    def create(cls, state: Type[BaseState]):
         """Create a new state manager.
 
         Args:
@@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
         return StateManagerMemory(state=state)
 
     @abstractmethod
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
 
         Args:
@@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
         pass
 
     @abstractmethod
-    async def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: BaseState):
         """Set the state for a token.
 
         Args:
@@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
 
     @abstractmethod
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
 
         Args:
@@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
     """A state manager that stores states in memory."""
 
     # The mapping of client ids to states.
-    states: Dict[str, State] = {}
+    states: Dict[str, BaseState] = {}
 
     # The mutex ensures the dict of mutexes is updated exclusively
     _state_manager_lock = asyncio.Lock()
@@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
             "_states_locks": {"exclude": True},
         }
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
 
         Args:
@@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
             self.states[token] = self.state()
         return self.states[token]
 
-    async def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: BaseState):
         """Set the state for a token.
 
         Args:
@@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
         pass
 
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
 
         Args:
@@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
         b"evicted",
     }
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
 
         Args:
@@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
             return await self.get_state(token)
         return cloudpickle.loads(redis_state)
 
-    async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
+    async def set_state(
+        self, token: str, state: BaseState, lock_id: bytes | None = None
+    ):
         """Set the state for a token.
 
         Args:
@@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
         await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
 
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
 
         Args:
@@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
 
     __mutable_types__ = (list, dict, set, Base)
 
-    def __init__(self, wrapped: Any, state: State, field_name: str):
+    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
         """Create a proxy for a mutable object that tracks changes.
 
         Args:

+ 6 - 3
reflex/testing.py

@@ -38,7 +38,7 @@ import reflex.utils.build
 import reflex.utils.exec
 import reflex.utils.prerequisites
 import reflex.utils.processes
-from reflex.state import State, StateManagerMemory, StateManagerRedis
+from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
 
 try:
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
@@ -162,6 +162,9 @@ class AppHarness:
         with chdir(self.app_path):
             # ensure config and app are reloaded when testing different app
             reflex.config.get_config(reload=True)
+            # reset rx.State subclasses
+            State.class_subclasses.clear()
+            # self.app_module.app.
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
         if isinstance(self.app_instance.state_manager, StateManagerRedis):
@@ -434,7 +437,7 @@ class AppHarness:
         self._frontends.append(driver)
         return driver
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state associated with the given token.
 
         Args:
@@ -561,7 +564,7 @@ class AppHarness:
             )
         return element.get_attribute("value")
 
-    def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
+    def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
         """Poll app state_manager for any connected clients.
 
         Args:

+ 2 - 0
reflex/utils/prerequisites.py

@@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
     Returns:
         The app based on the default config.
     """
+    os.environ[constants.RELOAD_CONFIG] = str(reload)
     config = get_config()
     module = ".".join([config.app_name, config.app_name])
     sys.path.insert(0, os.getcwd())
     app = __import__(module, fromlist=(constants.CompileVars.APP,))
     if reload:
+
         importlib.reload(app)
     return app
 

+ 7 - 7
reflex/vars.py

@@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
 from reflex.utils.imports import ImportDict, ImportVar
 
 if TYPE_CHECKING:
-    from reflex.state import State
+    from reflex.state import BaseState
 
 # Set of unique variable names.
 USED_VARIABLES = set()
@@ -1472,7 +1472,7 @@ class Var:
             )
         )
 
-    def _var_set_state(self, state: Type[State] | str) -> Any:
+    def _var_set_state(self, state: Type[BaseState] | str) -> Any:
         """Set the state of the var.
 
         Args:
@@ -1604,14 +1604,14 @@ class BaseVar(Var):
             return setter
         return ".".join((self._var_data.state, setter))
 
-    def get_setter(self) -> Callable[[State, Any], None]:
+    def get_setter(self) -> Callable[[BaseState, Any], None]:
         """Get the var's setter function.
 
         Returns:
             A function that that creates a setter for the var.
         """
 
-        def setter(state: State, value: Any):
+        def setter(state: BaseState, value: Any):
             """Get the setter for the var.
 
             Args:
@@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
 
     def __init__(
         self,
-        fget: Callable[[State], Any],
-        fset: Callable[[State, Any], None] | None = None,
-        fdel: Callable[[State], Any] | None = None,
+        fget: Callable[[BaseState], Any],
+        fset: Callable[[BaseState, Any], None] | None = None,
+        fdel: Callable[[BaseState], Any] | None = None,
         doc: str | None = None,
         **kwargs,
     ):

+ 3 - 2
reflex/vars.pyi

@@ -5,6 +5,7 @@ from _typeshed import Incomplete
 from reflex import constants as constants
 from reflex.base import Base as Base
 from reflex.state import State as State
+from reflex.state import BaseState as BaseState
 from reflex.utils import console as console, format as format, types as types
 from reflex.utils.imports import ImportVar
 from types import FunctionType
@@ -110,7 +111,7 @@ class Var:
     def as_ref(self) -> Var: ...
     @property
     def _var_full_name(self) -> str: ...
-    def _var_set_state(self, state: Type[State] | str) -> Any: ...
+    def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
 
 @dataclass(eq=False)
 class BaseVar(Var):
@@ -123,7 +124,7 @@ class BaseVar(Var):
     def __hash__(self) -> int: ...
     def get_default_value(self) -> Any: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
-    def get_setter(self) -> Callable[[State, Any], None]: ...
+    def get_setter(self) -> Callable[[BaseState, Any], None]: ...
 
 @dataclass(init=False)
 class ComputedVar(Var):

+ 3 - 0
tests/__init__.py

@@ -1 +1,4 @@
 """Root directory for tests."""
+import os
+
+from reflex import constants

+ 2 - 2
tests/components/base/test_script.py

@@ -2,7 +2,7 @@
 import pytest
 
 from reflex.components.base.script import Script
-from reflex.state import State
+from reflex.state import BaseState
 
 
 def test_script_inline():
@@ -31,7 +31,7 @@ def test_script_neither():
         Script.create()
 
 
-class EvState(State):
+class EvState(BaseState):
     """State for testing event handlers."""
 
     def on_ready(self):

+ 5 - 4
tests/components/datadisplay/conftest.py

@@ -5,6 +5,7 @@ import pandas as pd
 import pytest
 
 import reflex as rx
+from reflex.state import BaseState
 
 
 @pytest.fixture
@@ -18,7 +19,7 @@ def data_table_state(request):
         The data table state class.
     """
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         data = request.param["data"]
         columns = ["column1", "column2"]
 
@@ -33,7 +34,7 @@ def data_table_state2():
         The data table state class.
     """
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data = pd.DataFrame()
 
         @rx.var
@@ -51,7 +52,7 @@ def data_table_state3():
         The data table state class.
     """
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data: List = []
         _columns: List = ["col1", "col2"]
 
@@ -74,7 +75,7 @@ def data_table_state4():
         The data table state class.
     """
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data: List = []
         _columns: List = ["col1", "col2"]
 

+ 2 - 2
tests/components/datadisplay/test_table.py

@@ -4,12 +4,12 @@ from typing import List, Tuple
 import pytest
 
 from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
-from reflex.state import State
+from reflex.state import BaseState
 
 PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
 
 
-class TableState(State):
+class TableState(BaseState):
     """Test State class."""
 
     rows_List_List_str: List[List[str]] = [["random", "row"]]

+ 2 - 1
tests/components/forms/test_debounce.py

@@ -3,6 +3,7 @@
 import pytest
 
 import reflex as rx
+from reflex.state import BaseState
 from reflex.vars import BaseVar
 
 
@@ -24,7 +25,7 @@ def test_render_many_child():
         _ = rx.debounce_input("foo", "bar").render()
 
 
-class S(rx.State):
+class S(BaseState):
     """Example state for debounce tests."""
 
     value: str = ""

+ 2 - 1
tests/components/layout/test_cond.py

@@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
     tablet_only,
 )
 from reflex.components.typography.text import Text
+from reflex.state import BaseState
 from reflex.vars import Var
 
 
 @pytest.fixture
 def cond_state(request):
-    class CondState(rx.State):
+    class CondState(BaseState):
         value: request.param["value_type"] = request.param["value"]  # noqa
 
     return CondState

+ 2 - 2
tests/components/layout/test_foreach.py

@@ -4,10 +4,10 @@ import pytest
 
 from reflex.components import box, foreach, text
 from reflex.components.layout import Foreach
-from reflex.state import State
+from reflex.state import BaseState
 
 
-class ForEachState(State):
+class ForEachState(BaseState):
     """A state for testing the ForEach component."""
 
     colors_list: List[str] = ["red", "yellow"]

+ 3 - 3
tests/components/test_component.py

@@ -14,7 +14,7 @@ from reflex.components.component import (
 from reflex.components.layout.box import Box
 from reflex.constants import EventTriggers
 from reflex.event import EventChain, EventHandler
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils.imports import ImportVar
@@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
 
 @pytest.fixture
 def test_state():
-    class TestState(State):
+    class TestState(BaseState):
         num: int
 
         def do_something(self):
@@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
     )
 
 
-class C1State(State):
+class C1State(BaseState):
     """State for testing C1 component."""
 
     def mock_handler(self, _e, _bravo, _charlie):

+ 0 - 21
tests/conftest.py

@@ -8,7 +8,6 @@ from typing import Dict, Generator
 
 import pytest
 
-import reflex as rx
 from reflex.app import App
 from reflex.event import EventSpec
 
@@ -225,23 +224,3 @@ def token() -> str:
         A fresh/unique token string.
     """
     return str(uuid.uuid4())
-
-
-@pytest.fixture
-def duplicate_substate():
-    """Create a Test state that has duplicate child substates.
-
-    Returns:
-        The test state.
-    """
-
-    class TestState(rx.State):
-        pass
-
-    class ChildTestState(TestState):  # type: ignore # noqa
-        pass
-
-    class ChildTestState(TestState):  # type: ignore # noqa
-        pass
-
-    return TestState

+ 9 - 5
tests/middleware/test_hydrate_middleware.py

@@ -2,13 +2,14 @@ from typing import Any, Dict
 
 import pytest
 
+from reflex import constants
 from reflex.app import App
 from reflex.constants import CompileVars
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 
 
-def exp_is_hydrated(state: State) -> Dict[str, Any]:
+def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
     """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
 
     Args:
@@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
     return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
 
 
-class TestState(State):
+class TestState(BaseState):
     """A test state with no return in handler."""
 
     __test__ = False
@@ -32,7 +33,7 @@ class TestState(State):
         self.num += 1
 
 
-class TestState2(State):
+class TestState2(BaseState):
     """A test state with return in handler."""
 
     __test__ = False
@@ -54,7 +55,7 @@ class TestState2(State):
         self.name = "random"
 
 
-class TestState3(State):
+class TestState3(BaseState):
     """A test state with async handler."""
 
     __test__ = False
@@ -97,6 +98,9 @@ async def test_preprocess(
         event_fixture: The event fixture(an Event).
         expected: Expected delta.
     """
+    test_state.add_var(
+        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+    )
     app = App(state=test_state, load_events={"index": [test_state.test_handler]})
     state = test_state()
 

+ 4 - 2
tests/states/__init__.py

@@ -1,10 +1,12 @@
-"""Common rx.State subclasses for use in tests."""
+"""Common rx.BaseState subclasses for use in tests."""
 import reflex as rx
+from reflex.state import BaseState
 
 from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
 from .upload import (
     ChildFileUploadState,
     FileStateBase1,
+    FileStateBase2,
     FileUploadState,
     GrandChildFileUploadState,
     SubUploadState,
@@ -12,7 +14,7 @@ from .upload import (
 )
 
 
-class GenState(rx.State):
+class GenState(BaseState):
     """A state with event handlers that generate multiple updates."""
 
     value: int

+ 4 - 3
tests/states/mutation.py

@@ -3,9 +3,10 @@
 from typing import Dict, List, Set, Union
 
 import reflex as rx
+from reflex.state import BaseState
 
 
-class DictMutationTestState(rx.State):
+class DictMutationTestState(BaseState):
     """A state for testing ReflexDict mutation."""
 
     # plain dict
@@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
         self.friend_in_nested_dict["friend"]["age"] = 30
 
 
-class ListMutationTestState(rx.State):
+class ListMutationTestState(BaseState):
     """A state for testing ReflexList mutation."""
 
     # plain list
@@ -144,7 +145,7 @@ class CustomVar(rx.Base):
     custom: OtherBase = OtherBase()
 
 
-class MutableTestState(rx.State):
+class MutableTestState(BaseState):
     """A test state."""
 
     array: List[Union[str, List, Dict[str, str]]] = [

+ 5 - 4
tests/states/upload.py

@@ -3,9 +3,10 @@ from pathlib import Path
 from typing import ClassVar, List
 
 import reflex as rx
+from reflex.state import BaseState, State
 
 
-class UploadState(rx.State):
+class UploadState(BaseState):
     """The base state for uploading a file."""
 
     async def handle_upload1(self, files: List[rx.UploadFile]):
@@ -17,7 +18,7 @@ class UploadState(rx.State):
         pass
 
 
-class BaseState(rx.State):
+class BaseState(BaseState):
     """The test base state."""
 
     pass
@@ -37,7 +38,7 @@ class SubUploadState(BaseState):
         pass
 
 
-class FileUploadState(rx.State):
+class FileUploadState(State):
     """The base state for uploading a file."""
 
     img_list: List[str]
@@ -79,7 +80,7 @@ class FileUploadState(rx.State):
         pass
 
 
-class FileStateBase1(rx.State):
+class FileStateBase1(State):
     """The base state for a child FileUploadState."""
 
     pass

+ 27 - 19
tests/test_app.py

@@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
-from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import ComputedVar
@@ -43,7 +43,7 @@ from .states import (
 )
 
 
-class EmptyState(State):
+class EmptyState(BaseState):
     """An empty state."""
 
     pass
@@ -77,14 +77,14 @@ def about_page():
     return about
 
 
-class ATestState(State):
+class ATestState(BaseState):
     """A simple state for testing."""
 
     var: int
 
 
 @pytest.fixture()
-def test_state() -> Type[State]:
+def test_state() -> Type[BaseState]:
     """A default state.
 
     Returns:
@@ -94,14 +94,14 @@ def test_state() -> Type[State]:
 
 
 @pytest.fixture()
-def redundant_test_state() -> Type[State]:
+def redundant_test_state() -> Type[BaseState]:
     """A default state.
 
     Returns:
         A default state.
     """
 
-    class RedundantTestState(State):
+    class RedundantTestState(BaseState):
         var: int
 
     return RedundantTestState
@@ -198,12 +198,12 @@ def test_default_app(app: App):
 
 
 def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
-    """Test that an error is thrown when multiple classes subclass rx.State.
+    """Test that an error is thrown when multiple classes subclass rx.BaseState.
 
     Args:
         monkeypatch: Pytest monkeypatch object.
-        test_state: A test state subclassing rx.State.
-        redundant_test_state: Another test state subclassing rx.State.
+        test_state: A test state subclassing rx.BaseState.
+        redundant_test_state: Another test state subclassing rx.BaseState.
     """
     monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
     with pytest.raises(ValueError):
@@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
     [
         (
             FileUploadState,
-            {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
+            {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
         ),
         (
             ChildFileUploadState,
             {
-                "file_state_base1.child_file_upload_state": {
+                "state.file_state_base1.child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
             },
@@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
         (
             GrandChildFileUploadState,
             {
-                "file_state_base1.file_state_base2.grand_child_file_upload_state": {
+                "state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
             },
         ),
     ],
 )
-async def test_upload_file(tmp_path, state, delta, token: str):
+async def test_upload_file(tmp_path, state, delta, token: str, mocker):
     """Test that file upload works correctly.
 
     Args:
@@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
         state: The state class.
         delta: Expected delta
         token: a Token.
+        mocker: pytest mocker object.
     """
+    mocker.patch(
+        "reflex.state.State.class_subclasses",
+        {state if state is FileUploadState else FileStateBase1},
+    )
     state._tmp_path = tmp_path
     # The App state must be the "root" of the state tree
-    app = App(state=state if state is FileUploadState else FileStateBase1)
+    app = App(state=State)
     app.event_namespace.emit = AsyncMock()  # type: ignore
     current_state = await app.state_manager.get_state(token)
     data = b"This is binary data"
@@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
     request_mock = unittest.mock.Mock()
     request_mock.headers = {
         "reflex-client-token": token,
-        "reflex-event-handler": f"{state_name}.multi_handle_upload",
+        "reflex-event-handler": f"state.{state_name}.multi_handle_upload",
     }
 
     file1 = UploadFile(
@@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
         await app.state_manager.redis.close()
 
 
-class DynamicState(State):
+class DynamicState(BaseState):
     """State class for testing dynamic route var.
 
     This is defined at module level because event handlers cannot be addressed
@@ -891,9 +896,7 @@ class DynamicState(State):
 
 @pytest.mark.asyncio
 async def test_dynamic_route_var_route_change_completed_on_load(
-    index_page,
-    windows_platform: bool,
-    token: str,
+    index_page, windows_platform: bool, token: str, mocker
 ):
     """Create app with dynamic route var, and simulate navigation.
 
@@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         index_page: The index page.
         windows_platform: Whether the system is windows.
         token: a Token.
+        mocker: pytest mocker object.
     """
+    mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
+    DynamicState.add_var(
+        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+    )
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
     if windows_platform:

+ 2 - 2
tests/test_event.py

@@ -4,7 +4,7 @@ import pytest
 
 from reflex import event
 from reflex.event import Event, EventHandler, EventSpec, fix_events
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils import format
 from reflex.vars import Var
 
@@ -303,7 +303,7 @@ def test_event_actions():
 
 
 def test_event_actions_on_state():
-    class EventActionState(State):
+    class EventActionState(BaseState):
         def handler(self):
             pass
 

+ 35 - 29
tests/test_state.py

@@ -19,11 +19,11 @@ from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
 from reflex.state import (
+    BaseState,
     ImmutableStateError,
     LockExpiredError,
     MutableProxy,
     RouterData,
-    State,
     StateManager,
     StateManagerMemory,
     StateManagerRedis,
@@ -75,7 +75,7 @@ class Object(Base):
     prop2: str = "hello"
 
 
-class TestState(State):
+class TestState(BaseState):
     """A test state."""
 
     # Set this class as not test one
@@ -148,7 +148,7 @@ class GrandchildState(ChildState):
         pass
 
 
-class DateTimeState(State):
+class DateTimeState(BaseState):
     """A State with some datetime fields."""
 
     d: datetime.date = datetime.date.fromisoformat("1989-11-09")
@@ -253,7 +253,6 @@ def test_class_vars(test_state):
     """
     cls = type(test_state)
     assert set(cls.vars.keys()) == {
-        CompileVars.IS_HYDRATED,  # added by hydrate_middleware to all State
         "router",
         "num1",
         "num2",
@@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
         "obj",
         "upper",
         "complex",
-        "is_hydrated",
         "fig",
         "key",
         "sum",
@@ -837,7 +835,7 @@ def test_get_query_params(test_state):
 
 
 def test_add_var():
-    class DynamicState(State):
+    class DynamicState(BaseState):
         pass
 
     ds1 = DynamicState()
@@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
     assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
 
 
-class InterdependentState(State):
+class InterdependentState(BaseState):
     """A state with 3 vars and 3 computed vars.
 
     x: a variable that no computed var depends on
@@ -915,7 +913,7 @@ class InterdependentState(State):
 
 
 @pytest.fixture
-def interdependent_state() -> State:
+def interdependent_state() -> BaseState:
     """A state with varying dependency between vars.
 
     Returns:
@@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
 def test_child_state():
     """Test that the child state computed vars can reference parent state vars."""
 
-    class MainState(State):
+    class MainState(BaseState):
         v: int = 2
 
     class ChildState(MainState):
@@ -1006,7 +1004,7 @@ def test_child_state():
 def test_conditional_computed_vars():
     """Test that computed vars can have conditionals."""
 
-    class MainState(State):
+    class MainState(BaseState):
         flag: bool = False
         t1: str = "a"
         t2: str = "b"
@@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
 def test_event_handlers_call_other_handlers():
     """Test that event handlers can call other event handlers."""
 
-    class MainState(State):
+    class MainState(BaseState):
         v: int = 0
 
         def set_v(self, v: int):
@@ -1077,7 +1075,7 @@ def test_computed_var_cached():
     """Test that a ComputedVar doesn't recalculate when accessed."""
     comp_v_calls = 0
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
 
         @rx.cached_var
@@ -1102,7 +1100,7 @@ def test_computed_var_cached():
 def test_computed_var_cached_depends_on_non_cached():
     """Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
 
         @rx.var
@@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
     """Child state cached_var that depends on parent state un cached var is always recalculated."""
     counter = 0
 
-    class ParentState(State):
+    class ParentState(BaseState):
         @rx.var
         def no_cache_v(self) -> int:
             nonlocal counter
@@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
     dict1 = ps.dict()
     assert dict1[ps.get_full_name()] == {
         "no_cache_v": 1,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
     assert dict1[cs.get_full_name()] == {"dep_v": 2}
     dict2 = ps.dict()
     assert dict2[ps.get_full_name()] == {
         "no_cache_v": 3,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
     assert dict2[cs.get_full_name()] == {"dep_v": 4}
     dict3 = ps.dict()
     assert dict3[ps.get_full_name()] == {
         "no_cache_v": 5,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
     assert dict3[cs.get_full_name()] == {"dep_v": 6}
@@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
     """
     counter = 0
 
-    class HandlerState(State):
+    class HandlerState(BaseState):
         x: int = 42
 
         def handler(self):
@@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
 def test_computed_var_dependencies():
     """Test that a ComputedVar correctly tracks its dependencies."""
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
         w: int = 0
         x: int = 0
@@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
 def test_backend_method():
     """A method with leading underscore should be callable from event handler."""
 
-    class BackendMethodState(State):
+    class BackendMethodState(BaseState):
         def _be_method(self):
             return True
 
@@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
     """Test that an error is thrown when an event handler shadows a state method."""
     with pytest.raises(NameError) as err:
 
-        class InvalidTest(rx.State):
+        class InvalidTest(BaseState):
             def reset(self):
                 pass
 
@@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
 def test_state_with_invalid_yield():
     """Test that an error is thrown when a state yields an invalid value."""
 
-    class StateWithInvalidYield(rx.State):
+    class StateWithInvalidYield(BaseState):
         """A state that yields an invalid value."""
 
         def invalid_handler(self):
@@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
 
 
-class BackgroundTaskState(State):
+class BackgroundTaskState(BaseState):
     """A state with a background task."""
 
     order: List[str] = []
@@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
         assert not isinstance(var_copy, MutableProxy)
 
 
-def test_duplicate_substate_class(duplicate_substate):
+def test_duplicate_substate_class(mocker):
+    mocker.patch("reflex.state.os.environ", {})
     with pytest.raises(ValueError):
-        duplicate_substate()
+
+        class TestState(BaseState):
+            pass
+
+        class ChildTestState(TestState):  # type: ignore # noqa
+            pass
+
+        class ChildTestState(TestState):  # type: ignore # noqa
+            pass
+
+        return TestState
 
 
 class Foo(Base):
@@ -2206,7 +2212,7 @@ class Foo(Base):
 def test_json_dumps_with_mutables():
     """Test that json.dumps works with Base vars inside mutable types."""
 
-    class MutableContainsBase(State):
+    class MutableContainsBase(BaseState):
         items: List[Foo] = [Foo()]
 
     dict_val = MutableContainsBase().dict()
@@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
     f_formatted_router = str(formatted_router).replace("'", '"')
     assert (
         val
-        == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
+        == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
     )
 
 
@@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
     default = [[0, 0], [0, 1], [1, 1]]
     copied_default = copy.deepcopy(default)
 
-    class MutableResetState(State):
+    class MutableResetState(BaseState):
         items: List[List[int]] = default
 
     instance = MutableResetState()
@@ -2273,7 +2279,7 @@ class Custom3(Base):
 def test_state_union_optional():
     """Test that state can be defined with Union and Optional vars."""
 
-    class UnionState(State):
+    class UnionState(BaseState):
         int_float: Union[int, float] = 0
         opt_int: Optional[int]
         c3: Optional[Custom3]

+ 5 - 11
tests/test_var.py

@@ -6,7 +6,7 @@ import pytest
 from pandas import DataFrame
 
 from reflex.base import Base
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.vars import (
     BaseVar,
     ComputedVar,
@@ -24,12 +24,6 @@ test_vars = [
 ]
 
 
-class BaseState(State):
-    """A Test State."""
-
-    val: str = "key"
-
-
 @pytest.fixture
 def TestObj():
     class TestObj(Base):
@@ -41,7 +35,7 @@ def TestObj():
 
 @pytest.fixture
 def ParentState(TestObj):
-    class ParentState(State):
+    class ParentState(BaseState):
         foo: int
         bar: int
 
@@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
 
 @pytest.fixture
 def StateWithAnyVar(TestObj):
-    class StateWithAnyVar(State):
+    class StateWithAnyVar(BaseState):
         @ComputedVar
         def var_without_annotation(self) -> typing.Any:
             return TestObj
@@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
 
 @pytest.fixture
 def StateWithCorrectVarAnnotation():
-    class StateWithCorrectVarAnnotation(State):
+    class StateWithCorrectVarAnnotation(BaseState):
         @ComputedVar
         def var_with_annotation(self) -> str:
             return "Correct annotation"
@@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
 
 @pytest.fixture
 def StateWithWrongVarAnnotation(TestObj):
-    class StateWithWrongVarAnnotation(State):
+    class StateWithWrongVarAnnotation(BaseState):
         @ComputedVar
         def var_with_annotation(self) -> str:
             return TestObj

+ 0 - 2
tests/utils/test_format.py

@@ -528,7 +528,6 @@ formatted_router = {
                     },
                     "dt": "1989-11-09 18:53:00+01:00",
                     "fig": [],
-                    "is_hydrated": False,
                     "key": "",
                     "map_key": "a",
                     "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
@@ -553,7 +552,6 @@ formatted_router = {
                 DateTimeState.get_full_name(): {
                     "d": "1989-11-09",
                     "dt": "1989-11-09 18:53:00+01:00",
-                    "is_hydrated": False,
                     "t": "18:53:00+01:00",
                     "td": "11 days, 0:11:00",
                     "router": formatted_router,

+ 2 - 2
tests/utils/test_utils.py

@@ -10,7 +10,7 @@ from packaging import version
 from reflex import constants
 from reflex.base import Base
 from reflex.event import EventHandler
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils import (
     build,
     prerequisites,
@@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
 VMAXPLUS1 = version.parse(get_above_max_version())
 
 
-class ExampleTestState(State):
+class ExampleTestState(BaseState):
     """Test state class."""
 
     def test_event_handler(self):