Bladeren bron

RED-1052/rx.State as Base State (#2146)

Elijah Ahianyo 1 jaar geleden
bovenliggende
commit
e3ee98098a
49 gewijzigde bestanden met toevoegingen van 356 en 270 verwijderingen
  1. 1 1
      integration/test_background_task.py
  2. 1 1
      integration/test_call_script.py
  3. 30 24
      integration/test_client_storage.py
  4. 1 1
      integration/test_connection_banner.py
  5. 5 3
      integration/test_dynamic_routes.py
  6. 7 3
      integration/test_event_actions.py
  7. 13 13
      integration/test_event_chain.py
  8. 3 3
      integration/test_form_submit.py
  9. 12 8
      integration/test_input.py
  10. 2 2
      integration/test_login_flow.py
  11. 1 1
      integration/test_radix_themes.py
  12. 1 1
      integration/test_server_side_event.py
  13. 1 1
      integration/test_table.py
  14. 7 7
      integration/test_upload.py
  15. 1 1
      integration/test_var_operations.py
  16. 15 13
      reflex/app.py
  17. 2 1
      reflex/app.pyi
  18. 32 1
      reflex/base.py
  19. 5 5
      reflex/compiler/compiler.py
  20. 4 4
      reflex/compiler/utils.py
  21. 2 0
      reflex/constants/__init__.py
  22. 1 0
      reflex/constants/base.py
  23. 2 2
      reflex/event.py
  24. 2 2
      reflex/middleware/hydrate_middleware.py
  25. 3 3
      reflex/middleware/middleware.py
  26. 70 42
      reflex/state.py
  27. 6 3
      reflex/testing.py
  28. 2 0
      reflex/utils/prerequisites.py
  29. 7 7
      reflex/vars.py
  30. 3 2
      reflex/vars.pyi
  31. 3 0
      tests/__init__.py
  32. 2 2
      tests/components/base/test_script.py
  33. 5 4
      tests/components/datadisplay/conftest.py
  34. 2 2
      tests/components/datadisplay/test_table.py
  35. 2 1
      tests/components/forms/test_debounce.py
  36. 2 1
      tests/components/layout/test_cond.py
  37. 2 2
      tests/components/layout/test_foreach.py
  38. 3 3
      tests/components/test_component.py
  39. 0 21
      tests/conftest.py
  40. 9 5
      tests/middleware/test_hydrate_middleware.py
  41. 4 2
      tests/states/__init__.py
  42. 4 3
      tests/states/mutation.py
  43. 5 4
      tests/states/upload.py
  44. 27 19
      tests/test_app.py
  45. 2 2
      tests/test_event.py
  46. 35 29
      tests/test_state.py
  47. 5 11
      tests/test_var.py
  48. 0 2
      tests/utils/test_format.py
  49. 2 2
      tests/utils/test_utils.py

+ 1 - 1
integration/test_background_task.py

@@ -93,7 +93,7 @@ def BackgroundTask():
             rx.button("Reset", on_click=State.reset_counter, id="reset"),
             rx.button("Reset", on_click=State.reset_counter, id="reset"),
         )
         )
 
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.compile()
     app.compile()
 
 

+ 1 - 1
integration/test_call_script.py

@@ -135,7 +135,7 @@ def CallScript():
             yield rx.call_script("inline_counter = 0; external_counter = 0")
             yield rx.call_script("inline_counter = 0; external_counter = 0")
             self.reset()
             self.reset()
 
 
-    app = rx.App(state=CallScriptState)
+    app = rx.App(state=rx.State)
     with open("assets/external.js", "w") as f:
     with open("assets/external.js", "w") as f:
         f.write(external_scripts)
         f.write(external_scripts)
 
 

+ 30 - 24
integration/test_client_storage.py

@@ -97,7 +97,7 @@ def ClientSide():
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
             rx.box(ClientSideSubSubState.l1s, id="l1s"),
         )
         )
 
 
-    app = rx.App(state=ClientSideState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.add_page(index, route="/foo")
     app.add_page(index, route="/foo")
     app.compile()
     app.compile()
@@ -263,7 +263,6 @@ async def test_client_side_state(
     state_var_input.send_keys("c7")
     state_var_input.send_keys("c7")
     input_value_input.send_keys("c7 value")
     input_value_input.send_keys("c7 value")
     set_sub_state_button.click()
     set_sub_state_button.click()
-
     state_var_input.send_keys("l1")
     state_var_input.send_keys("l1")
     input_value_input.send_keys("l1 value")
     input_value_input.send_keys("l1 value")
     set_sub_state_button.click()
     set_sub_state_button.click()
@@ -276,7 +275,6 @@ async def test_client_side_state(
     state_var_input.send_keys("l4")
     state_var_input.send_keys("l4")
     input_value_input.send_keys("l4 value")
     input_value_input.send_keys("l4 value")
     set_sub_state_button.click()
     set_sub_state_button.click()
-
     state_var_input.send_keys("c1s")
     state_var_input.send_keys("c1s")
     input_value_input.send_keys("c1s value")
     input_value_input.send_keys("c1s value")
     set_sub_sub_state_button.click()
     set_sub_sub_state_button.click()
@@ -285,28 +283,28 @@ async def test_client_side_state(
     set_sub_sub_state_button.click()
     set_sub_sub_state_button.click()
 
 
     exp_cookies = {
     exp_cookies = {
-        "client_side_state.client_side_sub_state.c1": {
+        "state.client_side_state.client_side_sub_state.c1": {
             "domain": "localhost",
             "domain": "localhost",
             "httpOnly": False,
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c1",
+            "name": "state.client_side_state.client_side_sub_state.c1",
             "path": "/",
             "path": "/",
             "sameSite": "Lax",
             "sameSite": "Lax",
             "secure": False,
             "secure": False,
             "value": "c1%20value",
             "value": "c1%20value",
         },
         },
-        "client_side_state.client_side_sub_state.c2": {
+        "state.client_side_state.client_side_sub_state.c2": {
             "domain": "localhost",
             "domain": "localhost",
             "httpOnly": False,
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c2",
+            "name": "state.client_side_state.client_side_sub_state.c2",
             "path": "/",
             "path": "/",
             "sameSite": "Lax",
             "sameSite": "Lax",
             "secure": False,
             "secure": False,
             "value": "c2%20value",
             "value": "c2%20value",
         },
         },
-        "client_side_state.client_side_sub_state.c4": {
+        "state.client_side_state.client_side_sub_state.c4": {
             "domain": "localhost",
             "domain": "localhost",
             "httpOnly": False,
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c4",
+            "name": "state.client_side_state.client_side_sub_state.c4",
             "path": "/",
             "path": "/",
             "sameSite": "Strict",
             "sameSite": "Strict",
             "secure": False,
             "secure": False,
@@ -321,19 +319,19 @@ async def test_client_side_state(
             "secure": False,
             "secure": False,
             "value": "c6%20value",
             "value": "c6%20value",
         },
         },
-        "client_side_state.client_side_sub_state.c7": {
+        "state.client_side_state.client_side_sub_state.c7": {
             "domain": "localhost",
             "domain": "localhost",
             "httpOnly": False,
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.c7",
+            "name": "state.client_side_state.client_side_sub_state.c7",
             "path": "/",
             "path": "/",
             "sameSite": "Lax",
             "sameSite": "Lax",
             "secure": False,
             "secure": False,
             "value": "c7%20value",
             "value": "c7%20value",
         },
         },
-        "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
+        "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
             "domain": "localhost",
             "domain": "localhost",
             "httpOnly": False,
             "httpOnly": False,
-            "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
+            "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
             "path": "/",
             "path": "/",
             "sameSite": "Lax",
             "sameSite": "Lax",
             "secure": False,
             "secure": False,
@@ -354,40 +352,45 @@ async def test_client_side_state(
     input_value_input.send_keys("c3 value")
     input_value_input.send_keys("c3 value")
     set_sub_state_button.click()
     set_sub_state_button.click()
     AppHarness._poll_for(
     AppHarness._poll_for(
-        lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
+        lambda: "state.client_side_state.client_side_sub_state.c3"
+        in cookie_info_map(driver)
     )
     )
-    c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
+    c3_cookie = cookie_info_map(driver)[
+        "state.client_side_state.client_side_sub_state.c3"
+    ]
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie.pop("expiry") is not None
     assert c3_cookie == {
     assert c3_cookie == {
         "domain": "localhost",
         "domain": "localhost",
         "httpOnly": False,
         "httpOnly": False,
-        "name": "client_side_state.client_side_sub_state.c3",
+        "name": "state.client_side_state.client_side_sub_state.c3",
         "path": "/",
         "path": "/",
         "sameSite": "Lax",
         "sameSite": "Lax",
         "secure": False,
         "secure": False,
         "value": "c3%20value",
         "value": "c3%20value",
     }
     }
     time.sleep(2)  # wait for c3 to expire
     time.sleep(2)  # wait for c3 to expire
-    assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
+    assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
+        driver
+    )
 
 
     local_storage_items = local_storage.items()
     local_storage_items = local_storage.items()
     local_storage_items.pop("chakra-ui-color-mode", None)
     local_storage_items.pop("chakra-ui-color-mode", None)
     assert (
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l1")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
         == "l1 value"
         == "l1 value"
     )
     )
     assert (
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l2")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
         == "l2 value"
         == "l2 value"
     )
     )
     assert local_storage_items.pop("l3") == "l3 value"
     assert local_storage_items.pop("l3") == "l3 value"
     assert (
     assert (
-        local_storage_items.pop("client_side_state.client_side_sub_state.l4")
+        local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
         == "l4 value"
         == "l4 value"
     )
     )
     assert (
     assert (
         local_storage_items.pop(
         local_storage_items.pop(
-            "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
+            "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
         )
         )
         == "l1s value"
         == "l1s value"
     )
     )
@@ -482,12 +485,15 @@ async def test_client_side_state(
 
 
     # make sure c5 cookie shows up on the `/foo` route
     # make sure c5 cookie shows up on the `/foo` route
     AppHarness._poll_for(
     AppHarness._poll_for(
-        lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
+        lambda: "state.client_side_state.client_side_sub_state.c5"
+        in cookie_info_map(driver)
     )
     )
-    assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
+    assert cookie_info_map(driver)[
+        "state.client_side_state.client_side_sub_state.c5"
+    ] == {
         "domain": "localhost",
         "domain": "localhost",
         "httpOnly": False,
         "httpOnly": False,
-        "name": "client_side_state.client_side_sub_state.c5",
+        "name": "state.client_side_state.client_side_sub_state.c5",
         "path": "/foo/",
         "path": "/foo/",
         "sameSite": "Lax",
         "sameSite": "Lax",
         "secure": False,
         "secure": False,

+ 1 - 1
integration/test_connection_banner.py

@@ -19,7 +19,7 @@ def ConnectionBanner():
     def index():
     def index():
         return rx.text("Hello World")
         return rx.text("Hello World")
 
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.compile()
     app.compile()
 
 

+ 5 - 3
integration/test_dynamic_routes.py

@@ -56,7 +56,7 @@ def DynamicRoute():
     def redirect_page():
     def redirect_page():
         return rx.fragment(rx.text("redirecting..."))
         return rx.fragment(rx.text("redirecting..."))
 
 
-    app = rx.App(state=DynamicState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index, route="/static/x", on_load=DynamicState.on_load)  # type: ignore
     app.add_page(index, route="/static/x", on_load=DynamicState.on_load)  # type: ignore
@@ -143,10 +143,12 @@ def poll_for_order(
             return await dynamic_route.get_state(token)
             return await dynamic_route.get_state(token)
 
 
         async def _check():
         async def _check():
-            return (await _backend_state()).order == exp_order
+            return (await _backend_state()).substates[
+                "dynamic_state"
+            ].order == exp_order
 
 
         await AppHarness._poll_for_async(_check)
         await AppHarness._poll_for_async(_check)
-        assert (await _backend_state()).order == exp_order
+        assert (await _backend_state()).substates["dynamic_state"].order == exp_order
 
 
     return _poll_for_order
     return _poll_for_order
 
 

+ 7 - 3
integration/test_event_actions.py

@@ -130,7 +130,7 @@ def TestEventAction():
             on_click=EventActionState.on_click("outer"),  # type: ignore
             on_click=EventActionState.on_click("outer"),  # type: ignore
         )
         )
 
 
-    app = rx.App(state=EventActionState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.compile()
     app.compile()
 
 
@@ -211,10 +211,14 @@ def poll_for_order(
             return await event_action.get_state(token)
             return await event_action.get_state(token)
 
 
         async def _check():
         async def _check():
-            return (await _backend_state()).order == exp_order
+            return (await _backend_state()).substates[
+                "event_action_state"
+            ].order == exp_order
 
 
         await AppHarness._poll_for_async(_check)
         await AppHarness._poll_for_async(_check)
-        assert (await _backend_state()).order == exp_order
+        assert (await _backend_state()).substates[
+            "event_action_state"
+        ].order == exp_order
 
 
     return _poll_for_order
     return _poll_for_order
 
 

+ 13 - 13
integration/test_event_chain.py

@@ -122,7 +122,7 @@ def EventChain():
             time.sleep(0.5)
             time.sleep(0.5)
             self.interim_value = "final"
             self.interim_value = "final"
 
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
 
 
     token_input = rx.input(
     token_input = rx.input(
         value=State.router.session.client_token, is_read_only=True, id="token"
         value=State.router.session.client_token, is_read_only=True, id="token"
@@ -401,12 +401,12 @@ async def test_event_chain_click(
     btn.click()
     btn.click()
 
 
     async def _has_all_events():
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
 
     await AppHarness._poll_for_async(_has_all_events)
     await AppHarness._poll_for_async(_has_all_events)
-    event_order = (await event_chain.get_state(token)).event_order
+    event_order = (await event_chain.get_state(token)).substates["state"].event_order
     assert event_order == exp_event_order
     assert event_order == exp_event_order
 
 
 
 
@@ -453,12 +453,12 @@ async def test_event_chain_on_load(
     token = assert_token(event_chain, driver)
     token = assert_token(event_chain, driver)
 
 
     async def _has_all_events():
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
 
     await AppHarness._poll_for_async(_has_all_events)
     await AppHarness._poll_for_async(_has_all_events)
-    backend_state = await event_chain.get_state(token)
+    backend_state = (await event_chain.get_state(token)).substates["state"]
     assert backend_state.event_order == exp_event_order
     assert backend_state.event_order == exp_event_order
     assert backend_state.is_hydrated is True
     assert backend_state.is_hydrated is True
 
 
@@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
     unmount_button.click()
     unmount_button.click()
 
 
     async def _has_all_events():
     async def _has_all_events():
-        return len((await event_chain.get_state(token)).event_order) == len(
-            exp_event_order
-        )
+        return len(
+            (await event_chain.get_state(token)).substates["state"].event_order
+        ) == len(exp_event_order)
 
 
     await AppHarness._poll_for_async(_has_all_events)
     await AppHarness._poll_for_async(_has_all_events)
-    event_order = (await event_chain.get_state(token)).event_order
+    event_order = (await event_chain.get_state(token)).substates["state"].event_order
     assert event_order == exp_event_order
     assert event_order == exp_event_order
 
 
 
 

+ 3 - 3
integration/test_form_submit.py

@@ -22,7 +22,7 @@ def FormSubmit():
         def form_submit(self, form_data: dict):
         def form_submit(self, form_data: dict):
             self.form_data = form_data
             self.form_data = form_data
 
 
-    app = rx.App(state=FormState)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
@@ -75,7 +75,7 @@ def FormSubmitName():
         def form_submit(self, form_data: dict):
         def form_submit(self, form_data: dict):
             self.form_data = form_data
             self.form_data = form_data
 
 
-    app = rx.App(state=FormState)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
@@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
     submit_input.click()
     submit_input.click()
 
 
     async def get_form_data():
     async def get_form_data():
-        return (await form_submit.get_state(token)).form_data
+        return (await form_submit.get_state(token)).substates["form_state"].form_data
 
 
     # wait for the form data to arrive at the backend
     # wait for the form data to arrive at the backend
     form_data = await AppHarness._poll_for_async(get_form_data)
     form_data = await AppHarness._poll_for_async(get_form_data)

+ 12 - 8
integration/test_input.py

@@ -16,7 +16,7 @@ def FullyControlledInput():
     class State(rx.State):
     class State(rx.State):
         text: str = "initial"
         text: str = "initial"
 
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():
@@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("foo")
     debounce_input.send_keys("foo")
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "ifoonitial"
     assert debounce_input.get_attribute("value") == "ifoonitial"
-    assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
     assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
 
 
     # clear the input on the backend
     # clear the input on the backend
     async with fully_controlled_input.modify_state(token) as state:
     async with fully_controlled_input.modify_state(token) as state:
-        state.text = ""
-    assert (await fully_controlled_input.get_state(token)).text == ""
+        state.substates["state"].text = ""
+    assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
     assert (
     assert (
         fully_controlled_input.poll_for_value(
         fully_controlled_input.poll_for_value(
             debounce_input, exp_not_equal="ifoonitial"
             debounce_input, exp_not_equal="ifoonitial"
@@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     debounce_input.send_keys("getting testing done")
     debounce_input.send_keys("getting testing done")
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "getting testing done"
     assert debounce_input.get_attribute("value") == "getting testing done"
-    assert (
-        await fully_controlled_input.get_state(token)
-    ).text == "getting testing done"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
     assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
 
 
     # type into the on_change input
     # type into the on_change input
@@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
     time.sleep(0.5)
     time.sleep(0.5)
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert debounce_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
     assert on_change_input.get_attribute("value") == "overwrite the state"
-    assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
+    assert (await fully_controlled_input.get_state(token)).substates[
+        "state"
+    ].text == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
     assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
 
 
     clear_button.click()
     clear_button.click()

+ 2 - 2
integration/test_login_flow.py

@@ -42,7 +42,7 @@ def LoginSample():
             rx.button("Do it", on_click=State.login, id="doit"),
             rx.button("Do it", on_click=State.login, id="doit"),
         )
         )
 
 
-    app = rx.App(state=State)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.add_page(login)
     app.add_page(login)
     app.compile()
     app.compile()
@@ -137,6 +137,6 @@ def test_login_flow(
     logout_button = driver.find_element(By.ID, "logout")
     logout_button = driver.find_element(By.ID, "logout")
     logout_button.click()
     logout_button.click()
 
 
-    assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "")
+    assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
     with pytest.raises(NoSuchElementException):
     with pytest.raises(NoSuchElementException):
         driver.find_element(By.ID, "auth-token")
         driver.find_element(By.ID, "auth-token")

+ 1 - 1
integration/test_radix_themes.py

@@ -81,7 +81,7 @@ def RadixThemesApp():
         )
         )
 
 
     app = rx.App(
     app = rx.App(
-        state=State,
+        state=rx.State,
         theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
         theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
     )
     )
     app.add_page(index)
     app.add_page(index)

+ 1 - 1
integration/test_server_side_event.py

@@ -33,7 +33,7 @@ def ServerSideEvent():
         def set_value_return_c(self):
         def set_value_return_c(self):
             return rx.set_value("c", "")
             return rx.set_value("c", "")
 
 
-    app = rx.App(state=SSState)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():

+ 1 - 1
integration/test_table.py

@@ -26,7 +26,7 @@ def Table():
 
 
         caption: str = "random caption"
         caption: str = "random caption"
 
 
-    app = rx.App(state=TableState)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():

+ 7 - 7
integration/test_upload.py

@@ -113,7 +113,7 @@ def UploadFile():
             ),
             ),
         )
         )
 
 
-    app = rx.App(state=UploadState)
+    app = rx.App(state=rx.State)
     app.add_page(index)
     app.add_page(index)
     app.compile()
     app.compile()
 
 
@@ -192,7 +192,7 @@ async def test_upload_file(
 
 
     # look up the backend state and assert on uploaded contents
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
     async def get_file_data():
-        return (await upload_file.get_state(token))._file_data
+        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
 
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
     assert isinstance(file_data, dict)
@@ -205,8 +205,8 @@ async def test_upload_file(
     state = await upload_file.get_state(token)
     state = await upload_file.get_state(token)
     if secondary:
     if secondary:
         # only the secondary form tracks progress and chain events
         # only the secondary form tracks progress and chain events
-        assert state.event_order.count("upload_progress") == 1
-        assert state.event_order.count("chain_event") == 1
+        assert state.substates["upload_state"].event_order.count("upload_progress") == 1
+        assert state.substates["upload_state"].event_order.count("chain_event") == 1
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
 
 
     # look up the backend state and assert on uploaded contents
     # look up the backend state and assert on uploaded contents
     async def get_file_data():
     async def get_file_data():
-        return (await upload_file.get_state(token))._file_data
+        return (await upload_file.get_state(token)).substates["upload_state"]._file_data
 
 
     file_data = await AppHarness._poll_for_async(get_file_data)
     file_data = await AppHarness._poll_for_async(get_file_data)
     assert isinstance(file_data, dict)
     assert isinstance(file_data, dict)
@@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
 
 
     # look up the backend state and assert on progress
     # look up the backend state and assert on progress
     state = await upload_file.get_state(token)
     state = await upload_file.get_state(token)
-    assert state.progress_dicts
-    assert exp_name not in state._file_data
+    assert state.substates["upload_state"].progress_dicts
+    assert exp_name not in state.substates["upload_state"]._file_data
 
 
     target_file.unlink()
     target_file.unlink()

+ 1 - 1
integration/test_var_operations.py

@@ -30,7 +30,7 @@ def VarOperations():
         dict2: dict = {3: 4}
         dict2: dict = {3: 4}
         html_str: str = "<div>hello</div>"
         html_str: str = "<div>hello</div>"
 
 
-    app = rx.App(state=VarOperationState)
+    app = rx.App(state=rx.State)
 
 
     @app.add_page
     @app.add_page
     def index():
     def index():

+ 15 - 13
reflex/app.py

@@ -57,6 +57,7 @@ from reflex.route import (
     verify_route_validity,
     verify_route_validity,
 )
 )
 from reflex.state import (
 from reflex.state import (
+    BaseState,
     RouterData,
     RouterData,
     State,
     State,
     StateManager,
     StateManager,
@@ -98,7 +99,7 @@ class App(Base):
     socket_app: Optional[ASGIApp] = None
     socket_app: Optional[ASGIApp] = None
 
 
     # The state class to use for the app.
     # The state class to use for the app.
-    state: Optional[Type[State]] = None
+    state: Optional[Type[BaseState]] = None
 
 
     # Class to manage many client states.
     # Class to manage many client states.
     _state_manager: Optional[StateManager] = None
     _state_manager: Optional[StateManager] = None
@@ -149,25 +150,24 @@ class App(Base):
                 "`connect_error_component` is deprecated, use `overlay_component` instead"
                 "`connect_error_component` is deprecated, use `overlay_component` instead"
             )
             )
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
-        state_subclasses = State.__subclasses__()
-        inferred_state = state_subclasses[-1] if state_subclasses else None
+        state_subclasses = BaseState.__subclasses__()
         is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
         is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
 
 
-        # Special case to allow test cases have multiple subclasses of rx.State.
+        # Special case to allow test cases have multiple subclasses of rx.BaseState.
         if not is_testing_env:
         if not is_testing_env:
-            # Only one State class is allowed.
+            # Only one Base State class is allowed.
             if len(state_subclasses) > 1:
             if len(state_subclasses) > 1:
                 raise ValueError(
                 raise ValueError(
-                    "rx.State has been subclassed multiple times. Only one subclass is allowed"
+                    "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
                 )
                 )
 
 
             # verify that provided state is valid
             # verify that provided state is valid
-            if self.state and inferred_state and self.state is not inferred_state:
+            if self.state and self.state is not State:
                 console.warn(
                 console.warn(
                     f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
                     f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
-                    f" Defaulting to root state: ({inferred_state.__name__})"
+                    f" Defaulting to root state: ({State.__name__})"
                 )
                 )
-            self.state = inferred_state
+            self.state = State
         # Get the config
         # Get the config
         config = get_config()
         config = get_config()
 
 
@@ -265,7 +265,7 @@ class App(Base):
             raise ValueError("The state manager has not been initialized.")
             raise ValueError("The state manager has not been initialized.")
         return self._state_manager
         return self._state_manager
 
 
-    async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
+    async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
         """Preprocess the event.
         """Preprocess the event.
 
 
         This is where middleware can modify the event before it is processed.
         This is where middleware can modify the event before it is processed.
@@ -290,7 +290,7 @@ class App(Base):
                 return out  # type: ignore
                 return out  # type: ignore
 
 
     async def postprocess(
     async def postprocess(
-        self, state: State, event: Event, update: StateUpdate
+        self, state: BaseState, event: Event, update: StateUpdate
     ) -> StateUpdate:
     ) -> StateUpdate:
         """Postprocess the event.
         """Postprocess the event.
 
 
@@ -764,7 +764,7 @@ class App(Base):
                 future.result()
                 future.result()
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.
         """Modify the state out of band.
 
 
         Args:
         Args:
@@ -792,7 +792,9 @@ class App(Base):
                     sid=state.router.session.session_id,
                     sid=state.router.session.session_id,
                 )
                 )
 
 
-    def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
+    def _process_background(
+        self, state: BaseState, event: Event
+    ) -> asyncio.Task | None:
         """Process an event in the background and emit updates as they arrive.
         """Process an event in the background and emit updates as they arrive.
 
 
         Args:
         Args:

+ 2 - 1
reflex/app.pyi

@@ -33,6 +33,7 @@ from reflex.route import (
 )
 )
 from reflex.state import (
 from reflex.state import (
     State as State,
     State as State,
+    BaseState as BaseState,
     StateManager as StateManager,
     StateManager as StateManager,
     StateUpdate as StateUpdate,
     StateUpdate as StateUpdate,
 )
 )
@@ -69,7 +70,7 @@ class App(Base):
     api: FastAPI
     api: FastAPI
     sio: Optional[AsyncServer]
     sio: Optional[AsyncServer]
     socket_app: Optional[ASGIApp]
     socket_app: Optional[ASGIApp]
-    state: Type[State]
+    state: Type[BaseState]
     state_manager: StateManager
     state_manager: StateManager
     style: ComponentStyle
     style: ComponentStyle
     middleware: List[Middleware]
     middleware: List[Middleware]

+ 32 - 1
reflex/base.py

@@ -1,11 +1,42 @@
 """Define the base Reflex class."""
 """Define the base Reflex class."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any
+import os
+from typing import Any, List, Type
 
 
 import pydantic
 import pydantic
+from pydantic import BaseModel
 from pydantic.fields import ModelField
 from pydantic.fields import ModelField
 
 
+from reflex import constants
+
+
+def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
+    """Ensure that the field's name does not shadow an existing attribute of the model.
+
+    Args:
+        bases: List of base models to check for shadowed attrs.
+        field_name: name of attribute
+
+    Raises:
+        NameError: If state var field shadows another in its parent state
+    """
+    reload = os.getenv(constants.RELOAD_CONFIG) == "True"
+    for base in bases:
+        try:
+            if not reload and getattr(base, field_name, None):
+                pass
+        except TypeError as te:
+            raise NameError(
+                f'State var "{field_name}" in {base} has been shadowed by a substate var; '
+                f'use a different field name instead".'
+            ) from te
+
+
+# monkeypatch pydantic validate_field_name method to skip validating
+# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
+pydantic.main.validate_field_name = validate_field_name  # type: ignore
+
 
 
 class Base(pydantic.BaseModel):
 class Base(pydantic.BaseModel):
     """The base class subclassed by all Reflex classes.
     """The base class subclassed by all Reflex classes.

+ 5 - 5
reflex/compiler/compiler.py

@@ -15,7 +15,7 @@ from reflex.components.component import (
     StatefulComponent,
     StatefulComponent,
 )
 )
 from reflex.config import get_config
 from reflex.config import get_config
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
 
 
 
 
@@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
     return templates.THEME.render(theme=theme)
     return templates.THEME.render(theme=theme)
 
 
 
 
-def _compile_contexts(state: Optional[Type[State]]) -> str:
+def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
     """Compile the initial state and contexts.
     """Compile the initial state and contexts.
 
 
     Args:
     Args:
@@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
 
 
 def _compile_page(
 def _compile_page(
     component: Component,
     component: Component,
-    state: Type[State],
+    state: Type[BaseState],
 ) -> str:
 ) -> str:
     """Compile the component given the app state.
     """Compile the component given the app state.
 
 
@@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
     return output_path, code
     return output_path, code
 
 
 
 
-def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
+def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
     """Compile the initial state / context.
     """Compile the initial state / context.
 
 
     Args:
     Args:
@@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
 
 
 
 
 def compile_page(
 def compile_page(
-    path: str, component: Component, state: Type[State]
+    path: str, component: Component, state: Type[BaseState]
 ) -> tuple[str, str]:
 ) -> tuple[str, str]:
     """Compile a single page.
     """Compile a single page.
 
 

+ 4 - 4
reflex/compiler/utils.py

@@ -21,7 +21,7 @@ from reflex.components.base import (
     Title,
     Title,
 )
 )
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.components.component import Component, ComponentStyle, CustomComponent
-from reflex.state import Cookie, LocalStorage, State
+from reflex.state import BaseState, Cookie, LocalStorage
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils import console, format, imports, path_ops
 
 
@@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
     }
     }
 
 
 
 
-def compile_state(state: Type[State]) -> dict:
+def compile_state(state: Type[BaseState]) -> dict:
     """Compile the state of the app.
     """Compile the state of the app.
 
 
     Args:
     Args:
@@ -170,7 +170,7 @@ def _compile_client_storage_field(
 
 
 
 
 def _compile_client_storage_recursive(
 def _compile_client_storage_recursive(
-    state: Type[State],
+    state: Type[BaseState],
 ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
 ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
     """Compile the client-side storage for the given state recursively.
     """Compile the client-side storage for the given state recursively.
 
 
@@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
     return cookies, local_storage
     return cookies, local_storage
 
 
 
 
-def compile_client_storage(state: Type[State]) -> dict[str, dict]:
+def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
     """Compile the client-side storage for the given state.
     """Compile the client-side storage for the given state.
 
 
     Args:
     Args:

+ 2 - 0
reflex/constants/__init__.py

@@ -6,6 +6,7 @@ from .base import (
     LOCAL_STORAGE,
     LOCAL_STORAGE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     POLLING_MAX_HTTP_BUFFER_SIZE,
     PYTEST_CURRENT_TEST,
     PYTEST_CURRENT_TEST,
+    RELOAD_CONFIG,
     SKIP_COMPILE_ENV_VAR,
     SKIP_COMPILE_ENV_VAR,
     ColorMode,
     ColorMode,
     Dirs,
     Dirs,
@@ -85,6 +86,7 @@ __ALL__ = [
     PYTEST_CURRENT_TEST,
     PYTEST_CURRENT_TEST,
     PRODUCTION_BACKEND_URL,
     PRODUCTION_BACKEND_URL,
     Reflex,
     Reflex,
+    RELOAD_CONFIG,
     RequirementsTxt,
     RequirementsTxt,
     RouteArgType,
     RouteArgType,
     RouteRegex,
     RouteRegex,

+ 1 - 0
reflex/constants/base.py

@@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
 # Testing variables.
 # Testing variables.
 # Testing os env set by pytest when running a test case.
 # Testing os env set by pytest when running a test case.
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
 PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
+RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"

+ 2 - 2
reflex/event.py

@@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from reflex.state import State
+    from reflex.state import BaseState
 
 
 
 
 class Event(Base):
 class Event(Base):
@@ -64,7 +64,7 @@ def background(fn):
 
 
 
 
 def _no_chain_background_task(
 def _no_chain_background_task(
-    state_cls: Type["State"], name: str, fn: Callable
+    state_cls: Type["BaseState"], name: str, fn: Callable
 ) -> Callable:
 ) -> Callable:
     """Protect against directly chaining a background task from another event handler.
     """Protect against directly chaining a background task from another event handler.
 
 

+ 2 - 2
reflex/middleware/hydrate_middleware.py

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
 from reflex import constants
 from reflex import constants
 from reflex.event import Event, fix_events, get_hydrate_event
 from reflex.event import Event, fix_events, get_hydrate_event
 from reflex.middleware.middleware import Middleware
 from reflex.middleware.middleware import Middleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 from reflex.utils import format
 from reflex.utils import format
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
     """Middleware to handle initial app hydration."""
     """Middleware to handle initial app hydration."""
 
 
     async def preprocess(
     async def preprocess(
-        self, app: App, state: State, event: Event
+        self, app: App, state: BaseState, event: Event
     ) -> Optional[StateUpdate]:
     ) -> Optional[StateUpdate]:
         """Preprocess the event.
         """Preprocess the event.
 
 

+ 3 - 3
reflex/middleware/middleware.py

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.event import Event
 from reflex.event import Event
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.app import App
     from reflex.app import App
@@ -16,7 +16,7 @@ class Middleware(Base, ABC):
     """Middleware to preprocess and postprocess requests."""
     """Middleware to preprocess and postprocess requests."""
 
 
     async def preprocess(
     async def preprocess(
-        self, app: App, state: State, event: Event
+        self, app: App, state: BaseState, event: Event
     ) -> Optional[StateUpdate]:
     ) -> Optional[StateUpdate]:
         """Preprocess the event.
         """Preprocess the event.
 
 
@@ -31,7 +31,7 @@ class Middleware(Base, ABC):
         return None
         return None
 
 
     async def postprocess(
     async def postprocess(
-        self, app: App, state: State, event: Event, update: StateUpdate
+        self, app: App, state: BaseState, event: Event, update: StateUpdate
     ) -> StateUpdate:
     ) -> StateUpdate:
         """Postprocess the event.
         """Postprocess the event.
 
 

+ 70 - 42
reflex/state.py

@@ -7,6 +7,7 @@ import copy
 import functools
 import functools
 import inspect
 import inspect
 import json
 import json
+import os
 import traceback
 import traceback
 import urllib.parse
 import urllib.parse
 import uuid
 import uuid
@@ -81,7 +82,7 @@ class HeaderData(Base):
 class PageData(Base):
 class PageData(Base):
     """An object containing page data."""
     """An object containing page data."""
 
 
-    host: str = ""  #  repeated with self.headers.origin (remove or keep the duplicate?)
+    host: str = ""  # repeated with self.headers.origin (remove or keep the duplicate?)
     path: str = ""
     path: str = ""
     raw_path: str = ""
     raw_path: str = ""
     full_path: str = ""
     full_path: str = ""
@@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
 }
 }
 
 
 
 
-class State(Base, ABC, extra=pydantic.Extra.allow):
+class BaseState(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
     """The state of the app."""
 
 
     # A map from the var name to the var.
     # A map from the var name to the var.
@@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The event handlers.
     # The event handlers.
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
     event_handlers: ClassVar[Dict[str, EventHandler]] = {}
 
 
+    # A set of subclassses of this class.
+    class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
+
     # Mapping of var name to set of computed variables that depend on it
     # Mapping of var name to set of computed variables that depend on it
     _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
     _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
 
 
@@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     _always_dirty_substates: ClassVar[Set[str]] = set()
     _always_dirty_substates: ClassVar[Set[str]] = set()
 
 
     # The parent state.
     # The parent state.
-    parent_state: Optional[State] = None
+    parent_state: Optional[BaseState] = None
 
 
     # The substates of the state.
     # The substates of the state.
-    substates: Dict[str, State] = {}
+    substates: Dict[str, BaseState] = {}
 
 
     # The set of dirty vars.
     # The set of dirty vars.
     dirty_vars: Set[str] = set()
     dirty_vars: Set[str] = set()
@@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # The router data for the current page
     # The router data for the current page
     router: RouterData = RouterData()
     router: RouterData = RouterData()
 
 
-    # The hydrated bool.
-    is_hydrated: bool = False
-
-    def __init__(self, *args, parent_state: State | None = None, **kwargs):
+    def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
         """Initialize the state.
         """Initialize the state.
 
 
         Args:
         Args:
@@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             parent_state: The parent state.
             parent_state: The parent state.
             **kwargs: The kwargs to pass to the Pydantic init method.
             **kwargs: The kwargs to pass to the Pydantic init method.
 
 
-        Raises:
-            ValueError: If a substate class shadows another.
         """
         """
         kwargs["parent_state"] = parent_state
         kwargs["parent_state"] = parent_state
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
         # Setup the substates.
         # Setup the substates.
         for substate in self.get_substates():
         for substate in self.get_substates():
-            substate_name = substate.get_name()
-            if substate_name in self.substates:
-                raise ValueError(
-                    f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
-                    f"substate classes is not allowed."
-                )
-            self.substates[substate_name] = substate(parent_state=self)
+            self.substates[substate.get_name()] = substate(parent_state=self)
         # Convert the event handlers to functions.
         # Convert the event handlers to functions.
         self._init_event_handlers()
         self._init_event_handlers()
 
 
         # Create a fresh copy of the backend variables for this instance
         # Create a fresh copy of the backend variables for this instance
         self._backend_vars = copy.deepcopy(self.backend_vars)
         self._backend_vars = copy.deepcopy(self.backend_vars)
 
 
-    def _init_event_handlers(self, state: State | None = None):
+    def _init_event_handlers(self, state: BaseState | None = None):
         """Initialize event handlers.
         """Initialize event handlers.
 
 
         Allow event handlers to be called directly on the instance. This is
         Allow event handlers to be called directly on the instance. This is
@@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Args:
         Args:
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
+
+        Raises:
+            ValueError: If a substate class shadows another.
         """
         """
+        is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
         super().__init_subclass__(**kwargs)
         super().__init_subclass__(**kwargs)
         # Event handlers should not shadow builtin state methods.
         # Event handlers should not shadow builtin state methods.
         cls._check_overridden_methods()
         cls._check_overridden_methods()
 
 
+        # Reset subclass tracking for this class.
+        cls.class_subclasses = set()
+
         # Get the parent vars.
         # Get the parent vars.
         parent_state = cls.get_parent_state()
         parent_state = cls.get_parent_state()
         if parent_state is not None:
         if parent_state is not None:
             cls.inherited_vars = parent_state.vars
             cls.inherited_vars = parent_state.vars
             cls.inherited_backend_vars = parent_state.backend_vars
             cls.inherited_backend_vars = parent_state.backend_vars
 
 
+            # Check if another substate class with the same name has already been defined.
+            if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
+                if is_testing_env:
+                    # Clear existing subclass with same name when app is reloaded via
+                    # utils.prerequisites.get_app(reload=True)
+                    parent_state.class_subclasses = set(
+                        c
+                        for c in parent_state.class_subclasses
+                        if c.__name__ != cls.__name__
+                    )
+                else:
+                    # During normal operation, subclasses cannot have the same name, even if they are
+                    # defined in different modules.
+                    raise ValueError(
+                        f"The substate class '{cls.__name__}' has been defined multiple times. "
+                        "Shadowing substate classes is not allowed."
+                    )
+            # Track this new subclass in the parent state's subclasses set.
+            parent_state.class_subclasses.add(cls)
+
         cls.new_backend_vars = {
         cls.new_backend_vars = {
             name: value
             name: value
             for name, value in cls.__dict__.items()
             for name, value in cls.__dict__.items()
@@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
 
     @classmethod
     @classmethod
     @functools.lru_cache()
     @functools.lru_cache()
-    def get_parent_state(cls) -> Type[State] | None:
+    def get_parent_state(cls) -> Type[BaseState] | None:
         """Get the parent state.
         """Get the parent state.
 
 
         Returns:
         Returns:
@@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         parent_states = [
         parent_states = [
             base
             base
             for base in cls.__bases__
             for base in cls.__bases__
-            if types._issubclass(base, State) and base is not State
+            if types._issubclass(base, BaseState) and base is not BaseState
         ]
         ]
         assert len(parent_states) < 2, "Only one parent state is allowed."
         assert len(parent_states) < 2, "Only one parent state is allowed."
         return parent_states[0] if len(parent_states) == 1 else None  # type: ignore
         return parent_states[0] if len(parent_states) == 1 else None  # type: ignore
 
 
     @classmethod
     @classmethod
-    @functools.lru_cache()
-    def get_substates(cls) -> set[Type[State]]:
+    def get_substates(cls) -> set[Type[BaseState]]:
         """Get the substates of the state.
         """Get the substates of the state.
 
 
         Returns:
         Returns:
             The substates of the state.
             The substates of the state.
         """
         """
-        return set(cls.__subclasses__())
+        return cls.class_subclasses
 
 
     @classmethod
     @classmethod
     @functools.lru_cache()
     @functools.lru_cache()
@@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
 
     @classmethod
     @classmethod
     @functools.lru_cache()
     @functools.lru_cache()
-    def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
+    def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
         """Get the class substate.
         """Get the class substate.
 
 
         Args:
         Args:
@@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         """
         """
         return {
         return {
             func[0]: func[1]
             func[0]: func[1]
-            for func in inspect.getmembers(State, predicate=inspect.isfunction)
+            for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
             if not func[0].startswith("__")
             if not func[0].startswith("__")
         }
         }
 
 
@@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         for substate in self.substates.values():
         for substate in self.substates.values():
             substate._reset_client_storage()
             substate._reset_client_storage()
 
 
-    def get_substate(self, path: Sequence[str]) -> State | None:
+    def get_substate(self, path: Sequence[str]) -> BaseState | None:
         """Get the substate.
         """Get the substate.
 
 
         Args:
         Args:
@@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
 
 
     def _get_event_handler(
     def _get_event_handler(
         self, event: Event
         self, event: Event
-    ) -> tuple[State | StateProxy, EventHandler]:
+    ) -> tuple[BaseState | StateProxy, EventHandler]:
         """Get the event handler for the given event.
         """Get the event handler for the given event.
 
 
         Args:
         Args:
@@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         )
         )
 
 
     async def _process_event(
     async def _process_event(
-        self, handler: EventHandler, state: State | StateProxy, payload: Dict
+        self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
     ) -> AsyncIterator[StateUpdate]:
     ) -> AsyncIterator[StateUpdate]:
         """Process event.
         """Process event.
 
 
@@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             d.update(substate_d)
             d.update(substate_d)
         return d
         return d
 
 
-    async def __aenter__(self) -> State:
+    async def __aenter__(self) -> BaseState:
         """Enter the async context manager protocol.
         """Enter the async context manager protocol.
 
 
         This should not be used for the State class, but exists for
         This should not be used for the State class, but exists for
@@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         pass
         pass
 
 
 
 
+class State(BaseState):
+    """The app Base State."""
+
+    # The hydrated bool.
+    is_hydrated: bool = False
+
+
 class StateProxy(wrapt.ObjectProxy):
 class StateProxy(wrapt.ObjectProxy):
     """Proxy of a state instance to control mutability of vars for a background task.
     """Proxy of a state instance to control mutability of vars for a background task.
 
 
@@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
     """A class to manage many client states."""
     """A class to manage many client states."""
 
 
     # The state class to use.
     # The state class to use.
-    state: Type[State]
+    state: Type[BaseState]
 
 
     @classmethod
     @classmethod
-    def create(cls, state: Type[State]):
+    def create(cls, state: Type[BaseState]):
         """Create a new state manager.
         """Create a new state manager.
 
 
         Args:
         Args:
@@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
         return StateManagerMemory(state=state)
         return StateManagerMemory(state=state)
 
 
     @abstractmethod
     @abstractmethod
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
         """Get the state for a token.
 
 
         Args:
         Args:
@@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: BaseState):
         """Set the state for a token.
         """Set the state for a token.
 
 
         Args:
         Args:
@@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
 
 
     @abstractmethod
     @abstractmethod
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
         """Modify the state for a token while holding exclusive lock.
 
 
         Args:
         Args:
@@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
     """A state manager that stores states in memory."""
     """A state manager that stores states in memory."""
 
 
     # The mapping of client ids to states.
     # The mapping of client ids to states.
-    states: Dict[str, State] = {}
+    states: Dict[str, BaseState] = {}
 
 
     # The mutex ensures the dict of mutexes is updated exclusively
     # The mutex ensures the dict of mutexes is updated exclusively
     _state_manager_lock = asyncio.Lock()
     _state_manager_lock = asyncio.Lock()
@@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
             "_states_locks": {"exclude": True},
             "_states_locks": {"exclude": True},
         }
         }
 
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
         """Get the state for a token.
 
 
         Args:
         Args:
@@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
             self.states[token] = self.state()
             self.states[token] = self.state()
         return self.states[token]
         return self.states[token]
 
 
-    async def set_state(self, token: str, state: State):
+    async def set_state(self, token: str, state: BaseState):
         """Set the state for a token.
         """Set the state for a token.
 
 
         Args:
         Args:
@@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
         pass
         pass
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
         """Modify the state for a token while holding exclusive lock.
 
 
         Args:
         Args:
@@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
         b"evicted",
         b"evicted",
     }
     }
 
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state for a token.
         """Get the state for a token.
 
 
         Args:
         Args:
@@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
             return await self.get_state(token)
             return await self.get_state(token)
         return cloudpickle.loads(redis_state)
         return cloudpickle.loads(redis_state)
 
 
-    async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
+    async def set_state(
+        self, token: str, state: BaseState, lock_id: bytes | None = None
+    ):
         """Set the state for a token.
         """Set the state for a token.
 
 
         Args:
         Args:
@@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
         await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
         await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[State]:
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state for a token while holding exclusive lock.
         """Modify the state for a token while holding exclusive lock.
 
 
         Args:
         Args:
@@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
 
 
     __mutable_types__ = (list, dict, set, Base)
     __mutable_types__ = (list, dict, set, Base)
 
 
-    def __init__(self, wrapped: Any, state: State, field_name: str):
+    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
         """Create a proxy for a mutable object that tracks changes.
         """Create a proxy for a mutable object that tracks changes.
 
 
         Args:
         Args:

+ 6 - 3
reflex/testing.py

@@ -38,7 +38,7 @@ import reflex.utils.build
 import reflex.utils.exec
 import reflex.utils.exec
 import reflex.utils.prerequisites
 import reflex.utils.prerequisites
 import reflex.utils.processes
 import reflex.utils.processes
-from reflex.state import State, StateManagerMemory, StateManagerRedis
+from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
 
 
 try:
 try:
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
     from selenium import webdriver  # pyright: ignore [reportMissingImports]
@@ -162,6 +162,9 @@ class AppHarness:
         with chdir(self.app_path):
         with chdir(self.app_path):
             # ensure config and app are reloaded when testing different app
             # ensure config and app are reloaded when testing different app
             reflex.config.get_config(reload=True)
             reflex.config.get_config(reload=True)
+            # reset rx.State subclasses
+            State.class_subclasses.clear()
+            # self.app_module.app.
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
         self.app_instance = self.app_module.app
         if isinstance(self.app_instance.state_manager, StateManagerRedis):
         if isinstance(self.app_instance.state_manager, StateManagerRedis):
@@ -434,7 +437,7 @@ class AppHarness:
         self._frontends.append(driver)
         self._frontends.append(driver)
         return driver
         return driver
 
 
-    async def get_state(self, token: str) -> State:
+    async def get_state(self, token: str) -> BaseState:
         """Get the state associated with the given token.
         """Get the state associated with the given token.
 
 
         Args:
         Args:
@@ -561,7 +564,7 @@ class AppHarness:
             )
             )
         return element.get_attribute("value")
         return element.get_attribute("value")
 
 
-    def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
+    def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
         """Poll app state_manager for any connected clients.
         """Poll app state_manager for any connected clients.
 
 
         Args:
         Args:

+ 2 - 0
reflex/utils/prerequisites.py

@@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
     Returns:
     Returns:
         The app based on the default config.
         The app based on the default config.
     """
     """
+    os.environ[constants.RELOAD_CONFIG] = str(reload)
     config = get_config()
     config = get_config()
     module = ".".join([config.app_name, config.app_name])
     module = ".".join([config.app_name, config.app_name])
     sys.path.insert(0, os.getcwd())
     sys.path.insert(0, os.getcwd())
     app = __import__(module, fromlist=(constants.CompileVars.APP,))
     app = __import__(module, fromlist=(constants.CompileVars.APP,))
     if reload:
     if reload:
+
         importlib.reload(app)
         importlib.reload(app)
     return app
     return app
 
 

+ 7 - 7
reflex/vars.py

@@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from reflex.state import State
+    from reflex.state import BaseState
 
 
 # Set of unique variable names.
 # Set of unique variable names.
 USED_VARIABLES = set()
 USED_VARIABLES = set()
@@ -1472,7 +1472,7 @@ class Var:
             )
             )
         )
         )
 
 
-    def _var_set_state(self, state: Type[State] | str) -> Any:
+    def _var_set_state(self, state: Type[BaseState] | str) -> Any:
         """Set the state of the var.
         """Set the state of the var.
 
 
         Args:
         Args:
@@ -1604,14 +1604,14 @@ class BaseVar(Var):
             return setter
             return setter
         return ".".join((self._var_data.state, setter))
         return ".".join((self._var_data.state, setter))
 
 
-    def get_setter(self) -> Callable[[State, Any], None]:
+    def get_setter(self) -> Callable[[BaseState, Any], None]:
         """Get the var's setter function.
         """Get the var's setter function.
 
 
         Returns:
         Returns:
             A function that that creates a setter for the var.
             A function that that creates a setter for the var.
         """
         """
 
 
-        def setter(state: State, value: Any):
+        def setter(state: BaseState, value: Any):
             """Get the setter for the var.
             """Get the setter for the var.
 
 
             Args:
             Args:
@@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        fget: Callable[[State], Any],
-        fset: Callable[[State, Any], None] | None = None,
-        fdel: Callable[[State], Any] | None = None,
+        fget: Callable[[BaseState], Any],
+        fset: Callable[[BaseState, Any], None] | None = None,
+        fdel: Callable[[BaseState], Any] | None = None,
         doc: str | None = None,
         doc: str | None = None,
         **kwargs,
         **kwargs,
     ):
     ):

+ 3 - 2
reflex/vars.pyi

@@ -5,6 +5,7 @@ from _typeshed import Incomplete
 from reflex import constants as constants
 from reflex import constants as constants
 from reflex.base import Base as Base
 from reflex.base import Base as Base
 from reflex.state import State as State
 from reflex.state import State as State
+from reflex.state import BaseState as BaseState
 from reflex.utils import console as console, format as format, types as types
 from reflex.utils import console as console, format as format, types as types
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
 from types import FunctionType
 from types import FunctionType
@@ -110,7 +111,7 @@ class Var:
     def as_ref(self) -> Var: ...
     def as_ref(self) -> Var: ...
     @property
     @property
     def _var_full_name(self) -> str: ...
     def _var_full_name(self) -> str: ...
-    def _var_set_state(self, state: Type[State] | str) -> Any: ...
+    def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
 
 
 @dataclass(eq=False)
 @dataclass(eq=False)
 class BaseVar(Var):
 class BaseVar(Var):
@@ -123,7 +124,7 @@ class BaseVar(Var):
     def __hash__(self) -> int: ...
     def __hash__(self) -> int: ...
     def get_default_value(self) -> Any: ...
     def get_default_value(self) -> Any: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
-    def get_setter(self) -> Callable[[State, Any], None]: ...
+    def get_setter(self) -> Callable[[BaseState, Any], None]: ...
 
 
 @dataclass(init=False)
 @dataclass(init=False)
 class ComputedVar(Var):
 class ComputedVar(Var):

+ 3 - 0
tests/__init__.py

@@ -1 +1,4 @@
 """Root directory for tests."""
 """Root directory for tests."""
+import os
+
+from reflex import constants

+ 2 - 2
tests/components/base/test_script.py

@@ -2,7 +2,7 @@
 import pytest
 import pytest
 
 
 from reflex.components.base.script import Script
 from reflex.components.base.script import Script
-from reflex.state import State
+from reflex.state import BaseState
 
 
 
 
 def test_script_inline():
 def test_script_inline():
@@ -31,7 +31,7 @@ def test_script_neither():
         Script.create()
         Script.create()
 
 
 
 
-class EvState(State):
+class EvState(BaseState):
     """State for testing event handlers."""
     """State for testing event handlers."""
 
 
     def on_ready(self):
     def on_ready(self):

+ 5 - 4
tests/components/datadisplay/conftest.py

@@ -5,6 +5,7 @@ import pandas as pd
 import pytest
 import pytest
 
 
 import reflex as rx
 import reflex as rx
+from reflex.state import BaseState
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -18,7 +19,7 @@ def data_table_state(request):
         The data table state class.
         The data table state class.
     """
     """
 
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         data = request.param["data"]
         data = request.param["data"]
         columns = ["column1", "column2"]
         columns = ["column1", "column2"]
 
 
@@ -33,7 +34,7 @@ def data_table_state2():
         The data table state class.
         The data table state class.
     """
     """
 
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data = pd.DataFrame()
         _data = pd.DataFrame()
 
 
         @rx.var
         @rx.var
@@ -51,7 +52,7 @@ def data_table_state3():
         The data table state class.
         The data table state class.
     """
     """
 
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data: List = []
         _data: List = []
         _columns: List = ["col1", "col2"]
         _columns: List = ["col1", "col2"]
 
 
@@ -74,7 +75,7 @@ def data_table_state4():
         The data table state class.
         The data table state class.
     """
     """
 
 
-    class DataTableState(rx.State):
+    class DataTableState(BaseState):
         _data: List = []
         _data: List = []
         _columns: List = ["col1", "col2"]
         _columns: List = ["col1", "col2"]
 
 

+ 2 - 2
tests/components/datadisplay/test_table.py

@@ -4,12 +4,12 @@ from typing import List, Tuple
 import pytest
 import pytest
 
 
 from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
 from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
-from reflex.state import State
+from reflex.state import BaseState
 
 
 PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
 PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
 
 
 
 
-class TableState(State):
+class TableState(BaseState):
     """Test State class."""
     """Test State class."""
 
 
     rows_List_List_str: List[List[str]] = [["random", "row"]]
     rows_List_List_str: List[List[str]] = [["random", "row"]]

+ 2 - 1
tests/components/forms/test_debounce.py

@@ -3,6 +3,7 @@
 import pytest
 import pytest
 
 
 import reflex as rx
 import reflex as rx
+from reflex.state import BaseState
 from reflex.vars import BaseVar
 from reflex.vars import BaseVar
 
 
 
 
@@ -24,7 +25,7 @@ def test_render_many_child():
         _ = rx.debounce_input("foo", "bar").render()
         _ = rx.debounce_input("foo", "bar").render()
 
 
 
 
-class S(rx.State):
+class S(BaseState):
     """Example state for debounce tests."""
     """Example state for debounce tests."""
 
 
     value: str = ""
     value: str = ""

+ 2 - 1
tests/components/layout/test_cond.py

@@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
     tablet_only,
     tablet_only,
 )
 )
 from reflex.components.typography.text import Text
 from reflex.components.typography.text import Text
+from reflex.state import BaseState
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 
 
 @pytest.fixture
 @pytest.fixture
 def cond_state(request):
 def cond_state(request):
-    class CondState(rx.State):
+    class CondState(BaseState):
         value: request.param["value_type"] = request.param["value"]  # noqa
         value: request.param["value_type"] = request.param["value"]  # noqa
 
 
     return CondState
     return CondState

+ 2 - 2
tests/components/layout/test_foreach.py

@@ -4,10 +4,10 @@ import pytest
 
 
 from reflex.components import box, foreach, text
 from reflex.components import box, foreach, text
 from reflex.components.layout import Foreach
 from reflex.components.layout import Foreach
-from reflex.state import State
+from reflex.state import BaseState
 
 
 
 
-class ForEachState(State):
+class ForEachState(BaseState):
     """A state for testing the ForEach component."""
     """A state for testing the ForEach component."""
 
 
     colors_list: List[str] = ["red", "yellow"]
     colors_list: List[str] = ["red", "yellow"]

+ 3 - 3
tests/components/test_component.py

@@ -14,7 +14,7 @@ from reflex.components.component import (
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
 from reflex.event import EventChain, EventHandler
 from reflex.event import EventChain, EventHandler
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils import imports
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
@@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
 
 
 @pytest.fixture
 @pytest.fixture
 def test_state():
 def test_state():
-    class TestState(State):
+    class TestState(BaseState):
         num: int
         num: int
 
 
         def do_something(self):
         def do_something(self):
@@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
     )
     )
 
 
 
 
-class C1State(State):
+class C1State(BaseState):
     """State for testing C1 component."""
     """State for testing C1 component."""
 
 
     def mock_handler(self, _e, _bravo, _charlie):
     def mock_handler(self, _e, _bravo, _charlie):

+ 0 - 21
tests/conftest.py

@@ -8,7 +8,6 @@ from typing import Dict, Generator
 
 
 import pytest
 import pytest
 
 
-import reflex as rx
 from reflex.app import App
 from reflex.app import App
 from reflex.event import EventSpec
 from reflex.event import EventSpec
 
 
@@ -225,23 +224,3 @@ def token() -> str:
         A fresh/unique token string.
         A fresh/unique token string.
     """
     """
     return str(uuid.uuid4())
     return str(uuid.uuid4())
-
-
-@pytest.fixture
-def duplicate_substate():
-    """Create a Test state that has duplicate child substates.
-
-    Returns:
-        The test state.
-    """
-
-    class TestState(rx.State):
-        pass
-
-    class ChildTestState(TestState):  # type: ignore # noqa
-        pass
-
-    class ChildTestState(TestState):  # type: ignore # noqa
-        pass
-
-    return TestState

+ 9 - 5
tests/middleware/test_hydrate_middleware.py

@@ -2,13 +2,14 @@ from typing import Any, Dict
 
 
 import pytest
 import pytest
 
 
+from reflex import constants
 from reflex.app import App
 from reflex.app import App
 from reflex.constants import CompileVars
 from reflex.constants import CompileVars
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
 from reflex.middleware.hydrate_middleware import HydrateMiddleware
-from reflex.state import State, StateUpdate
+from reflex.state import BaseState, StateUpdate
 
 
 
 
-def exp_is_hydrated(state: State) -> Dict[str, Any]:
+def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
     """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
     """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
 
 
     Args:
     Args:
@@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
     return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
     return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
 
 
 
 
-class TestState(State):
+class TestState(BaseState):
     """A test state with no return in handler."""
     """A test state with no return in handler."""
 
 
     __test__ = False
     __test__ = False
@@ -32,7 +33,7 @@ class TestState(State):
         self.num += 1
         self.num += 1
 
 
 
 
-class TestState2(State):
+class TestState2(BaseState):
     """A test state with return in handler."""
     """A test state with return in handler."""
 
 
     __test__ = False
     __test__ = False
@@ -54,7 +55,7 @@ class TestState2(State):
         self.name = "random"
         self.name = "random"
 
 
 
 
-class TestState3(State):
+class TestState3(BaseState):
     """A test state with async handler."""
     """A test state with async handler."""
 
 
     __test__ = False
     __test__ = False
@@ -97,6 +98,9 @@ async def test_preprocess(
         event_fixture: The event fixture(an Event).
         event_fixture: The event fixture(an Event).
         expected: Expected delta.
         expected: Expected delta.
     """
     """
+    test_state.add_var(
+        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+    )
     app = App(state=test_state, load_events={"index": [test_state.test_handler]})
     app = App(state=test_state, load_events={"index": [test_state.test_handler]})
     state = test_state()
     state = test_state()
 
 

+ 4 - 2
tests/states/__init__.py

@@ -1,10 +1,12 @@
-"""Common rx.State subclasses for use in tests."""
+"""Common rx.BaseState subclasses for use in tests."""
 import reflex as rx
 import reflex as rx
+from reflex.state import BaseState
 
 
 from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
 from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
 from .upload import (
 from .upload import (
     ChildFileUploadState,
     ChildFileUploadState,
     FileStateBase1,
     FileStateBase1,
+    FileStateBase2,
     FileUploadState,
     FileUploadState,
     GrandChildFileUploadState,
     GrandChildFileUploadState,
     SubUploadState,
     SubUploadState,
@@ -12,7 +14,7 @@ from .upload import (
 )
 )
 
 
 
 
-class GenState(rx.State):
+class GenState(BaseState):
     """A state with event handlers that generate multiple updates."""
     """A state with event handlers that generate multiple updates."""
 
 
     value: int
     value: int

+ 4 - 3
tests/states/mutation.py

@@ -3,9 +3,10 @@
 from typing import Dict, List, Set, Union
 from typing import Dict, List, Set, Union
 
 
 import reflex as rx
 import reflex as rx
+from reflex.state import BaseState
 
 
 
 
-class DictMutationTestState(rx.State):
+class DictMutationTestState(BaseState):
     """A state for testing ReflexDict mutation."""
     """A state for testing ReflexDict mutation."""
 
 
     # plain dict
     # plain dict
@@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
         self.friend_in_nested_dict["friend"]["age"] = 30
         self.friend_in_nested_dict["friend"]["age"] = 30
 
 
 
 
-class ListMutationTestState(rx.State):
+class ListMutationTestState(BaseState):
     """A state for testing ReflexList mutation."""
     """A state for testing ReflexList mutation."""
 
 
     # plain list
     # plain list
@@ -144,7 +145,7 @@ class CustomVar(rx.Base):
     custom: OtherBase = OtherBase()
     custom: OtherBase = OtherBase()
 
 
 
 
-class MutableTestState(rx.State):
+class MutableTestState(BaseState):
     """A test state."""
     """A test state."""
 
 
     array: List[Union[str, List, Dict[str, str]]] = [
     array: List[Union[str, List, Dict[str, str]]] = [

+ 5 - 4
tests/states/upload.py

@@ -3,9 +3,10 @@ from pathlib import Path
 from typing import ClassVar, List
 from typing import ClassVar, List
 
 
 import reflex as rx
 import reflex as rx
+from reflex.state import BaseState, State
 
 
 
 
-class UploadState(rx.State):
+class UploadState(BaseState):
     """The base state for uploading a file."""
     """The base state for uploading a file."""
 
 
     async def handle_upload1(self, files: List[rx.UploadFile]):
     async def handle_upload1(self, files: List[rx.UploadFile]):
@@ -17,7 +18,7 @@ class UploadState(rx.State):
         pass
         pass
 
 
 
 
-class BaseState(rx.State):
+class BaseState(BaseState):
     """The test base state."""
     """The test base state."""
 
 
     pass
     pass
@@ -37,7 +38,7 @@ class SubUploadState(BaseState):
         pass
         pass
 
 
 
 
-class FileUploadState(rx.State):
+class FileUploadState(State):
     """The base state for uploading a file."""
     """The base state for uploading a file."""
 
 
     img_list: List[str]
     img_list: List[str]
@@ -79,7 +80,7 @@ class FileUploadState(rx.State):
         pass
         pass
 
 
 
 
-class FileStateBase1(rx.State):
+class FileStateBase1(State):
     """The base state for a child FileUploadState."""
     """The base state for a child FileUploadState."""
 
 
     pass
     pass

+ 27 - 19
tests/test_app.py

@@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
-from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
+from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
@@ -43,7 +43,7 @@ from .states import (
 )
 )
 
 
 
 
-class EmptyState(State):
+class EmptyState(BaseState):
     """An empty state."""
     """An empty state."""
 
 
     pass
     pass
@@ -77,14 +77,14 @@ def about_page():
     return about
     return about
 
 
 
 
-class ATestState(State):
+class ATestState(BaseState):
     """A simple state for testing."""
     """A simple state for testing."""
 
 
     var: int
     var: int
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()
-def test_state() -> Type[State]:
+def test_state() -> Type[BaseState]:
     """A default state.
     """A default state.
 
 
     Returns:
     Returns:
@@ -94,14 +94,14 @@ def test_state() -> Type[State]:
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()
-def redundant_test_state() -> Type[State]:
+def redundant_test_state() -> Type[BaseState]:
     """A default state.
     """A default state.
 
 
     Returns:
     Returns:
         A default state.
         A default state.
     """
     """
 
 
-    class RedundantTestState(State):
+    class RedundantTestState(BaseState):
         var: int
         var: int
 
 
     return RedundantTestState
     return RedundantTestState
@@ -198,12 +198,12 @@ def test_default_app(app: App):
 
 
 
 
 def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
 def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
-    """Test that an error is thrown when multiple classes subclass rx.State.
+    """Test that an error is thrown when multiple classes subclass rx.BaseState.
 
 
     Args:
     Args:
         monkeypatch: Pytest monkeypatch object.
         monkeypatch: Pytest monkeypatch object.
-        test_state: A test state subclassing rx.State.
-        redundant_test_state: Another test state subclassing rx.State.
+        test_state: A test state subclassing rx.BaseState.
+        redundant_test_state: Another test state subclassing rx.BaseState.
     """
     """
     monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
     monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
@@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
     [
     [
         (
         (
             FileUploadState,
             FileUploadState,
-            {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
+            {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
         ),
         ),
         (
         (
             ChildFileUploadState,
             ChildFileUploadState,
             {
             {
-                "file_state_base1.child_file_upload_state": {
+                "state.file_state_base1.child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
                 }
             },
             },
@@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
         (
         (
             GrandChildFileUploadState,
             GrandChildFileUploadState,
             {
             {
-                "file_state_base1.file_state_base2.grand_child_file_upload_state": {
+                "state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
                     "img_list": ["image1.jpg", "image2.jpg"]
                     "img_list": ["image1.jpg", "image2.jpg"]
                 }
                 }
             },
             },
         ),
         ),
     ],
     ],
 )
 )
-async def test_upload_file(tmp_path, state, delta, token: str):
+async def test_upload_file(tmp_path, state, delta, token: str, mocker):
     """Test that file upload works correctly.
     """Test that file upload works correctly.
 
 
     Args:
     Args:
@@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
         state: The state class.
         state: The state class.
         delta: Expected delta
         delta: Expected delta
         token: a Token.
         token: a Token.
+        mocker: pytest mocker object.
     """
     """
+    mocker.patch(
+        "reflex.state.State.class_subclasses",
+        {state if state is FileUploadState else FileStateBase1},
+    )
     state._tmp_path = tmp_path
     state._tmp_path = tmp_path
     # The App state must be the "root" of the state tree
     # The App state must be the "root" of the state tree
-    app = App(state=state if state is FileUploadState else FileStateBase1)
+    app = App(state=State)
     app.event_namespace.emit = AsyncMock()  # type: ignore
     app.event_namespace.emit = AsyncMock()  # type: ignore
     current_state = await app.state_manager.get_state(token)
     current_state = await app.state_manager.get_state(token)
     data = b"This is binary data"
     data = b"This is binary data"
@@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
     request_mock = unittest.mock.Mock()
     request_mock = unittest.mock.Mock()
     request_mock.headers = {
     request_mock.headers = {
         "reflex-client-token": token,
         "reflex-client-token": token,
-        "reflex-event-handler": f"{state_name}.multi_handle_upload",
+        "reflex-event-handler": f"state.{state_name}.multi_handle_upload",
     }
     }
 
 
     file1 = UploadFile(
     file1 = UploadFile(
@@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
         await app.state_manager.redis.close()
         await app.state_manager.redis.close()
 
 
 
 
-class DynamicState(State):
+class DynamicState(BaseState):
     """State class for testing dynamic route var.
     """State class for testing dynamic route var.
 
 
     This is defined at module level because event handlers cannot be addressed
     This is defined at module level because event handlers cannot be addressed
@@ -891,9 +896,7 @@ class DynamicState(State):
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dynamic_route_var_route_change_completed_on_load(
 async def test_dynamic_route_var_route_change_completed_on_load(
-    index_page,
-    windows_platform: bool,
-    token: str,
+    index_page, windows_platform: bool, token: str, mocker
 ):
 ):
     """Create app with dynamic route var, and simulate navigation.
     """Create app with dynamic route var, and simulate navigation.
 
 
@@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
         index_page: The index page.
         index_page: The index page.
         windows_platform: Whether the system is windows.
         windows_platform: Whether the system is windows.
         token: a Token.
         token: a Token.
+        mocker: pytest mocker object.
     """
     """
+    mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
+    DynamicState.add_var(
+        constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
+    )
     arg_name = "dynamic"
     arg_name = "dynamic"
     route = f"/test/[{arg_name}]"
     route = f"/test/[{arg_name}]"
     if windows_platform:
     if windows_platform:

+ 2 - 2
tests/test_event.py

@@ -4,7 +4,7 @@ import pytest
 
 
 from reflex import event
 from reflex import event
 from reflex.event import Event, EventHandler, EventSpec, fix_events
 from reflex.event import Event, EventHandler, EventSpec, fix_events
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import Var
 from reflex.vars import Var
 
 
@@ -303,7 +303,7 @@ def test_event_actions():
 
 
 
 
 def test_event_actions_on_state():
 def test_event_actions_on_state():
-    class EventActionState(State):
+    class EventActionState(BaseState):
         def handler(self):
         def handler(self):
             pass
             pass
 
 

+ 35 - 29
tests/test_state.py

@@ -19,11 +19,11 @@ from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
 from reflex.event import Event, EventHandler
 from reflex.state import (
 from reflex.state import (
+    BaseState,
     ImmutableStateError,
     ImmutableStateError,
     LockExpiredError,
     LockExpiredError,
     MutableProxy,
     MutableProxy,
     RouterData,
     RouterData,
-    State,
     StateManager,
     StateManager,
     StateManagerMemory,
     StateManagerMemory,
     StateManagerRedis,
     StateManagerRedis,
@@ -75,7 +75,7 @@ class Object(Base):
     prop2: str = "hello"
     prop2: str = "hello"
 
 
 
 
-class TestState(State):
+class TestState(BaseState):
     """A test state."""
     """A test state."""
 
 
     # Set this class as not test one
     # Set this class as not test one
@@ -148,7 +148,7 @@ class GrandchildState(ChildState):
         pass
         pass
 
 
 
 
-class DateTimeState(State):
+class DateTimeState(BaseState):
     """A State with some datetime fields."""
     """A State with some datetime fields."""
 
 
     d: datetime.date = datetime.date.fromisoformat("1989-11-09")
     d: datetime.date = datetime.date.fromisoformat("1989-11-09")
@@ -253,7 +253,6 @@ def test_class_vars(test_state):
     """
     """
     cls = type(test_state)
     cls = type(test_state)
     assert set(cls.vars.keys()) == {
     assert set(cls.vars.keys()) == {
-        CompileVars.IS_HYDRATED,  # added by hydrate_middleware to all State
         "router",
         "router",
         "num1",
         "num1",
         "num2",
         "num2",
@@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
         "obj",
         "obj",
         "upper",
         "upper",
         "complex",
         "complex",
-        "is_hydrated",
         "fig",
         "fig",
         "key",
         "key",
         "sum",
         "sum",
@@ -837,7 +835,7 @@ def test_get_query_params(test_state):
 
 
 
 
 def test_add_var():
 def test_add_var():
-    class DynamicState(State):
+    class DynamicState(BaseState):
         pass
         pass
 
 
     ds1 = DynamicState()
     ds1 = DynamicState()
@@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
     assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
     assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
 
 
 
 
-class InterdependentState(State):
+class InterdependentState(BaseState):
     """A state with 3 vars and 3 computed vars.
     """A state with 3 vars and 3 computed vars.
 
 
     x: a variable that no computed var depends on
     x: a variable that no computed var depends on
@@ -915,7 +913,7 @@ class InterdependentState(State):
 
 
 
 
 @pytest.fixture
 @pytest.fixture
-def interdependent_state() -> State:
+def interdependent_state() -> BaseState:
     """A state with varying dependency between vars.
     """A state with varying dependency between vars.
 
 
     Returns:
     Returns:
@@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
 def test_child_state():
 def test_child_state():
     """Test that the child state computed vars can reference parent state vars."""
     """Test that the child state computed vars can reference parent state vars."""
 
 
-    class MainState(State):
+    class MainState(BaseState):
         v: int = 2
         v: int = 2
 
 
     class ChildState(MainState):
     class ChildState(MainState):
@@ -1006,7 +1004,7 @@ def test_child_state():
 def test_conditional_computed_vars():
 def test_conditional_computed_vars():
     """Test that computed vars can have conditionals."""
     """Test that computed vars can have conditionals."""
 
 
-    class MainState(State):
+    class MainState(BaseState):
         flag: bool = False
         flag: bool = False
         t1: str = "a"
         t1: str = "a"
         t2: str = "b"
         t2: str = "b"
@@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
 def test_event_handlers_call_other_handlers():
 def test_event_handlers_call_other_handlers():
     """Test that event handlers can call other event handlers."""
     """Test that event handlers can call other event handlers."""
 
 
-    class MainState(State):
+    class MainState(BaseState):
         v: int = 0
         v: int = 0
 
 
         def set_v(self, v: int):
         def set_v(self, v: int):
@@ -1077,7 +1075,7 @@ def test_computed_var_cached():
     """Test that a ComputedVar doesn't recalculate when accessed."""
     """Test that a ComputedVar doesn't recalculate when accessed."""
     comp_v_calls = 0
     comp_v_calls = 0
 
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
         v: int = 0
 
 
         @rx.cached_var
         @rx.cached_var
@@ -1102,7 +1100,7 @@ def test_computed_var_cached():
 def test_computed_var_cached_depends_on_non_cached():
 def test_computed_var_cached_depends_on_non_cached():
     """Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
     """Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
 
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
         v: int = 0
 
 
         @rx.var
         @rx.var
@@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
     """Child state cached_var that depends on parent state un cached var is always recalculated."""
     """Child state cached_var that depends on parent state un cached var is always recalculated."""
     counter = 0
     counter = 0
 
 
-    class ParentState(State):
+    class ParentState(BaseState):
         @rx.var
         @rx.var
         def no_cache_v(self) -> int:
         def no_cache_v(self) -> int:
             nonlocal counter
             nonlocal counter
@@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
     dict1 = ps.dict()
     dict1 = ps.dict()
     assert dict1[ps.get_full_name()] == {
     assert dict1[ps.get_full_name()] == {
         "no_cache_v": 1,
         "no_cache_v": 1,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
     assert dict1[cs.get_full_name()] == {"dep_v": 2}
     assert dict1[cs.get_full_name()] == {"dep_v": 2}
     dict2 = ps.dict()
     dict2 = ps.dict()
     assert dict2[ps.get_full_name()] == {
     assert dict2[ps.get_full_name()] == {
         "no_cache_v": 3,
         "no_cache_v": 3,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
     assert dict2[cs.get_full_name()] == {"dep_v": 4}
     assert dict2[cs.get_full_name()] == {"dep_v": 4}
     dict3 = ps.dict()
     dict3 = ps.dict()
     assert dict3[ps.get_full_name()] == {
     assert dict3[ps.get_full_name()] == {
         "no_cache_v": 5,
         "no_cache_v": 5,
-        CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
     assert dict3[cs.get_full_name()] == {"dep_v": 6}
     assert dict3[cs.get_full_name()] == {"dep_v": 6}
@@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
     """
     """
     counter = 0
     counter = 0
 
 
-    class HandlerState(State):
+    class HandlerState(BaseState):
         x: int = 42
         x: int = 42
 
 
         def handler(self):
         def handler(self):
@@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
 def test_computed_var_dependencies():
 def test_computed_var_dependencies():
     """Test that a ComputedVar correctly tracks its dependencies."""
     """Test that a ComputedVar correctly tracks its dependencies."""
 
 
-    class ComputedState(State):
+    class ComputedState(BaseState):
         v: int = 0
         v: int = 0
         w: int = 0
         w: int = 0
         x: int = 0
         x: int = 0
@@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
 def test_backend_method():
 def test_backend_method():
     """A method with leading underscore should be callable from event handler."""
     """A method with leading underscore should be callable from event handler."""
 
 
-    class BackendMethodState(State):
+    class BackendMethodState(BaseState):
         def _be_method(self):
         def _be_method(self):
             return True
             return True
 
 
@@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
     """Test that an error is thrown when an event handler shadows a state method."""
     """Test that an error is thrown when an event handler shadows a state method."""
     with pytest.raises(NameError) as err:
     with pytest.raises(NameError) as err:
 
 
-        class InvalidTest(rx.State):
+        class InvalidTest(BaseState):
             def reset(self):
             def reset(self):
                 pass
                 pass
 
 
@@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
 def test_state_with_invalid_yield():
 def test_state_with_invalid_yield():
     """Test that an error is thrown when a state yields an invalid value."""
     """Test that an error is thrown when a state yields an invalid value."""
 
 
-    class StateWithInvalidYield(rx.State):
+    class StateWithInvalidYield(BaseState):
         """A state that yields an invalid value."""
         """A state that yields an invalid value."""
 
 
         def invalid_handler(self):
         def invalid_handler(self):
@@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
     assert mcall.kwargs["to"] == grandchild_state.get_sid()
 
 
 
 
-class BackgroundTaskState(State):
+class BackgroundTaskState(BaseState):
     """A state with a background task."""
     """A state with a background task."""
 
 
     order: List[str] = []
     order: List[str] = []
@@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
         assert not isinstance(var_copy, MutableProxy)
         assert not isinstance(var_copy, MutableProxy)
 
 
 
 
-def test_duplicate_substate_class(duplicate_substate):
+def test_duplicate_substate_class(mocker):
+    mocker.patch("reflex.state.os.environ", {})
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
-        duplicate_substate()
+
+        class TestState(BaseState):
+            pass
+
+        class ChildTestState(TestState):  # type: ignore # noqa
+            pass
+
+        class ChildTestState(TestState):  # type: ignore # noqa
+            pass
+
+        return TestState
 
 
 
 
 class Foo(Base):
 class Foo(Base):
@@ -2206,7 +2212,7 @@ class Foo(Base):
 def test_json_dumps_with_mutables():
 def test_json_dumps_with_mutables():
     """Test that json.dumps works with Base vars inside mutable types."""
     """Test that json.dumps works with Base vars inside mutable types."""
 
 
-    class MutableContainsBase(State):
+    class MutableContainsBase(BaseState):
         items: List[Foo] = [Foo()]
         items: List[Foo] = [Foo()]
 
 
     dict_val = MutableContainsBase().dict()
     dict_val = MutableContainsBase().dict()
@@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
     f_formatted_router = str(formatted_router).replace("'", '"')
     f_formatted_router = str(formatted_router).replace("'", '"')
     assert (
     assert (
         val
         val
-        == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
+        == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
     )
     )
 
 
 
 
@@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
     default = [[0, 0], [0, 1], [1, 1]]
     default = [[0, 0], [0, 1], [1, 1]]
     copied_default = copy.deepcopy(default)
     copied_default = copy.deepcopy(default)
 
 
-    class MutableResetState(State):
+    class MutableResetState(BaseState):
         items: List[List[int]] = default
         items: List[List[int]] = default
 
 
     instance = MutableResetState()
     instance = MutableResetState()
@@ -2273,7 +2279,7 @@ class Custom3(Base):
 def test_state_union_optional():
 def test_state_union_optional():
     """Test that state can be defined with Union and Optional vars."""
     """Test that state can be defined with Union and Optional vars."""
 
 
-    class UnionState(State):
+    class UnionState(BaseState):
         int_float: Union[int, float] = 0
         int_float: Union[int, float] = 0
         opt_int: Optional[int]
         opt_int: Optional[int]
         c3: Optional[Custom3]
         c3: Optional[Custom3]

+ 5 - 11
tests/test_var.py

@@ -6,7 +6,7 @@ import pytest
 from pandas import DataFrame
 from pandas import DataFrame
 
 
 from reflex.base import Base
 from reflex.base import Base
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.vars import (
 from reflex.vars import (
     BaseVar,
     BaseVar,
     ComputedVar,
     ComputedVar,
@@ -24,12 +24,6 @@ test_vars = [
 ]
 ]
 
 
 
 
-class BaseState(State):
-    """A Test State."""
-
-    val: str = "key"
-
-
 @pytest.fixture
 @pytest.fixture
 def TestObj():
 def TestObj():
     class TestObj(Base):
     class TestObj(Base):
@@ -41,7 +35,7 @@ def TestObj():
 
 
 @pytest.fixture
 @pytest.fixture
 def ParentState(TestObj):
 def ParentState(TestObj):
-    class ParentState(State):
+    class ParentState(BaseState):
         foo: int
         foo: int
         bar: int
         bar: int
 
 
@@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
 
 
 @pytest.fixture
 @pytest.fixture
 def StateWithAnyVar(TestObj):
 def StateWithAnyVar(TestObj):
-    class StateWithAnyVar(State):
+    class StateWithAnyVar(BaseState):
         @ComputedVar
         @ComputedVar
         def var_without_annotation(self) -> typing.Any:
         def var_without_annotation(self) -> typing.Any:
             return TestObj
             return TestObj
@@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
 
 
 @pytest.fixture
 @pytest.fixture
 def StateWithCorrectVarAnnotation():
 def StateWithCorrectVarAnnotation():
-    class StateWithCorrectVarAnnotation(State):
+    class StateWithCorrectVarAnnotation(BaseState):
         @ComputedVar
         @ComputedVar
         def var_with_annotation(self) -> str:
         def var_with_annotation(self) -> str:
             return "Correct annotation"
             return "Correct annotation"
@@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
 
 
 @pytest.fixture
 @pytest.fixture
 def StateWithWrongVarAnnotation(TestObj):
 def StateWithWrongVarAnnotation(TestObj):
-    class StateWithWrongVarAnnotation(State):
+    class StateWithWrongVarAnnotation(BaseState):
         @ComputedVar
         @ComputedVar
         def var_with_annotation(self) -> str:
         def var_with_annotation(self) -> str:
             return TestObj
             return TestObj

+ 0 - 2
tests/utils/test_format.py

@@ -528,7 +528,6 @@ formatted_router = {
                     },
                     },
                     "dt": "1989-11-09 18:53:00+01:00",
                     "dt": "1989-11-09 18:53:00+01:00",
                     "fig": [],
                     "fig": [],
-                    "is_hydrated": False,
                     "key": "",
                     "key": "",
                     "map_key": "a",
                     "map_key": "a",
                     "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
                     "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
@@ -553,7 +552,6 @@ formatted_router = {
                 DateTimeState.get_full_name(): {
                 DateTimeState.get_full_name(): {
                     "d": "1989-11-09",
                     "d": "1989-11-09",
                     "dt": "1989-11-09 18:53:00+01:00",
                     "dt": "1989-11-09 18:53:00+01:00",
-                    "is_hydrated": False,
                     "t": "18:53:00+01:00",
                     "t": "18:53:00+01:00",
                     "td": "11 days, 0:11:00",
                     "td": "11 days, 0:11:00",
                     "router": formatted_router,
                     "router": formatted_router,

+ 2 - 2
tests/utils/test_utils.py

@@ -10,7 +10,7 @@ from packaging import version
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.event import EventHandler
 from reflex.event import EventHandler
-from reflex.state import State
+from reflex.state import BaseState
 from reflex.utils import (
 from reflex.utils import (
     build,
     build,
     prerequisites,
     prerequisites,
@@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
 VMAXPLUS1 = version.parse(get_above_max_version())
 VMAXPLUS1 = version.parse(get_above_max_version())
 
 
 
 
-class ExampleTestState(State):
+class ExampleTestState(BaseState):
     """Test state class."""
     """Test state class."""
 
 
     def test_event_handler(self):
     def test_event_handler(self):