Răsfoiți Sursa

Merge branch 'main' into move-tailwind-to-its-own-module

Khaleel Al-Adhami 1 săptămână în urmă
părinte
comite
6172db89ae
84 a modificat fișierele cu 3108 adăugiri și 2601 ștergeri
  1. 0 40
      .coveragerc
  2. 3 3
      .pre-commit-config.yaml
  3. 16 16
      pyi_hashes.json
  4. 41 3
      pyproject.toml
  5. 10 10
      reflex/.templates/jinja/web/pages/_app.js.jinja2
  6. 1 1
      reflex/.templates/jinja/web/pages/_document.js.jinja2
  7. 1 1
      reflex/.templates/jinja/web/pages/index.js.jinja2
  8. 1 1
      reflex/.templates/jinja/web/pages/stateful_component.js.jinja2
  9. 35 72
      reflex/.templates/jinja/web/pages/utils.js.jinja2
  10. 6 18
      reflex/.templates/jinja/web/utils/context.js.jinja2
  11. 7 7
      reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js
  12. 5 4
      reflex/.templates/web/components/shiki/code.js
  13. 12 4
      reflex/app.py
  14. 20 3
      reflex/compiler/compiler.py
  15. 8 5
      reflex/components/base/bare.py
  16. 2 4
      reflex/components/base/body.py
  17. 1 1
      reflex/components/base/error_boundary.py
  18. 3 3
      reflex/components/base/link.py
  19. 5 9
      reflex/components/base/meta.py
  20. 50 37
      reflex/components/component.py
  21. 35 4
      reflex/components/core/colors.py
  22. 2 2
      reflex/components/core/cond.py
  23. 50 18
      reflex/components/core/upload.py
  24. 2 2
      reflex/components/datadisplay/shiki_code_block.py
  25. 5 2
      reflex/components/dynamic.py
  26. 4 0
      reflex/components/el/element.py
  27. 3 1
      reflex/components/el/elements/forms.py
  28. 3 1
      reflex/components/el/elements/inline.py
  29. 1 1
      reflex/components/el/elements/metadata.py
  30. 3 1
      reflex/components/el/elements/typography.py
  31. 46 3
      reflex/components/lucide/icon.py
  32. 6 7
      reflex/components/markdown/markdown.py
  33. 1 1
      reflex/components/moment/moment.py
  34. 8 1
      reflex/components/next/video.py
  35. 1 1
      reflex/components/plotly/plotly.py
  36. 1 1
      reflex/components/radix/primitives/drawer.py
  37. 3 2
      reflex/components/radix/themes/layout/list.py
  38. 21 7
      reflex/components/react_player/react_player.py
  39. 1 1
      reflex/components/suneditor/editor.py
  40. 5 2
      reflex/config.py
  41. 1 1
      reflex/constants/base.py
  42. 23 6
      reflex/constants/colors.py
  43. 20 6
      reflex/constants/installer.py
  44. 24 0
      reflex/event.py
  45. 858 0
      reflex/istate/manager.py
  46. 726 2
      reflex/istate/proxy.py
  47. 12 6
      reflex/reflex.py
  48. 94 1624
      reflex/state.py
  49. 20 13
      reflex/testing.py
  50. 1 2
      reflex/utils/build.py
  51. 16 6
      reflex/utils/format.py
  52. 22 35
      reflex/utils/prerequisites.py
  53. 1 0
      reflex/utils/pyi_generator.py
  54. 7 0
      reflex/utils/redir.py
  55. 14 0
      reflex/utils/serializers.py
  56. 24 8
      reflex/utils/types.py
  57. 13 1
      reflex/vars/base.py
  58. 16 8
      reflex/vars/number.py
  59. 2 1
      reflex/vars/sequence.py
  60. 0 0
      tests/benchmarks/fixtures.py
  61. 1 2
      tests/integration/test_connection_banner.py
  62. 57 6
      tests/integration/test_lifespan.py
  63. 4 4
      tests/units/components/base/test_bare.py
  64. 4 6
      tests/units/components/base/test_link.py
  65. 5 5
      tests/units/components/base/test_script.py
  66. 19 12
      tests/units/components/core/test_colors.py
  67. 2 2
      tests/units/components/core/test_cond.py
  68. 1 1
      tests/units/components/core/test_foreach.py
  69. 5 3
      tests/units/components/core/test_html.py
  70. 10 9
      tests/units/components/core/test_match.py
  71. 2 2
      tests/units/components/datadisplay/test_dataeditor.py
  72. 2 2
      tests/units/components/datadisplay/test_datatable.py
  73. 12 12
      tests/units/components/el/test_svg.py
  74. 1 1
      tests/units/components/forms/test_form.py
  75. 0 0
      tests/units/components/markdown/test_markdown.py
  76. 38 20
      tests/units/components/test_component.py
  77. 4 4
      tests/units/components/test_tag.py
  78. 30 31
      tests/units/test_app.py
  79. 44 0
      tests/units/test_event.py
  80. 48 10
      tests/units/test_state.py
  81. 41 0
      tests/units/test_var.py
  82. 3 3
      tests/units/utils/test_format.py
  83. 6 0
      tests/units/utils/test_serializers.py
  84. 447 447
      uv.lock

+ 0 - 40
.coveragerc

@@ -1,40 +0,0 @@
-[run]
-source = reflex
-branch = true
-omit =
-    */pyi_generator.py
-    reflex/__main__.py
-    reflex/app_module_for_backend.py
-    reflex/components/chakra/*
-    reflex/experimental/*
-
-[report]
-show_missing = true
-# TODO bump back to 79
-fail_under = 70
-precision = 2
-
-# Regexes for lines to exclude from consideration
-exclude_also =
-    # Don't complain about missing debug-only code:
-    def __repr__
-    if self\.debug
-
-    # Don't complain if tests don't hit defensive assertion code:
-    raise AssertionError
-    raise NotImplementedError
-
-    # Don't complain if non-runnable code isn't run:
-    if 0:
-    if __name__ == .__main__.:
-
-    # Don't complain about abstract methods, they aren't run:
-    @(abc\.)?abstractmethod
-    
-    # Don't complain about overloaded methods:
-    @overload
-
-ignore_errors = True
-
-[html]
-directory = coverage_html_report

+ 3 - 3
.pre-commit-config.yaml

@@ -1,8 +1,8 @@
 fail_fast: true
 
 repos:
-  - repo: https://github.com/charliermarsh/ruff-pre-commit
-    rev: v0.11.2
+  - repo: https://github.com/astral-sh/ruff-pre-commit
+    rev: v0.11.8
     hooks:
       - id: ruff-format
         args: [reflex, tests]
@@ -30,7 +30,7 @@ repos:
         entry: python3 scripts/make_pyi.py
 
   - repo: https://github.com/RobertCraigie/pyright-python
-    rev: v1.1.398
+    rev: v1.1.400
     hooks:
       - id: pyright
         args: [reflex, tests]

+ 16 - 16
pyi_hashes.json

@@ -3,13 +3,13 @@
   "reflex/components/__init__.pyi": "76ba0a12cd3a7ba5ab6341a3ae81551f",
   "reflex/components/base/__init__.pyi": "e9aaf47be1e1977eacee97b880c8f7de",
   "reflex/components/base/app_wrap.pyi": "387fc7a0c2da8760d9449e2893e44eec",
-  "reflex/components/base/body.pyi": "2cc870cec4b1c28081dd40467752c2b7",
+  "reflex/components/base/body.pyi": "2d16002f24c8ee0007b46ff2bf1f2c78",
   "reflex/components/base/document.pyi": "30377cdfb02b564f8de29b0473d2346c",
   "reflex/components/base/error_boundary.pyi": "c56b591d14a92b99a1e97e04afe167d7",
   "reflex/components/base/fragment.pyi": "603ee8e03af88d4a8ff6bc1fbce4e022",
   "reflex/components/base/head.pyi": "893047aa32da553711db8f1345adb6b0",
-  "reflex/components/base/link.pyi": "396488afa3b7a5b0d0e6c5e89159f857",
-  "reflex/components/base/meta.pyi": "bc4b4fda6f022a517de339ffdd667e3b",
+  "reflex/components/base/link.pyi": "e96179dc7823f354fb73a6c03e31028c",
+  "reflex/components/base/meta.pyi": "da52c3212fac6b50560863146a7afcc3",
   "reflex/components/base/script.pyi": "530cf8f47eb90082bf65942e8b5d745f",
   "reflex/components/base/strict_mode.pyi": "d972e0ff2a6f961e7df90fc27b8bb51b",
   "reflex/components/core/__init__.pyi": "44bcee7bc4e27e2f4f4707b843acf291",
@@ -20,32 +20,32 @@
   "reflex/components/core/debounce.pyi": "affda049624c266c7d5620efa3b7041b",
   "reflex/components/core/html.pyi": "b12117b42ef79ee90b6b4dec50baeb86",
   "reflex/components/core/sticky.pyi": "c65131cf7c2312c68e1fddaa0cc27150",
-  "reflex/components/core/upload.pyi": "53e06193fa23a603737bc49b1c6c2565",
+  "reflex/components/core/upload.pyi": "4680da6f7b3df704a682cc6441b1ac18",
   "reflex/components/datadisplay/__init__.pyi": "cf087efa8b3960decc6b231cc986cfa9",
   "reflex/components/datadisplay/code.pyi": "3d8f0ab4c2f123d7f80d15c7ebc553d9",
   "reflex/components/datadisplay/dataeditor.pyi": "cb03d732e2fe771a8d46c7bcda671f92",
   "reflex/components/datadisplay/shiki_code_block.pyi": "87db7639bfa5cd53e1709e1363f93278",
   "reflex/components/el/__init__.pyi": "09042a2db5e0637e99b5173430600522",
-  "reflex/components/el/element.pyi": "06ac2213b062119323291fa66a1ac19e",
+  "reflex/components/el/element.pyi": "ea6b33a8545c2c845dc6c30ff1c872a4",
   "reflex/components/el/elements/__init__.pyi": "280ed457675f3720e34b560a3f617739",
   "reflex/components/el/elements/base.pyi": "6e533348b5e1a88cf62fbb5a38dbd795",
-  "reflex/components/el/elements/forms.pyi": "161f1ef847e5da8755528a7977fdcf53",
-  "reflex/components/el/elements/inline.pyi": "33d9d860e75dd8c4769825127ed363bb",
+  "reflex/components/el/elements/forms.pyi": "3ff8fd5d8a36418874e9fe4ff76bbfe8",
+  "reflex/components/el/elements/inline.pyi": "f881d229c9ecaa61d17ac6837ac9a839",
   "reflex/components/el/elements/media.pyi": "addd6872281d65d44a484358b895432f",
-  "reflex/components/el/elements/metadata.pyi": "974a86d9f0662f6fc15a5bb4b3a87862",
+  "reflex/components/el/elements/metadata.pyi": "a5b9b30c4649e88aa26a1a5609fc86ef",
   "reflex/components/el/elements/other.pyi": "995a4fbf10bfdb7f48808210dfe413bd",
   "reflex/components/el/elements/scripts.pyi": "cd5bd53c3a6b016fbb913aff36d63344",
   "reflex/components/el/elements/sectioning.pyi": "65aa53b1372598ec1785616cb7016032",
   "reflex/components/el/elements/tables.pyi": "e1282d8ddf4efa4c911ca104a907ee88",
-  "reflex/components/el/elements/typography.pyi": "00088c9c1b68a14e5a41d837e8fdf542",
+  "reflex/components/el/elements/typography.pyi": "928ff998c9bbb32ae7ccce5f6cb885a7",
   "reflex/components/gridjs/datatable.pyi": "3db3f994640c19be5c3fa2983f71de56",
-  "reflex/components/lucide/icon.pyi": "a5521a8baf8d2d7281e3fdfe6ce7073b",
-  "reflex/components/markdown/markdown.pyi": "6b268afa879e33abf651bda56be5065e",
+  "reflex/components/lucide/icon.pyi": "508c8844959925555a895df8dcac3751",
+  "reflex/components/markdown/markdown.pyi": "1fc31d2652d3ff015c6da2c7cbab716a",
   "reflex/components/moment/moment.pyi": "6dd0c7cee5f0f29bc11d830c697d7f92",
   "reflex/components/next/base.pyi": "14aafd5b018a4bc9748a3c9980fcfe3e",
   "reflex/components/next/image.pyi": "3a0d1970e69144e9c6806e68ab99f181",
   "reflex/components/next/link.pyi": "cd913e10205314afe67101d9640e05cb",
-  "reflex/components/next/video.pyi": "09698418db651917630a7fefeb573fc2",
+  "reflex/components/next/video.pyi": "aa8f814dec99f8712dc2351b15f922ac",
   "reflex/components/plotly/plotly.pyi": "b1f0bbcaf4706d0a373c99395ba50118",
   "reflex/components/radix/__init__.pyi": "8d586cbff1d7130d09476ac72ee73400",
   "reflex/components/radix/primitives/__init__.pyi": "fe8715decf3e9ae471b56bba14e42cb3",
@@ -101,7 +101,7 @@
   "reflex/components/radix/themes/layout/container.pyi": "4020c3dca660027b84d11cc4198393c4",
   "reflex/components/radix/themes/layout/flex.pyi": "f814281a5635ad43dd1df23f8e356c66",
   "reflex/components/radix/themes/layout/grid.pyi": "6062188367a2c253f014f916197c963d",
-  "reflex/components/radix/themes/layout/list.pyi": "0e91d3f1c82c9094f328e5b8ecd2f60a",
+  "reflex/components/radix/themes/layout/list.pyi": "930009f82662686841e9ce97bfd4a1ea",
   "reflex/components/radix/themes/layout/section.pyi": "41895910072e023ed0fef6a8ad956046",
   "reflex/components/radix/themes/layout/spacer.pyi": "029eb0eaa731bcdff7c496e0437e22b1",
   "reflex/components/radix/themes/layout/stack.pyi": "3b0da99b00c826d087ed89fc67c595c1",
@@ -111,9 +111,9 @@
   "reflex/components/radix/themes/typography/heading.pyi": "5a3b0b8e44bda0fce22c6b1a1f25e68e",
   "reflex/components/radix/themes/typography/link.pyi": "45965d95b9f9b76f8f4a3084a5430194",
   "reflex/components/radix/themes/typography/text.pyi": "e6aa0ca43ebbd42701a3c72c0312032e",
-  "reflex/components/react_player/audio.pyi": "972975ed0ba3e1dc4a867da20b11ae8e",
-  "reflex/components/react_player/react_player.pyi": "63ffffbc24907103f797dcfd85894107",
-  "reflex/components/react_player/video.pyi": "35ce5ad62e8bff17d9c09d27c362f8dc",
+  "reflex/components/react_player/audio.pyi": "18fb682ec86d1b44682e1903dff11794",
+  "reflex/components/react_player/react_player.pyi": "171d829b30c1c0c62e49e4a21cffe50f",
+  "reflex/components/react_player/video.pyi": "5c93cfe85ba4dcadfddae94a2e36bb4e",
   "reflex/components/recharts/__init__.pyi": "a52c9055e37c6ee25ded15688d45e8a5",
   "reflex/components/recharts/cartesian.pyi": "9dd16c08abe5205c6c414474e2de2f79",
   "reflex/components/recharts/charts.pyi": "3570af4627c601d10ee37033f1b2329c",

+ 41 - 3
pyproject.toml

@@ -1,6 +1,6 @@
 [project]
 name = "reflex"
-version = "0.7.10dev1"
+version = "0.7.11dev1"
 description = "Web apps in pure Python."
 license = { text = "Apache-2.0" }
 authors = [
@@ -32,7 +32,7 @@ dependencies = [
   "python-socketio >=5.12.0,<6.0",
   "python-multipart >=0.0.20,<1.0",
   "redis >=5.2.1,<6.0",
-  "reflex-hosting-cli >=0.1.43",
+  "reflex-hosting-cli >=0.1.47",
   "rich >=13,<15",
   "sqlmodel >=0.0.24,<0.1",
   "click >=8",
@@ -150,7 +150,7 @@ dev = [
   "plotly >=6.0",
   "pre-commit >=4.2",
   "psycopg[binary] >=3.2",
-  "pyright >=1.1.399",
+  "pyright >=1.1.400",
   "pytest >=8.3",
   "pytest-asyncio >=0.26",
   "pytest-benchmark >=5.1",
@@ -165,5 +165,43 @@ dev = [
   "selenium >=4.31",
   "starlette-admin >=0.14",
   "uvicorn >=0.34.0",
+]
+
+
+[tool.coverage.run]
+source = ["reflex"]
+branch = true
+omit = [
+  "*/pyi_generator.py",
+  "reflex/__main__.py",
+  "reflex/app_module_for_backend.py",
+  "reflex/components/chakra/*",
+  "reflex/experimental/*",
+]
+
+[tool.coverage.report]
+show_missing = true
+# TODO bump back to 79
+fail_under = 70
+precision = 2
+ignore_errors = true
 
+exclude_also = [
+  "def __repr__",
+  # Don't complain about missing debug-only code:
+  "if self.debug",
+  # Don't complain if tests don't hit defensive assertion code:
+  "raise AssertionError",
+  "raise NotImplementedError",
+  # Regexes for lines to exclude from consideration
+  "if 0:",
+  # Don't complain if non-runnable code isn't run:
+  "if __name__ == .__main__.:",
+  # Don't complain about abstract methods, they aren't run:
+  "@(abc.)?abstractmethod",
+  # Don't complain about overloaded methods:
+  "@overload",
 ]
+
+[tool.coverage.html]
+directory = "coverage_html_report"

+ 10 - 10
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -22,7 +22,7 @@ function AppWrap({children}) {
   {{ renderHooks(hooks) }}
 
   return (
-    {{utils.render(render, indent_width=0)}}
+    {{utils.render(render)}}
   )
 }
 
@@ -37,15 +37,15 @@ export default function MyApp({ Component, pageProps }) {
     window["__reflex"] = windowImports;
   }, []);
   return (
-    <ThemeProvider defaultTheme={ defaultColorMode } attribute="class">
-      <StateProvider>
-        <EventLoopProvider>
-            <AppWrap>
-              <Component {...pageProps} />
-            </AppWrap>
-        </EventLoopProvider>
-      </StateProvider>
-    </ThemeProvider>
+    jsx(ThemeProvider, {defaultTheme:defaultColorMode,attribute:"class"},
+      jsx(StateProvider, {},
+        jsx(EventLoopProvider, {}, 
+          jsx(AppWrap, {},
+            jsx(Component, pageProps)
+          )
+        )
+      )
+    )
   );
 }
 

+ 1 - 1
reflex/.templates/jinja/web/pages/_document.js.jinja2

@@ -3,7 +3,7 @@
 {% block export %}
 export default function Document() {
   return (
-    {{utils.render(document, indent_width=0)}}
+    {{utils.render(document)}}
   )
 }
 {% endblock %}

+ 1 - 1
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -12,7 +12,7 @@ export default function Component() {
     {{ renderHooks(hooks)}}
 
   return (
-    {{utils.render(render, indent_width=0)}}
+    {{utils.render(render)}}
   )
 }
 {% endblock %}

+ 1 - 1
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -6,6 +6,6 @@ export function {{tag_name}} () {
   {{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }}
   
   return (
-    {{utils.render(component.render(), indent_width=0)}}
+    {{utils.render(component.render())}}
   )
 }

+ 35 - 72
reflex/.templates/jinja/web/pages/utils.js.jinja2

@@ -1,46 +1,37 @@
 {# Rendering components recursively. #}
 {# Args: #}
 {#     component: component dictionary #}
-{#     indent_width: indent width #}
-{% macro render(component, indent_width=0) %}
-{% filter indent(width=indent_width) %}
-  {%- if component is not mapping %}
-    {{- component }}
-  {%- elif "iterable" in component %}
-    {{- render_iterable_tag(component) }}
-  {%- elif component.name == "match"%}
-    {{- render_match_tag(component) }}
-  {%- elif "cond" in component %}
-    {{- render_condition_tag(component) }}
-  {%- elif component.children|length %}
-    {{- render_tag(component) }}
-  {%- else %}
-    {{- render_self_close_tag(component) }}
-  {%- endif %}
-{% endfilter %}
+{% macro render(component) %}
+{%- if component is not mapping %}{{ component }}
+{%- elif "iterable" in component %}{{ render_iterable_tag(component) }}
+{%- elif component.name == "match"%}{{ render_match_tag(component) }}
+{%- elif "cond" in component %}{{ render_condition_tag(component) }}
+{%- elif component.children|length %}{{ render_tag(component) }}
+{%- else %}{{ render_self_close_tag(component) }}
+{%- endif %}
 {% endmacro %}
 
 {# Rendering self close tag. #}
 {# Args: #}
 {#     component: component dictionary #}
 {% macro render_self_close_tag(component) %}
-{%- if component.name|length %}
-<{{ component.name }} {{- render_props(component.props) }}{% if component.autofocus %} ref={focusRef} {% endif %}/>
-{%- else %}
-  {{- component.contents }}
-{%- endif %}
+{% if component.name|length %}
+jsx({{ component.name }},{{ render_props(component.props) }},{{ component.contents }})
+{% elif component.contents|length -%}{{ component.contents }}
+{% else %}""
+{% endif %}
 {% endmacro %}
 
 {# Rendering close tag with args and props. #}
 {# Args: #}
 {#     component: component dictionary #}
 {% macro render_tag(component) %}
-<{{component.name}} {{- render_props(component.props) }}>
-{{ component.contents }}
-{% for child in component.children %}
-{{ render(child) }}
-{% endfor %}
-</{{component.name}}>
+jsx(
+{% if component.name|length %}{{ component.name }}{% else %}Fragment{% endif %},
+{{ render_props(component.props) }},
+{% if component.contents|length %}{{ component.contents }},{% endif %}
+{% for child in component.children %}{% if child is mapping or child|length %}{{ render(child) }},{% endif %}{% endfor %}
+)
 {%- endmacro %}
 
 
@@ -48,11 +39,7 @@
 {# Args: #}
 {#     component: component dictionary #}
 {% macro render_condition_tag(component) %}
-{ {{- component.cond_state }} ? (
-  {{ render(component.true_value) }}
-) : (
-  {{ render(component.false_value) }}
-)}
+({{ component.cond_state }} ? ({{ render(component.true_value) }}) : ({{ render(component.false_value) }}))
 {%- endmacro %}
 
 
@@ -60,57 +47,33 @@
 {# Args: #}
 {#     component: component dictionary #}
 {% macro render_iterable_tag(component) %}
-<>{ {{ component.iterable_state }}.map(({{ component.arg_name }}, {{ component.arg_index }}) => (
-  {% for child in component.children %}
-  {{ render(child) }}
-  {% endfor %}
-))}</>
+{{ component.iterable_state }}.map(({{ component.arg_name }},{{ component.arg_index }})=>({% for child in component.children %}{{ render(child) }}{% endfor %}))
 {%- endmacro %}
 
 
 {# Rendering props of a component. #}
 {# Args: #}
 {#     component: component dictionary #}
-{% macro render_props(props) %}
-{% if props|length %} {{ props|join(" ") }}{% endif %}
-{% endmacro %}
+{% macro render_props(props) %}{{ "{" }}{% if props|length %}{{ props|join(",") }}{% endif %}{{ "}" }}{% endmacro %}
 
 {# Rendering Match component. #}
 {# Args: #}
 {#     component: component dictionary #}
 {% macro render_match_tag(component) %}
-{
-    (() => {
-        switch (JSON.stringify({{ component.cond._js_expr }})) {
-        {% for case in component.match_cases %}
-            {% for condition in case[:-1] %}
-                case JSON.stringify({{ condition._js_expr }}):
-            {% endfor %}
-                return {{ render(case[-1]) }};
-                break;
-        {% endfor %}
-            default:
-                return {{ render(component.default) }};
-                break;
-        }
-    })()
-  }
-{%- endmacro %}
-
-
-{# Rendering content with args. #}
-{# Args: #}
-{#     component: component dictionary #}
-{% macro render_arg_content(component) %}
-{% filter indent(width=2) %}
-{# no string below for a line break #}
-
-{({ {{component.args|join(", ")}} }) => (
-  {% for child in component.children %}
-  {{ render(child) }}
+(() => {
+  switch (JSON.stringify({{ component.cond._js_expr }})) {
+  {% for case in component.match_cases %}
+    {% for condition in case[:-1] %}
+      case JSON.stringify({{ condition._js_expr }}):
+    {% endfor %}
+      return {{ render(case[-1]) }};
+      break;
   {% endfor %}
-)}
-{% endfilter %}
+    default:
+      return {{ render(component.default) }};
+      break;
+  }
+})()
 {% endmacro %}
 
 

+ 6 - 18
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -1,4 +1,4 @@
-import { createContext, useContext, useMemo, useReducer, useState } from "react"
+import { createContext, useContext, useMemo, useReducer, useState, createElement } from "react"
 import { applyDelta, Event, hydrateClientStorage, useEventLoop, refs } from "$/utils/state.js"
 
 {% if initial_state %}
@@ -77,11 +77,7 @@ export function UploadFilesProvider({ children }) {
     delete newFilesById[id]
     return newFilesById
   })
-  return (
-    <UploadFilesContext value={[filesById, setFilesById]}>
-      {children}
-    </UploadFilesContext>
-  )
+  return createElement(UploadFilesContext, {value:[filesById, setFilesById]}, children);
 }
 
 export function EventLoopProvider({ children }) {
@@ -91,11 +87,7 @@ export function EventLoopProvider({ children }) {
     initialEvents,
     clientStorage,
   )
-  return (
-    <EventLoopContext value={[addEvents, connectErrors]}>
-      {children}
-    </EventLoopContext>
-  )
+  return createElement(EventLoopContext, {value:[addEvents, connectErrors]}, children);
 }
 
 export function StateProvider({ children }) {
@@ -112,13 +104,9 @@ export function StateProvider({ children }) {
 
   return (
     {% for state_name in initial_state %}
-    <StateContexts.{{state_name|var_name}} value={ {{state_name|var_name}} }>
-    {% endfor %}
-      <DispatchContext value={dispatchers}>
-        {children}
-      </DispatchContext>
-    {% for state_name in initial_state|reverse %}
-    </StateContexts.{{state_name|var_name}}>
+    createElement(StateContexts.{{state_name|var_name}},{value: {{state_name|var_name}}},
     {% endfor %}
+    createElement(DispatchContext.Provider, {value: dispatchers}, children),
+    {% for state_name in initial_state|reverse %}){% endfor %}
   )
 }

+ 7 - 7
reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js

@@ -1,5 +1,5 @@
 import { useTheme } from "next-themes";
-import { useRef, useEffect, useState } from "react";
+import { useRef, useEffect, useState, createElement } from "react";
 import {
   ColorModeContext,
   defaultColorMode,
@@ -50,11 +50,11 @@ export default function RadixThemesColorModeProvider({ children }) {
     }
     setTheme(mode);
   };
-  return (
-    <ColorModeContext
-      value={{ rawColorMode, resolvedColorMode, toggleColorMode, setColorMode }}
-    >
-      {children}
-    </ColorModeContext>
+  return createElement(
+    ColorModeContext,
+    {
+      value: { rawColorMode, resolvedColorMode, toggleColorMode, setColorMode },
+    },
+    children,
   );
 }

+ 5 - 4
reflex/.templates/web/components/shiki/code.js

@@ -1,4 +1,4 @@
-import { useEffect, useState } from "react";
+import { useEffect, useState, createElement } from "react";
 import { codeToHtml } from "shiki";
 
 /**
@@ -33,7 +33,8 @@ export function Code({
     }
     fetchCode();
   }, [code, language, theme, transformers, decorations]);
-  return (
-    <div dangerouslySetInnerHTML={{ __html: codeResult }} {...divProps}></div>
-  );
+  return createElement("div", {
+    dangerouslySetInnerHTML: { __html: codeResult },
+    ...divProps,
+  });
 }

+ 12 - 4
reflex/app.py

@@ -29,7 +29,7 @@ from starlette.datastructures import Headers
 from starlette.datastructures import UploadFile as StarletteUploadFile
 from starlette.exceptions import HTTPException
 from starlette.middleware import cors
-from starlette.requests import Request
+from starlette.requests import ClientDisconnect, Request
 from starlette.responses import JSONResponse, Response, StreamingResponse
 from starlette.staticfiles import StaticFiles
 from typing_extensions import deprecated
@@ -491,7 +491,7 @@ class App(MiddlewareMixin, LifespanMixin):
             set_breakpoints(self.style.pop("breakpoints"))
 
         # Set up the API.
-        self._api = Starlette(lifespan=self._run_lifespan_tasks)
+        self._api = Starlette()
         App._add_cors(self._api)
         self._add_default_endpoints()
 
@@ -634,6 +634,7 @@ class App(MiddlewareMixin, LifespanMixin):
 
         if not self._api:
             raise ValueError("The app has not been initialized.")
+
         if self._cached_fastapi_app is not None:
             asgi_app = self._cached_fastapi_app
             asgi_app.mount("", self._api)
@@ -658,7 +659,11 @@ class App(MiddlewareMixin, LifespanMixin):
                     # Transform the asgi app.
                     asgi_app = api_transformer(asgi_app)
 
-        return asgi_app
+        top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
+        top_asgi_app.mount("", asgi_app)
+        App._add_cors(top_asgi_app)
+
+        return top_asgi_app
 
     def _add_default_endpoints(self):
         """Add default api endpoints (ping)."""
@@ -1838,7 +1843,10 @@ def upload(app: App):
         from reflex.utils.exceptions import UploadTypeError, UploadValueError
 
         # Get the files from the request.
-        files = await request.form()
+        try:
+            files = await request.form()
+        except ClientDisconnect:
+            return Response()  # user cancelled
         files = files.getlist("files")
         if not files:
             raise UploadValueError("No files were uploaded.")

+ 20 - 3
reflex/compiler/compiler.py

@@ -30,6 +30,13 @@ from reflex.utils.prerequisites import get_web_dir
 from reflex.vars.base import LiteralVar, Var
 
 
+def _apply_common_imports(
+    imports: dict[str, list[ImportVar]],
+):
+    imports.setdefault("@emotion/react", []).append(ImportVar("jsx"))
+    imports.setdefault("react", []).append(ImportVar("Fragment"))
+
+
 def _compile_document_root(root: Component) -> str:
     """Compile the document root.
 
@@ -39,8 +46,10 @@ def _compile_document_root(root: Component) -> str:
     Returns:
         The compiled document root.
     """
+    document_root_imports = root._get_all_imports()
+    _apply_common_imports(document_root_imports)
     return templates.DOCUMENT_ROOT.render(
-        imports=utils.compile_imports(root._get_all_imports()),
+        imports=utils.compile_imports(document_root_imports),
         document=root.render(),
     )
 
@@ -74,8 +83,11 @@ def _compile_app(app_root: Component) -> str:
         (_normalize_library_name(name), name) for name in bundled_libraries
     ]
 
+    app_root_imports = app_root._get_all_imports()
+    _apply_common_imports(app_root_imports)
+
     return templates.APP_ROOT.render(
-        imports=utils.compile_imports(app_root._get_all_imports()),
+        imports=utils.compile_imports(app_root_imports),
         custom_codes=app_root._get_all_custom_code(),
         hooks=app_root._get_all_hooks(),
         window_libraries=window_libraries,
@@ -143,6 +155,7 @@ def _compile_page(
         The compiled component.
     """
     imports = component._get_all_imports()
+    _apply_common_imports(imports)
     imports = utils.compile_imports(imports)
 
     # Compile the code to render the component.
@@ -325,7 +338,7 @@ def _compile_components(
     """
     imports = {
         "react": [ImportVar(tag="memo")],
-        f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
+        f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
     }
     component_renders = []
 
@@ -335,6 +348,8 @@ def _compile_components(
         component_renders.append(component_render)
         imports = utils.merge_imports(imports, component_imports)
 
+    _apply_common_imports(imports)
+
     dynamic_imports = {
         comp_import: None
         for comp_render in component_renders
@@ -427,6 +442,8 @@ def _compile_stateful_components(
     all_imports.pop(
         f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
     )
+    if rendered_components:
+        _apply_common_imports(all_imports)
 
     return templates.STATEFUL_COMPONENTS.render(
         imports=utils.compile_imports(all_imports),

+ 8 - 5
reflex/components/base/bare.py

@@ -169,11 +169,14 @@ class Bare(Component):
         return refs
 
     def _render(self) -> Tag:
-        if isinstance(self.contents, Var):
-            if isinstance(self.contents, (BooleanVar, ObjectVar)):
-                return Tagless(contents=f"{{{self.contents.to_string()!s}}}")
-            return Tagless(contents=f"{{{self.contents!s}}}")
-        return Tagless(contents=str(self.contents))
+        contents = (
+            Var.create(self.contents)
+            if not isinstance(self.contents, Var)
+            else self.contents
+        )
+        if isinstance(contents, (BooleanVar, ObjectVar)):
+            return Tagless(contents=f"{contents.to_string()!s}")
+        return Tagless(contents=f"{contents!s}")
 
     def _add_style_recursive(
         self, style: ComponentStyle, theme: Component | None = None

+ 2 - 4
reflex/components/base/body.py

@@ -1,9 +1,7 @@
 """Display the page body."""
 
-from reflex.components.component import Component
+from reflex.components.el import elements
 
 
-class Body(Component):
+class Body(elements.Body):
     """A body component."""
-
-    tag = "body"

+ 1 - 1
reflex/components/base/error_boundary.py

@@ -33,7 +33,7 @@ def on_error_spec(
 class ErrorBoundary(Component):
     """A React Error Boundary component that catches unhandled frontend exceptions."""
 
-    library = "react-error-boundary"
+    library = "react-error-boundary@6.0.0"
     tag = "ErrorBoundary"
 
     # Fired when the boundary catches an error.

+ 3 - 3
reflex/components/base/link.py

@@ -1,10 +1,10 @@
 """Display the title of the current page."""
 
-from reflex.components.component import Component
+from reflex.components.el.elements.base import BaseHTML
 from reflex.vars.base import Var
 
 
-class RawLink(Component):
+class RawLink(BaseHTML):
     """A component that displays the title of the current page."""
 
     tag = "link"
@@ -16,7 +16,7 @@ class RawLink(Component):
     rel: Var[str]
 
 
-class ScriptTag(Component):
+class ScriptTag(BaseHTML):
     """A script tag with the specified type and source."""
 
     tag = "script"

+ 5 - 9
reflex/components/base/meta.py

@@ -3,14 +3,12 @@
 from __future__ import annotations
 
 from reflex.components.base.bare import Bare
-from reflex.components.component import Component
+from reflex.components.el import elements
 
 
-class Title(Component):
+class Title(elements.Title):
     """A component that displays the title of the current page."""
 
-    tag = "title"
-
     def render(self) -> dict:
         """Render the title component.
 
@@ -26,11 +24,9 @@ class Title(Component):
         return super().render()
 
 
-class Meta(Component):
+class Meta(elements.Meta):
     """A component that displays metadata for the current page."""
 
-    tag = "meta"
-
     # The description of character encoding.
     char_set: str | None = None
 
@@ -47,14 +43,14 @@ class Meta(Component):
     http_equiv: str | None = None
 
 
-class Description(Meta):
+class Description(elements.Meta):
     """A component that displays the title of the current page."""
 
     # The type of the description.
     name: str | None = "description"
 
 
-class Image(Meta):
+class Image(elements.Meta):
     """A component that displays the title of the current page."""
 
     # The type of the image.

+ 50 - 37
reflex/components/component.py

@@ -43,6 +43,7 @@ from reflex.event import (
     no_args_event_spec,
     parse_args_spec,
     run_script,
+    unwrap_var_annotation,
 )
 from reflex.style import Style, format_as_emotion
 from reflex.utils import console, format, imports, types
@@ -50,14 +51,15 @@ from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_
 from reflex.vars import VarData
 from reflex.vars.base import (
     CachedVarOperation,
+    LiteralNoneVar,
     LiteralVar,
     Var,
     cached_property_no_lock,
 )
 from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
 from reflex.vars.number import ternary_operation
-from reflex.vars.object import ObjectVar
-from reflex.vars.sequence import LiteralArrayVar
+from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import LiteralArrayVar, LiteralStringVar, StringVar
 
 
 class BaseComponent(Base, ABC):
@@ -266,6 +268,9 @@ class Component(BaseComponent, ABC):
     # The alias for the tag.
     alias: str | None = pydantic.v1.Field(default_factory=lambda: None)
 
+    # Whether the component is a global scope tag. True for tags like `html`, `head`, `body`.
+    _is_tag_in_global_scope: ClassVar[bool] = False
+
     # Whether the import is default or named.
     is_default: bool | None = pydantic.v1.Field(default_factory=lambda: False)
 
@@ -598,13 +603,36 @@ class Component(BaseComponent, ABC):
         # Convert class_name to str if it's list
         class_name = kwargs.get("class_name", "")
         if isinstance(class_name, (list, tuple)):
-            if any(isinstance(c, Var) for c in class_name):
+            has_var = False
+            for c in class_name:
+                if isinstance(c, str):
+                    continue
+                if isinstance(c, Var):
+                    if not isinstance(c, StringVar) and not issubclass(
+                        c._var_type, str
+                    ):
+                        raise TypeError(
+                            f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c._js_expr} of type {c._var_type}."
+                        )
+                    has_var = True
+                else:
+                    raise TypeError(
+                        f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c} of type {type(c)}."
+                    )
+            if has_var:
                 kwargs["class_name"] = LiteralArrayVar.create(
                     class_name, _var_type=list[str]
                 ).join(" ")
             else:
                 kwargs["class_name"] = " ".join(class_name)
-
+        elif (
+            isinstance(class_name, Var)
+            and not isinstance(class_name, StringVar)
+            and not issubclass(class_name._var_type, str)
+        ):
+            raise TypeError(
+                f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {class_name._js_expr} of type {class_name._var_type}."
+            )
         # Construct the component.
         for key, value in kwargs.items():
             setattr(self, key, value)
@@ -664,9 +692,14 @@ class Component(BaseComponent, ABC):
         Returns:
             The tag to render.
         """
+        # Create the base tag.
+        name = (self.tag if not self.alias else self.alias) or ""
+        if self._is_tag_in_global_scope and self.library is None:
+            name = '"' + name + '"'
+
         # Create the base tag.
         tag = Tag(
-            name=(self.tag if not self.alias else self.alias) or "",
+            name=name,
             special_props=self.special_props,
         )
 
@@ -1146,7 +1179,7 @@ class Component(BaseComponent, ABC):
                 vars.append(comp_prop)
             elif isinstance(comp_prop, str):
                 # Collapse VarData encoded in f-strings.
-                var = LiteralVar.create(comp_prop)
+                var = LiteralStringVar.create(comp_prop)
                 if var._get_all_var_data() is not None:
                     vars.append(var)
 
@@ -1912,8 +1945,8 @@ def _register_custom_component(
         prop: (
             Var(
                 "",
-                _var_type=annotation,
-            )
+                _var_type=unwrap_var_annotation(annotation),
+            ).guess_type()
             if not types.safe_issubclass(annotation, EventHandler)
             else EventSpec(handler=EventHandler(fn=lambda: []))
         )
@@ -2494,7 +2527,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
         return Var.create(tag)
 
     if "iterable" in tag:
-        function_return = Var.create(
+        function_return = LiteralArrayVar.create(
             [
                 render_dict_to_var(child.render(), imported_names)
                 for child in tag["children"]
@@ -2537,7 +2570,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
             render_dict_to_var(tag["true_value"], imported_names),
             render_dict_to_var(tag["false_value"], imported_names)
             if tag["false_value"] is not None
-            else Var.create(None),
+            else LiteralNoneVar.create(),
         )
 
     props = {}
@@ -2545,32 +2578,26 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
     special_props = []
 
     for prop_str in tag["props"]:
-        if "=" not in prop_str:
+        if ":" not in prop_str:
             special_props.append(Var(prop_str).to(ObjectVar))
             continue
-        prop = prop_str.index("=")
+        prop = prop_str.index(":")
         key = prop_str[:prop]
-        value = prop_str[prop + 2 : -1]
+        value = prop_str[prop + 1 :]
         props[key] = value
 
-    props = Var.create({Var.create(k): Var(v) for k, v in props.items()})
+    props = LiteralObjectVar.create(
+        {LiteralStringVar.create(k): Var(v) for k, v in props.items()}
+    )
 
     for prop in special_props:
         props = props.merge(prop)
 
-    contents = tag["contents"][1:-1] if tag["contents"] else None
+    contents = tag["contents"] if tag["contents"] else None
 
     raw_tag_name = tag.get("name")
     tag_name = Var(raw_tag_name or "Fragment")
 
-    tag_name = (
-        Var.create(raw_tag_name)
-        if raw_tag_name
-        and raw_tag_name.split(".")[0] not in imported_names
-        and raw_tag_name.lower() == raw_tag_name
-        else tag_name
-    )
-
     return FunctionStringVar.create(
         "jsx",
     ).call(
@@ -2615,23 +2642,9 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
         """
         return VarData.merge(
             self._var_data,
-            VarData(
-                imports={
-                    "@emotion/react": [
-                        ImportVar(tag="jsx"),
-                    ],
-                }
-            ),
             VarData(
                 imports=self._var_value._get_all_imports(),
             ),
-            VarData(
-                imports={
-                    "react": [
-                        ImportVar(tag="Fragment"),
-                    ],
-                }
-            ),
         )
 
     def __hash__(self) -> int:

+ 35 - 4
reflex/components/core/colors.py

@@ -1,11 +1,22 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
-from reflex.constants.colors import Color, ColorType, ShadeType
-from reflex.utils.types import validate_parameter_literals
+from reflex.constants.base import REFLEX_VAR_OPENING_TAG
+from reflex.constants.colors import (
+    COLORS,
+    MAX_SHADE_VALUE,
+    MIN_SHADE_VALUE,
+    Color,
+    ColorType,
+    ShadeType,
+)
+from reflex.vars.base import Var
 
 
-@validate_parameter_literals
-def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
+def color(
+    color: ColorType | Var[str],
+    shade: ShadeType | Var[int] = 7,
+    alpha: bool | Var[bool] = False,
+) -> Color:
     """Create a color object.
 
     Args:
@@ -15,5 +26,25 @@ def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color:
 
     Returns:
         The color object.
+
+    Raises:
+        ValueError: If the color, shade, or alpha are not valid.
     """
+    if isinstance(color, str):
+        if color not in COLORS and REFLEX_VAR_OPENING_TAG not in color:
+            raise ValueError(f"Color must be one of {COLORS}, received {color}")
+    elif not isinstance(color, Var):
+        raise ValueError("Color must be a string or a Var")
+
+    if isinstance(shade, int):
+        if shade < MIN_SHADE_VALUE or shade > MAX_SHADE_VALUE:
+            raise ValueError(
+                f"Shade must be between {MIN_SHADE_VALUE} and {MAX_SHADE_VALUE}"
+            )
+    elif not isinstance(shade, Var):
+        raise ValueError("Shade must be an integer or a Var")
+
+    if not isinstance(alpha, (bool, Var)):
+        raise ValueError("Alpha must be a boolean or a Var")
+
     return Color(color, shade, alpha)

+ 2 - 2
reflex/components/core/cond.py

@@ -5,7 +5,7 @@ from __future__ import annotations
 from typing import Any, overload
 
 from reflex.components.base.fragment import Fragment
-from reflex.components.component import BaseComponent, Component, MemoizationLeaf
+from reflex.components.component import BaseComponent, Component
 from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
 from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
@@ -20,7 +20,7 @@ _IS_TRUE_IMPORT: ImportDict = {
 }
 
 
-class Cond(MemoizationLeaf):
+class Cond(Component):
     """Render one of two components based on a condition."""
 
     # The cond to determine which component to render.

+ 50 - 18
reflex/components/core/upload.py

@@ -13,6 +13,7 @@ from reflex.components.component import (
     MemoizationLeaf,
     StatefulComponent,
 )
+from reflex.components.core.cond import cond
 from reflex.components.el.elements.forms import Input
 from reflex.components.radix.themes.layout.box import Box
 from reflex.config import environment
@@ -28,6 +29,7 @@ from reflex.event import (
     parse_args_spec,
     run_script,
 )
+from reflex.style import Style
 from reflex.utils import format
 from reflex.utils.imports import ImportVar
 from reflex.vars import VarData
@@ -231,6 +233,9 @@ class Upload(MemoizationLeaf):
     # Fired when files are dropped.
     on_drop: EventHandler[_on_drop_spec]
 
+    # Style rules to apply when actively dragging.
+    drag_active_style: Style | None = None
+
     @classmethod
     def create(cls, *children, **props) -> Component:
         """Create an upload component.
@@ -266,25 +271,46 @@ class Upload(MemoizationLeaf):
             # If on_drop is not provided, save files to be uploaded later.
             upload_props["on_drop"] = upload_file(upload_props["id"])
         else:
-            on_drop = upload_props["on_drop"]
-            if isinstance(on_drop, (EventHandler, EventSpec)):
-                # Call the lambda to get the event chain.
-                on_drop = call_event_handler(on_drop, _on_drop_spec)
-            elif isinstance(on_drop, Callable):
-                # Call the lambda to get the event chain.
-                on_drop = call_event_fn(on_drop, _on_drop_spec)
-            if isinstance(on_drop, EventSpec):
-                # Update the provided args for direct use with on_drop.
-                on_drop = on_drop.with_args(
-                    args=tuple(
-                        cls._update_arg_tuple_for_on_drop(arg_value)
-                        for arg_value in on_drop.args
-                    ),
-                )
+            on_drop = (
+                [on_drop_prop]
+                if not isinstance(on_drop_prop := upload_props["on_drop"], Sequence)
+                else list(on_drop_prop)
+            )
+            for ix, event in enumerate(on_drop):
+                if isinstance(event, (EventHandler, EventSpec)):
+                    # Call the lambda to get the event chain.
+                    event = call_event_handler(event, _on_drop_spec)
+                elif isinstance(event, Callable):
+                    # Call the lambda to get the event chain.
+                    event = call_event_fn(event, _on_drop_spec)
+                if isinstance(event, EventSpec):
+                    # Update the provided args for direct use with on_drop.
+                    event = event.with_args(
+                        args=tuple(
+                            cls._update_arg_tuple_for_on_drop(arg_value)
+                            for arg_value in event.args
+                        ),
+                    )
+                on_drop[ix] = event
             upload_props["on_drop"] = on_drop
 
         input_props_unique_name = get_unique_variable_name()
         root_props_unique_name = get_unique_variable_name()
+        is_drag_active_unique_name = get_unique_variable_name()
+        drag_active_css_class_unique_name = get_unique_variable_name() + "-drag-active"
+
+        # Handle special style when dragging over the drop zone.
+        if "drag_active_style" in props:
+            props.setdefault("style", Style())[
+                f"&:where(.{drag_active_css_class_unique_name})"
+            ] = props.pop("drag_active_style")
+            props["class_name"].append(
+                cond(
+                    Var(is_drag_active_unique_name),
+                    drag_active_css_class_unique_name,
+                    "",
+                ),
+            )
 
         event_var, callback_str = StatefulComponent._get_memoized_event_triggers(
             GhostUpload.create(on_drop=upload_props["on_drop"])
@@ -303,7 +329,13 @@ class Upload(MemoizationLeaf):
             }
         )
 
-        left_side = f"const {{getRootProps: {root_props_unique_name}, getInputProps: {input_props_unique_name}}} "
+        left_side = (
+            "const { "
+            f"getRootProps: {root_props_unique_name}, "
+            f"getInputProps: {input_props_unique_name}, "
+            f"isDragActive: {is_drag_active_unique_name}"
+            "}"
+        )
         right_side = f"useDropzone({use_dropzone_arguments!s})"
 
         var_data = VarData.merge(
@@ -329,7 +361,7 @@ class Upload(MemoizationLeaf):
         upload = Input.create(type="file")
         upload.special_props = [
             Var(
-                _js_expr=f"{{...{input_props_unique_name}()}}",
+                _js_expr=f"{input_props_unique_name}()",
                 _var_type=None,
                 _var_data=var_data,
             )
@@ -343,7 +375,7 @@ class Upload(MemoizationLeaf):
         )
         zone.special_props = [
             Var(
-                _js_expr=f"{{...{root_props_unique_name}()}}",
+                _js_expr=f"{root_props_unique_name}()",
                 _var_type=None,
                 _var_data=var_data,
             )

+ 2 - 2
reflex/components/datadisplay/shiki_code_block.py

@@ -421,7 +421,7 @@ class ShikiBaseTransformers(Base):
 class ShikiJsTransformer(ShikiBaseTransformers):
     """A Wrapped shikijs transformer."""
 
-    library: str = "@shikijs/transformers"
+    library: str = "@shikijs/transformers@3.3.0"
     fns: list[FunctionStringVar] = [
         FunctionStringVar.create(fn) for fn in SHIKIJS_TRANSFORMER_FNS
     ]
@@ -538,7 +538,7 @@ class ShikiCodeBlock(Component, MarkdownComponentMap):
 
     alias = "ShikiCode"
 
-    lib_dependencies: list[str] = ["shiki"]
+    lib_dependencies: list[str] = ["shiki@3.3.0"]
 
     # The language to use.
     language: Var[LiteralCodeLanguage] = Var.create("python")

+ 5 - 2
reflex/components/dynamic.py

@@ -72,7 +72,7 @@ def load_dynamic_serializer():
             The generated code
         """
         # Causes a circular import, so we import here.
-        from reflex.compiler import templates, utils
+        from reflex.compiler import compiler, templates, utils
         from reflex.components.base.bare import Bare
 
         component = Bare.create(Var.create(component))
@@ -97,8 +97,11 @@ def load_dynamic_serializer():
 
         libs_in_window = bundled_libraries
 
+        component_imports = component._get_all_imports()
+        compiler._apply_common_imports(component_imports)
+
         imports = {}
-        for lib, names in component._get_all_imports().items():
+        for lib, names in component_imports.items():
             formatted_lib_name = format_library_name(lib)
             if (
                 not lib.startswith((".", "/", "$/"))

+ 4 - 0
reflex/components/el/element.py

@@ -1,11 +1,15 @@
 """Base class definition for raw HTML elements."""
 
+from typing import ClassVar
+
 from reflex.components.component import Component
 
 
 class Element(Component):
     """The base class for all raw HTML elements."""
 
+    _is_tag_in_global_scope: ClassVar[bool] = True
+
     def __eq__(self, other: object):
         """Two elements are equal if they have the same tag.
 

+ 3 - 1
reflex/components/el/elements/forms.py

@@ -4,7 +4,7 @@ from __future__ import annotations
 
 from collections.abc import Iterator
 from hashlib import md5
-from typing import Any, Literal
+from typing import Any, ClassVar, Literal
 
 from jinja2 import Environment
 
@@ -86,6 +86,8 @@ class Button(BaseHTML):
     # Value of the button, used when sending form data
     value: Var[str | int | float]
 
+    _invalid_children: ClassVar[list[str]] = ["Button"]
+
 
 class Datalist(BaseHTML):
     """Display the datalist element."""

+ 3 - 1
reflex/components/el/elements/inline.py

@@ -1,6 +1,6 @@
 """Inline classes."""
 
-from typing import Literal
+from typing import ClassVar, Literal
 
 from reflex.vars.base import Var
 
@@ -48,6 +48,8 @@ class A(BaseHTML):  # Inherits common attributes from BaseMeta
     # Specifies where to open the linked document
     target: Var[str | Literal["_self", "_blank", "_parent", "_top"]]
 
+    _invalid_children: ClassVar[list[str]] = ["A"]
+
 
 class Abbr(BaseHTML):
     """Display the abbr element."""

+ 1 - 1
reflex/components/el/elements/metadata.py

@@ -89,7 +89,7 @@ class StyleEl(Element):
 
     media: Var[str]
 
-    special_props: list[Var] = [Var(_js_expr="suppressHydrationWarning")]
+    suppress_hydration_warning: Var[bool] = Var.create(True)
 
 
 base = Base.create

+ 3 - 1
reflex/components/el/elements/typography.py

@@ -1,6 +1,6 @@
 """Typography classes."""
 
-from typing import Literal
+from typing import ClassVar, Literal
 
 from reflex.vars.base import Var
 
@@ -87,6 +87,8 @@ class P(BaseHTML):
 
     tag = "p"
 
+    _invalid_children: ClassVar[list] = ["P", "Ol", "Ul", "Div"]
+
 
 class Pre(BaseHTML):
     """Display the pre element."""

+ 46 - 3
reflex/components/lucide/icon.py

@@ -10,7 +10,7 @@ from reflex.vars.sequence import LiteralStringVar, StringVar
 class LucideIconComponent(Component):
     """Lucide Icon Component."""
 
-    library = "lucide-react@0.471.1"
+    library = "lucide-react@0.507.0"
 
 
 class Icon(LucideIconComponent):
@@ -75,13 +75,12 @@ class Icon(LucideIconComponent):
             )
             console.warn(
                 f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(icons_sorted[0:10])}, ..."
-                "\nSee full list at https://reflex.dev/docs/library/data-display/icon/#icons-list. Using 'circle-help' icon instead."
+                "\nSee full list at https://reflex.dev/docs/library/data-display/icon/#icons-list. Using 'circle_help' icon instead."
             )
             tag = "circle_help"
 
         props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE.get(tag, format.to_title_case(tag))
         props["alias"] = f"Lucide{props['tag']}"
-        props.setdefault("color", "var(--current-color)")
         return super().create(**props)
 
 
@@ -234,6 +233,8 @@ LUCIDE_ICON_LIST = [
     "banana",
     "bandage",
     "banknote",
+    "banknote_arrow_down",
+    "banknote_x",
     "bar_chart",
     "bar_chart_2",
     "bar_chart_3",
@@ -249,6 +250,7 @@ LUCIDE_ICON_LIST = [
     "battery_full",
     "battery_low",
     "battery_medium",
+    "battery_plus",
     "battery_warning",
     "beaker",
     "bean",
@@ -321,6 +323,7 @@ LUCIDE_ICON_LIST = [
     "bot",
     "bot_message_square",
     "bot_off",
+    "bow_arrow",
     "box",
     "box_select",
     "boxes",
@@ -330,12 +333,15 @@ LUCIDE_ICON_LIST = [
     "brain_circuit",
     "brain_cog",
     "brick_wall",
+    "brick_wall_fire",
     "briefcase",
     "briefcase_business",
     "briefcase_conveyor_belt",
     "briefcase_medical",
     "bring_to_front",
     "brush",
+    "brush_cleaning",
+    "bubbles",
     "bug",
     "bug_off",
     "bug_play",
@@ -475,6 +481,7 @@ LUCIDE_ICON_LIST = [
     "circle_power",
     "circle_slash",
     "circle_slash_2",
+    "circle_small",
     "circle_stop",
     "circle_user",
     "circle_user_round",
@@ -509,6 +516,8 @@ LUCIDE_ICON_LIST = [
     "clock_alert",
     "clock_arrow_down",
     "clock_arrow_up",
+    "clock_fading",
+    "clock_plus",
     "cloud",
     "cloud_alert",
     "cloud_cog",
@@ -538,6 +547,7 @@ LUCIDE_ICON_LIST = [
     "coins",
     "columns_2",
     "columns_3",
+    "columns_3_cog",
     "columns_4",
     "combine",
     "command",
@@ -585,6 +595,8 @@ LUCIDE_ICON_LIST = [
     "database",
     "database_backup",
     "database_zap",
+    "decimals_arrow_left",
+    "decimals_arrow_right",
     "delete",
     "dessert",
     "diameter",
@@ -612,6 +624,7 @@ LUCIDE_ICON_LIST = [
     "dollar_sign",
     "donut",
     "door_closed",
+    "door_closed_locked",
     "door_open",
     "dot",
     "download",
@@ -785,6 +798,9 @@ LUCIDE_ICON_LIST = [
     "frown",
     "fuel",
     "fullscreen",
+    "funnel",
+    "funnel_plus",
+    "funnel_x",
     "gallery_horizontal",
     "gallery_horizontal_end",
     "gallery_thumbnails",
@@ -836,6 +852,7 @@ LUCIDE_ICON_LIST = [
     "group",
     "guitar",
     "ham",
+    "hamburger",
     "hammer",
     "hand",
     "hand_coins",
@@ -864,7 +881,9 @@ LUCIDE_ICON_LIST = [
     "heart",
     "heart_crack",
     "heart_handshake",
+    "heart_minus",
     "heart_off",
+    "heart_plus",
     "heart_pulse",
     "heater",
     "hexagon",
@@ -975,6 +994,7 @@ LUCIDE_ICON_LIST = [
     "locate",
     "locate_fixed",
     "locate_off",
+    "location_edit",
     "lock",
     "lock_keyhole",
     "lock_keyhole_open",
@@ -1009,6 +1029,8 @@ LUCIDE_ICON_LIST = [
     "map_pin_x",
     "map_pin_x_inside",
     "map_pinned",
+    "map_plus",
+    "mars_stroke",
     "martini",
     "maximize",
     "maximize_2",
@@ -1107,6 +1129,7 @@ LUCIDE_ICON_LIST = [
     "network",
     "newspaper",
     "nfc",
+    "non_binary",
     "notebook",
     "notebook_pen",
     "notebook_tabs",
@@ -1249,6 +1272,7 @@ LUCIDE_ICON_LIST = [
     "receipt_swiss_franc",
     "receipt_text",
     "rectangle_ellipsis",
+    "rectangle_goggles",
     "rectangle_horizontal",
     "rectangle_vertical",
     "recycle",
@@ -1276,6 +1300,7 @@ LUCIDE_ICON_LIST = [
     "roller_coaster",
     "rotate_3d",
     "rotate_ccw",
+    "rotate_ccw_key",
     "rotate_ccw_square",
     "rotate_cw",
     "rotate_cw_square",
@@ -1287,12 +1312,14 @@ LUCIDE_ICON_LIST = [
     "rows_4",
     "rss",
     "ruler",
+    "ruler_dimension_line",
     "russian_ruble",
     "sailboat",
     "salad",
     "sandwich",
     "satellite",
     "satellite_dish",
+    "saudi_riyal",
     "save",
     "save_all",
     "save_off",
@@ -1348,6 +1375,7 @@ LUCIDE_ICON_LIST = [
     "shield_off",
     "shield_plus",
     "shield_question",
+    "shield_user",
     "shield_x",
     "ship",
     "ship_wheel",
@@ -1357,6 +1385,8 @@ LUCIDE_ICON_LIST = [
     "shopping_cart",
     "shovel",
     "shower_head",
+    "shredder",
+    "shrimp",
     "shrink",
     "shrub",
     "shuffle",
@@ -1385,6 +1415,7 @@ LUCIDE_ICON_LIST = [
     "smile_plus",
     "snail",
     "snowflake",
+    "soap_dispenser_droplet",
     "sofa",
     "soup",
     "space",
@@ -1396,6 +1427,7 @@ LUCIDE_ICON_LIST = [
     "spell_check",
     "spell_check_2",
     "spline",
+    "spline_pointer",
     "split",
     "spray_can",
     "sprout",
@@ -1449,6 +1481,7 @@ LUCIDE_ICON_LIST = [
     "square_plus",
     "square_power",
     "square_radical",
+    "square_round_corner",
     "square_scissors",
     "square_sigma",
     "square_slash",
@@ -1460,6 +1493,10 @@ LUCIDE_ICON_LIST = [
     "square_user",
     "square_user_round",
     "square_x",
+    "squares_exclude",
+    "squares_intersect",
+    "squares_subtract",
+    "squares_unite",
     "squircle",
     "squirrel",
     "stamp",
@@ -1556,6 +1593,7 @@ LUCIDE_ICON_LIST = [
     "train_front_tunnel",
     "train_track",
     "tram_front",
+    "transgender",
     "trash",
     "trash_2",
     "tree_deciduous",
@@ -1572,6 +1610,7 @@ LUCIDE_ICON_LIST = [
     "triangle_right",
     "trophy",
     "truck",
+    "truck_electric",
     "turtle",
     "tv",
     "tv_2",
@@ -1599,6 +1638,7 @@ LUCIDE_ICON_LIST = [
     "user",
     "user_check",
     "user_cog",
+    "user_lock",
     "user_minus",
     "user_pen",
     "user_plus",
@@ -1621,6 +1661,8 @@ LUCIDE_ICON_LIST = [
     "vault",
     "vegan",
     "venetian_mask",
+    "venus",
+    "venus_and_mars",
     "vibrate",
     "vibrate_off",
     "video",
@@ -1658,6 +1700,7 @@ LUCIDE_ICON_LIST = [
     "wifi_high",
     "wifi_low",
     "wifi_off",
+    "wifi_pen",
     "wifi_zero",
     "wind",
     "wind_arrow_down",

+ 6 - 7
reflex/components/markdown/markdown.py

@@ -18,8 +18,8 @@ from reflex.vars.number import ternary_operation
 
 # Special vars used in the component map.
 _CHILDREN = Var(_js_expr="children", _var_type=str)
-_PROPS = Var(_js_expr="...props")
-_PROPS_IN_TAG = Var(_js_expr="{...props}")
+_PROPS = Var(_js_expr="props")
+_PROPS_SPREAD = Var(_js_expr="...props")
 _MOCK_ARG = Var(_js_expr="", _var_type=str)
 _LANGUAGE = Var(_js_expr="_language", _var_type=str)
 
@@ -128,7 +128,7 @@ class MarkdownComponentMap:
         Returns:
             The function arguments as a list of strings.
         """
-        return ["node", _CHILDREN._js_expr, _PROPS._js_expr]
+        return ["node", _CHILDREN._js_expr, _PROPS_SPREAD._js_expr]
 
     @classmethod
     def get_fn_body(cls) -> Var:
@@ -297,7 +297,7 @@ let {_LANGUAGE!s} = match ? match[1] : '';
                 "inline",
                 "className",
                 _CHILDREN._js_expr,
-                _PROPS._js_expr,
+                _PROPS_SPREAD._js_expr,
             ),
             fn_body=Var(_js_expr=formatted_code),
             explicit_return=True,
@@ -321,7 +321,7 @@ let {_LANGUAGE!s} = match ? match[1] : '';
         if tag not in self.component_map:
             raise ValueError(f"No markdown component found for tag: {tag}.")
 
-        special_props = [_PROPS_IN_TAG]
+        special_props = [_PROPS]
         children = [
             _CHILDREN
             if tag != "codeblock"
@@ -338,9 +338,8 @@ let {_LANGUAGE!s} = match ? match[1] : '';
             special_props = []
 
         # If the children are set as a prop, don't pass them as children.
-        children_prop = props.pop("children", None)
+        children_prop = props.get("children")
         if children_prop is not None:
-            special_props.append(Var(_js_expr=f"children={{{children_prop!s}}}"))
             children = []
         # Get the component.
         component = self.component_map[tag](*children, **props).set(

+ 1 - 1
reflex/components/moment/moment.py

@@ -29,7 +29,7 @@ class Moment(NoSSRComponent):
 
     tag: str | None = "Moment"
     is_default = True
-    library: str | None = "react-moment"
+    library: str | None = "react-moment@1.1.3"
     lib_dependencies: list[str] = ["moment"]
 
     # How often the date update (how often time update / 0 to disable).

+ 8 - 1
reflex/components/next/video.py

@@ -1,6 +1,7 @@
 """Wrapping of the next-video component."""
 
 from reflex.components.component import Component
+from reflex.utils import console
 from reflex.vars.base import Var
 
 from .base import NextComponent
@@ -10,7 +11,7 @@ class Video(NextComponent):
     """A video component from NextJS."""
 
     tag = "Video"
-    library = "next-video"
+    library = "next-video@2.2.0"
     is_default = True
     # the URL
     src: Var[str]
@@ -28,4 +29,10 @@ class Video(NextComponent):
         Returns:
             The Video component.
         """
+        console.deprecate(
+            "next-video",
+            "The next-video component is deprecated. Use `rx.video` instead.",
+            deprecation_version="0.7.11",
+            removal_version="0.8.0",
+        )
         return super().create(*children, **props)

+ 1 - 1
reflex/components/plotly/plotly.py

@@ -252,7 +252,7 @@ const extractPoints = (points) => {
             )
         else:
             # Spread the figure dict over props, nothing to merge.
-            tag.special_props.append(Var(_js_expr=f"{{...{figure!s}}}"))
+            tag.special_props.append(Var(_js_expr=f"{figure!s}"))
         return tag
 
 

+ 1 - 1
reflex/components/radix/primitives/drawer.py

@@ -19,7 +19,7 @@ from reflex.vars.base import Var
 class DrawerComponent(RadixPrimitiveComponent):
     """A Drawer component."""
 
-    library = "vaul"
+    library = "vaul@1.1.2"
 
     lib_dependencies: list[str] = ["@radix-ui/react-dialog@^1.1.6"]
 

+ 3 - 2
reflex/components/radix/themes/layout/list.py

@@ -5,8 +5,9 @@ from __future__ import annotations
 from collections.abc import Iterable
 from typing import Any, Literal
 
-from reflex.components.component import Component, ComponentNamespace
+from reflex.components.component import ComponentNamespace
 from reflex.components.core.foreach import Foreach
+from reflex.components.el.elements.base import BaseHTML
 from reflex.components.el.elements.typography import Li, Ol, Ul
 from reflex.components.lucide.icon import Icon
 from reflex.components.markdown.markdown import MarkdownComponentMap
@@ -38,7 +39,7 @@ LiteralListStyleTypeOrdered = Literal[
 ]
 
 
-class BaseList(Component, MarkdownComponentMap):
+class BaseList(BaseHTML, MarkdownComponentMap):
     """Base class for ordered and unordered lists."""
 
     tag = "ul"

+ 21 - 7
reflex/components/react_player/react_player.py

@@ -2,7 +2,7 @@
 
 from __future__ import annotations
 
-from typing import TypedDict
+from typing import Any, TypedDict
 
 from reflex.components.component import NoSSRComponent
 from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec
@@ -50,12 +50,6 @@ class ReactPlayer(NoSSRComponent):
     # Mutes the player
     muted: Var[bool]
 
-    # Set the width of the player: ex:640px
-    width: Var[str]
-
-    # Set the height of the player: ex:640px
-    height: Var[str]
-
     # Called when media is loaded and ready to play. If playing is set to true, media will play immediately.
     on_ready: EventHandler[no_args_event_spec]
 
@@ -103,3 +97,23 @@ class ReactPlayer(NoSSRComponent):
 
     # Called when picture-in-picture mode is disabled.
     on_disable_pip: EventHandler[no_args_event_spec]
+
+    def _render(self, props: dict[str, Any] | None = None):
+        """Render the component. Adds width and height set to None because
+        react-player will set them to some random value that overrides the
+        css width and height.
+
+        Args:
+            props: The props to pass to the component.
+
+        Returns:
+            The rendered component.
+        """
+        return (
+            super()
+            ._render(props)
+            .add_props(
+                width=Var.create(None),
+                height=Var.create(None),
+            )
+        )

+ 1 - 1
reflex/components/suneditor/editor.py

@@ -103,7 +103,7 @@ class Editor(NoSSRComponent):
     refer to the library docs for a complete list.
     """
 
-    library = "suneditor-react"
+    library = "suneditor-react@3.6.1"
 
     tag = "SunEditor"
 

+ 5 - 2
reflex/config.py

@@ -56,7 +56,7 @@ def _load_dotenv_from_str(env_files: str) -> None:
 
     if load_dotenv is None:
         console.error(
-            """The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.0.1"`."""
+            """The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.1.0"`."""
         )
         return
 
@@ -69,7 +69,7 @@ def _load_dotenv_from_str(env_files: str) -> None:
 
 
 # Load the env files at import time if they are set in the ENV_FILE environment variable.
-if load_dotenv is not None and (env_files := os.getenv("ENV_FILE")):
+if env_files := os.getenv("ENV_FILE"):
     _load_dotenv_from_str(env_files)
 
 
@@ -880,6 +880,9 @@ class Config(Base):
     # Path to file containing key-values pairs to override in the environment; Dotenv format.
     env_file: str | None = None
 
+    # Whether to automatically create setters for state base vars
+    state_auto_setters: bool = True
+
     # Whether to display the sticky "Built with Reflex" badge on all pages.
     show_built_with_reflex: bool | None = None
 

+ 1 - 1
reflex/constants/base.py

@@ -134,7 +134,7 @@ class Templates(SimpleNamespace):
     DEFAULT_TEMPLATE_URL = "https://blank-template.reflex.run"
 
     # The reflex.build frontend host
-    REFLEX_BUILD_FRONTEND = "https://flexgen.reflex.run"
+    REFLEX_BUILD_FRONTEND = "https://reflex.build"
 
     # The reflex.build backend host
     REFLEX_BUILD_BACKEND = "https://flexgen-prod-flexgen.fly.dev"

+ 23 - 6
reflex/constants/colors.py

@@ -1,7 +1,12 @@
 """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors."""
 
+from __future__ import annotations
+
 from dataclasses import dataclass
-from typing import Literal
+from typing import TYPE_CHECKING, Literal, get_args
+
+if TYPE_CHECKING:
+    from reflex.vars import Var
 
 ColorType = Literal[
     "gray",
@@ -40,10 +45,16 @@ ColorType = Literal[
     "white",
 ]
 
+COLORS = frozenset(get_args(ColorType))
+
 ShadeType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+MIN_SHADE_VALUE = 1
+MAX_SHADE_VALUE = 12
 
 
-def format_color(color: ColorType, shade: ShadeType, alpha: bool) -> str:
+def format_color(
+    color: ColorType | Var[str], shade: ShadeType | Var[int], alpha: bool | Var[bool]
+) -> str:
     """Format a color as a CSS color string.
 
     Args:
@@ -54,7 +65,13 @@ def format_color(color: ColorType, shade: ShadeType, alpha: bool) -> str:
     Returns:
         The formatted color.
     """
-    return f"var(--{color}-{'a' if alpha else ''}{shade})"
+    if isinstance(alpha, bool):
+        return f"var(--{color}-{'a' if alpha else ''}{shade})"
+
+    from reflex.components.core import cond
+
+    alpha_var = cond(alpha, "a", "")
+    return f"var(--{color}-{alpha_var}{shade})"
 
 
 @dataclass
@@ -62,13 +79,13 @@ class Color:
     """A color in the Reflex color palette."""
 
     # The color palette to use
-    color: ColorType
+    color: ColorType | Var[str]
 
     # The shade of the color to use
-    shade: ShadeType = 7
+    shade: ShadeType | Var[int] = 7
 
     # Whether to use the alpha variant of the color
-    alpha: bool = False
+    alpha: bool | Var[bool] = False
 
     def __format__(self, format_spec: str) -> str:
         """Format the color as a CSS color string.

+ 20 - 6
reflex/constants/installer.py

@@ -14,7 +14,7 @@ class Bun(SimpleNamespace):
     """Bun constants."""
 
     # The Bun version.
-    VERSION = "1.2.10"
+    VERSION = "1.2.12"
 
     # Min Bun Version
     MIN_VERSION = "1.2.8"
@@ -75,7 +75,7 @@ fetch-retries=0
 
 
 def _determine_nextjs_version() -> str:
-    default_version = "15.3.1"
+    default_version = "15.3.2"
     if (version := os.getenv("NEXTJS_VERSION")) and version != default_version:
         from reflex.utils import console
 
@@ -86,6 +86,18 @@ def _determine_nextjs_version() -> str:
     return default_version
 
 
+def _determine_react_version() -> str:
+    default_version = "19.1.0"
+    if (version := os.getenv("REACT_VERSION")) and version != default_version:
+        from reflex.utils import console
+
+        console.warn(
+            f"You have requested react@{version} but the supported version is {default_version}, abandon all hope ye who enter here."
+        )
+        return version
+    return default_version
+
+
 class PackageJson(SimpleNamespace):
     """Constants used to build the package.json file."""
 
@@ -99,15 +111,17 @@ class PackageJson(SimpleNamespace):
 
     PATH = "package.json"
 
+    _react_version = _determine_react_version()
+
     DEPENDENCIES = {
         "@emotion/react": "11.14.0",
-        "axios": "1.8.4",
+        "axios": "1.9.0",
         "json5": "2.2.3",
         "next": _determine_nextjs_version(),
         "next-sitemap": "4.2.3",
         "next-themes": "0.4.6",
-        "react": "19.1.0",
-        "react-dom": "19.1.0",
+        "react": _react_version,
+        "react-dom": _react_version,
         "react-focus-lock": "2.13.6",
         "socket.io-client": "4.8.1",
         "universal-cookie": "7.2.2",
@@ -119,5 +133,5 @@ class PackageJson(SimpleNamespace):
     }
     OVERRIDES = {
         # This should always match the `react` version in DEPENDENCIES for recharts compatibility.
-        "react-is": "19.1.0"
+        "react-is": _react_version
     }

+ 24 - 0
reflex/event.py

@@ -2066,6 +2066,30 @@ class EventNamespace:
                 setattr(func, BACKGROUND_TASK_MARKER, True)
             if getattr(func, "__name__", "").startswith("_"):
                 raise ValueError("Event handlers cannot be private.")
+
+            qualname: str | None = getattr(func, "__qualname__", None)
+
+            if qualname and (
+                len(func_path := qualname.split(".")) == 1
+                or func_path[-2] == "<locals>"
+            ):
+                from reflex.state import BaseState
+
+                types = get_type_hints(func)
+                state_arg_name = next(iter(inspect.signature(func).parameters), None)
+                state_cls = state_arg_name and types.get(state_arg_name)
+                if state_cls and issubclass(state_cls, BaseState):
+                    name = (
+                        (func.__module__ + "." + qualname)
+                        .replace(".", "_")
+                        .replace("<locals>", "_")
+                        .removeprefix("_")
+                    )
+                    object.__setattr__(func, "__name__", name)
+                    object.__setattr__(func, "__qualname__", name)
+                    state_cls._add_event_handler(name, func)
+                    return getattr(state_cls, name)
+
             return func  # pyright: ignore [reportReturnType]
 
         if func is not None:

+ 858 - 0
reflex/istate/manager.py

@@ -0,0 +1,858 @@
+"""State manager for managing client states."""
+
+import asyncio
+import contextlib
+import dataclasses
+import functools
+import time
+import uuid
+from abc import ABC, abstractmethod
+from collections.abc import AsyncIterator
+from hashlib import md5
+from pathlib import Path
+
+from redis import ResponseError
+from redis.asyncio import Redis
+from redis.asyncio.client import PubSub
+from typing_extensions import override
+
+from reflex import constants
+from reflex.config import environment, get_config
+from reflex.state import BaseState, _split_substate_key, _substate_key
+from reflex.utils import console, path_ops, prerequisites
+from reflex.utils.exceptions import (
+    InvalidLockWarningThresholdError,
+    InvalidStateManagerModeError,
+    LockExpiredError,
+    StateSchemaMismatchError,
+)
+
+
+@dataclasses.dataclass
+class StateManager(ABC):
+    """A class to manage many client states."""
+
+    # The state class to use.
+    state: type[BaseState]
+
+    @classmethod
+    def create(cls, state: type[BaseState]):
+        """Create a new state manager.
+
+        Args:
+            state: The state class to use.
+
+        Raises:
+            InvalidStateManagerModeError: If the state manager mode is invalid.
+
+        Returns:
+            The state manager (either disk, memory or redis).
+        """
+        config = get_config()
+        if prerequisites.parse_redis_url() is not None:
+            config.state_manager_mode = constants.StateManagerMode.REDIS
+        if config.state_manager_mode == constants.StateManagerMode.MEMORY:
+            return StateManagerMemory(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.DISK:
+            return StateManagerDisk(state=state)
+        if config.state_manager_mode == constants.StateManagerMode.REDIS:
+            redis = prerequisites.get_redis()
+            if redis is not None:
+                # make sure expiration values are obtained only from the config object on creation
+                return StateManagerRedis(
+                    state=state,
+                    redis=redis,
+                    token_expiration=config.redis_token_expiration,
+                    lock_expiration=config.redis_lock_expiration,
+                    lock_warning_threshold=config.redis_lock_warning_threshold,
+                )
+        raise InvalidStateManagerModeError(
+            f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
+        )
+
+    @abstractmethod
+    async def get_state(self, token: str) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        pass
+
+    @abstractmethod
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @abstractmethod
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        yield self.state()
+
+
+@dataclasses.dataclass
+class StateManagerMemory(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
+
+    # The dict of mutexes for each client
+    _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
+        default_factory=dict, init=False
+    )
+
+    @override
+    async def get_state(self, token: str) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = _split_substate_key(token)[0]
+        if token not in self.states:
+            self.states[token] = self.state(_reflex_internal_init=True)
+        return self.states[token]
+
+    @override
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        pass
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        token = _split_substate_key(token)[0]
+        if token not in self._states_locks:
+            async with self._state_manager_lock:
+                if token not in self._states_locks:
+                    self._states_locks[token] = asyncio.Lock()
+
+        async with self._states_locks[token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+def _default_token_expiration() -> int:
+    """Get the default token expiration time.
+
+    Returns:
+        The default token expiration time.
+    """
+    return get_config().redis_token_expiration
+
+
+def reset_disk_state_manager():
+    """Reset the disk state manager."""
+    states_directory = prerequisites.get_states_dir()
+    if states_directory.exists():
+        for path in states_directory.iterdir():
+            path.unlink()
+
+
+@dataclasses.dataclass
+class StateManagerDisk(StateManager):
+    """A state manager that stores states in memory."""
+
+    # The mapping of client ids to states.
+    states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
+
+    # The mutex ensures the dict of mutexes is updated exclusively
+    _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
+
+    # The dict of mutexes for each client
+    _states_locks: dict[str, asyncio.Lock] = dataclasses.field(
+        default_factory=dict,
+        init=False,
+    )
+
+    # The token expiration time (s).
+    token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
+
+    def __post_init_(self):
+        """Create a new state manager."""
+        path_ops.mkdir(self.states_directory)
+
+        self._purge_expired_states()
+
+    @functools.cached_property
+    def states_directory(self) -> Path:
+        """Get the states directory.
+
+        Returns:
+            The states directory.
+        """
+        return prerequisites.get_states_dir()
+
+    def _purge_expired_states(self):
+        """Purge expired states from the disk."""
+        import time
+
+        for path in path_ops.ls(self.states_directory):
+            # check path is a pickle file
+            if path.suffix != ".pkl":
+                continue
+
+            # load last edited field from file
+            last_edited = path.stat().st_mtime
+
+            # check if the file is older than the token expiration time
+            if time.time() - last_edited > self.token_expiration:
+                # remove the file
+                path.unlink()
+
+    def token_path(self, token: str) -> Path:
+        """Get the path for a token.
+
+        Args:
+            token: The token to get the path for.
+
+        Returns:
+            The path for the token.
+        """
+        return (
+            self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
+        ).absolute()
+
+    async def load_state(self, token: str) -> BaseState | None:
+        """Load a state object based on the provided token.
+
+        Args:
+            token: The token used to identify the state object.
+
+        Returns:
+            The loaded state object or None.
+        """
+        token_path = self.token_path(token)
+
+        if token_path.exists():
+            try:
+                with token_path.open(mode="rb") as file:
+                    return BaseState._deserialize(fp=file)
+            except Exception:
+                pass
+
+    async def populate_substates(
+        self, client_token: str, state: BaseState, root_state: BaseState
+    ):
+        """Populate the substates of a state object.
+
+        Args:
+            client_token: The client token.
+            state: The state object to populate.
+            root_state: The root state object.
+        """
+        for substate in state.get_substates():
+            substate_token = _substate_key(client_token, substate)
+
+            fresh_instance = await root_state.get_state(substate)
+            instance = await self.load_state(substate_token)
+            if instance is not None:
+                # Ensure all substates exist, even if they weren't serialized previously.
+                instance.substates = fresh_instance.substates
+            else:
+                instance = fresh_instance
+            state.substates[substate.get_name()] = instance
+            instance.parent_state = state
+
+            await self.populate_substates(client_token, instance, root_state)
+
+    @override
+    async def get_state(
+        self,
+        token: str,
+    ) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+
+        Returns:
+            The state for the token.
+        """
+        client_token = _split_substate_key(token)[0]
+        root_state = self.states.get(client_token)
+        if root_state is not None:
+            # Retrieved state from memory.
+            return root_state
+
+        # Deserialize root state from disk.
+        root_state = await self.load_state(_substate_key(client_token, self.state))
+        # Create a new root state tree with all substates instantiated.
+        fresh_root_state = self.state(_reflex_internal_init=True)
+        if root_state is None:
+            root_state = fresh_root_state
+        else:
+            # Ensure all substates exist, even if they were not serialized previously.
+            root_state.substates = fresh_root_state.substates
+        self.states[client_token] = root_state
+        await self.populate_substates(client_token, root_state, root_state)
+        return root_state
+
+    async def set_state_for_substate(self, client_token: str, substate: BaseState):
+        """Set the state for a substate.
+
+        Args:
+            client_token: The client token.
+            substate: The substate to set.
+        """
+        substate_token = _substate_key(client_token, substate)
+
+        if substate._get_was_touched():
+            substate._was_touched = False  # Reset the touched flag after serializing.
+            pickle_state = substate._serialize()
+            if pickle_state:
+                if not self.states_directory.exists():
+                    self.states_directory.mkdir(parents=True, exist_ok=True)
+                self.token_path(substate_token).write_bytes(pickle_state)
+
+        for substate_substate in substate.substates.values():
+            await self.set_state_for_substate(client_token, substate_substate)
+
+    @override
+    async def set_state(self, token: str, state: BaseState):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+        """
+        client_token, substate = _split_substate_key(token)
+        await self.set_state_for_substate(client_token, state)
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        # Memory state manager ignores the substate suffix and always returns the top-level state.
+        client_token, substate = _split_substate_key(token)
+        if client_token not in self._states_locks:
+            async with self._state_manager_lock:
+                if client_token not in self._states_locks:
+                    self._states_locks[client_token] = asyncio.Lock()
+
+        async with self._states_locks[client_token]:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state)
+
+
+def _default_lock_expiration() -> int:
+    """Get the default lock expiration time.
+
+    Returns:
+        The default lock expiration time.
+    """
+    return get_config().redis_lock_expiration
+
+
+def _default_lock_warning_threshold() -> int:
+    """Get the default lock warning threshold.
+
+    Returns:
+        The default lock warning threshold.
+    """
+    return get_config().redis_lock_warning_threshold
+
+
+@dataclasses.dataclass
+class StateManagerRedis(StateManager):
+    """A state manager that stores states in redis."""
+
+    # The redis client to use.
+    redis: Redis
+
+    # The token expiration time (s).
+    token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
+
+    # The maximum time to hold a lock (ms).
+    lock_expiration: int = dataclasses.field(default_factory=_default_lock_expiration)
+
+    # The maximum time to hold a lock (ms) before warning.
+    lock_warning_threshold: int = dataclasses.field(
+        default_factory=_default_lock_warning_threshold
+    )
+
+    # The keyspace subscription string when redis is waiting for lock to be released.
+    _redis_notify_keyspace_events: str = dataclasses.field(
+        default="K"  # Enable keyspace notifications (target a particular key)
+        "g"  # For generic commands (DEL, EXPIRE, etc)
+        "x"  # For expired events
+        "e"  # For evicted events (i.e. maxmemory exceeded)
+    )
+
+    # These events indicate that a lock is no longer held.
+    _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
+        default_factory=lambda: {
+            b"del",
+            b"expire",
+            b"expired",
+            b"evicted",
+        }
+    )
+
+    # Whether keyspace notifications have been enabled.
+    _redis_notify_keyspace_events_enabled: bool = dataclasses.field(default=False)
+
+    # The logical database number used by the redis client.
+    _redis_db: int = dataclasses.field(default=0)
+
+    def __post_init__(self):
+        """Validate the lock warning threshold.
+
+        Raises:
+            InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
+        """
+        if self.lock_warning_threshold >= (lock_expiration := self.lock_expiration):
+            raise InvalidLockWarningThresholdError(
+                f"The lock warning threshold({self.lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
+            )
+
+    def _get_required_state_classes(
+        self,
+        target_state_cls: type[BaseState],
+        subclasses: bool = False,
+        required_state_classes: set[type[BaseState]] | None = None,
+    ) -> set[type[BaseState]]:
+        """Recursively determine which states are required to fetch the target state.
+
+        This will always include potentially dirty substates that depend on vars
+        in the target_state_cls.
+
+        Args:
+            target_state_cls: The target state class being fetched.
+            subclasses: Whether to include subclasses of the target state.
+            required_state_classes: Recursive argument tracking state classes that have already been seen.
+
+        Returns:
+            The set of state classes required to fetch the target state.
+        """
+        if required_state_classes is None:
+            required_state_classes = set()
+        # Get the substates if requested.
+        if subclasses:
+            for substate in target_state_cls.get_substates():
+                self._get_required_state_classes(
+                    substate,
+                    subclasses=True,
+                    required_state_classes=required_state_classes,
+                )
+        if target_state_cls in required_state_classes:
+            return required_state_classes
+        required_state_classes.add(target_state_cls)
+
+        # Get dependent substates.
+        for pd_substates in target_state_cls._get_potentially_dirty_states():
+            self._get_required_state_classes(
+                pd_substates,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+
+        # Get the parent state if it exists.
+        if parent_state := target_state_cls.get_parent_state():
+            self._get_required_state_classes(
+                parent_state,
+                subclasses=False,
+                required_state_classes=required_state_classes,
+            )
+        return required_state_classes
+
+    def _get_populated_states(
+        self,
+        target_state: BaseState,
+        populated_states: dict[str, BaseState] | None = None,
+    ) -> dict[str, BaseState]:
+        """Recursively determine which states from target_state are already fetched.
+
+        Args:
+            target_state: The state to check for populated states.
+            populated_states: Recursive argument tracking states seen in previous calls.
+
+        Returns:
+            A dictionary of state full name to state instance.
+        """
+        if populated_states is None:
+            populated_states = {}
+        if target_state.get_full_name() in populated_states:
+            return populated_states
+        populated_states[target_state.get_full_name()] = target_state
+        for substate in target_state.substates.values():
+            self._get_populated_states(substate, populated_states=populated_states)
+        if target_state.parent_state is not None:
+            self._get_populated_states(
+                target_state.parent_state, populated_states=populated_states
+            )
+        return populated_states
+
+    @override
+    async def get_state(
+        self,
+        token: str,
+        top_level: bool = True,
+        for_state_instance: BaseState | None = None,
+    ) -> BaseState:
+        """Get the state for a token.
+
+        Args:
+            token: The token to get the state for.
+            top_level: If true, return an instance of the top-level state (self.state).
+            for_state_instance: If provided, attach the requested states to this existing state tree.
+
+        Returns:
+            The state for the token.
+
+        Raises:
+            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
+                requested state was not fetched.
+        """
+        # Split the actual token from the fully qualified substate name.
+        token, state_path = _split_substate_key(token)
+        if state_path:
+            # Get the State class associated with the given path.
+            state_cls = self.state.get_class_substate(state_path)
+        else:
+            raise RuntimeError(
+                f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
+            )
+
+        # Determine which states we already have.
+        flat_state_tree: dict[str, BaseState] = (
+            self._get_populated_states(for_state_instance) if for_state_instance else {}
+        )
+
+        # Determine which states from the tree need to be fetched.
+        required_state_classes = sorted(
+            self._get_required_state_classes(state_cls, subclasses=True)
+            - {type(s) for s in flat_state_tree.values()},
+            key=lambda x: x.get_full_name(),
+        )
+
+        redis_pipeline = self.redis.pipeline()
+        for state_cls in required_state_classes:
+            redis_pipeline.get(_substate_key(token, state_cls))
+
+        for state_cls, redis_state in zip(
+            required_state_classes,
+            await redis_pipeline.execute(),
+            strict=False,
+        ):
+            state = None
+
+            if redis_state is not None:
+                # Deserialize the substate.
+                with contextlib.suppress(StateSchemaMismatchError):
+                    state = BaseState._deserialize(data=redis_state)
+            if state is None:
+                # Key didn't exist or schema mismatch so create a new instance for this token.
+                state = state_cls(
+                    init_substates=False,
+                    _reflex_internal_init=True,
+                )
+            flat_state_tree[state.get_full_name()] = state
+            if state.get_parent_state() is not None:
+                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
+                    "."
+                )
+                parent_state = flat_state_tree.get(parent_state_name)
+                if parent_state is None:
+                    raise RuntimeError(
+                        f"Parent state for {state.get_full_name()} was not found "
+                        "in the state tree, but should have already been fetched. "
+                        "This is a bug",
+                    )
+                parent_state.substates[state_name] = state
+                state.parent_state = parent_state
+
+        # To retain compatibility with previous implementation, by default, we return
+        # the top-level state which should always be fetched or already cached.
+        if top_level:
+            return flat_state_tree[self.state.get_full_name()]
+        return flat_state_tree[state_cls.get_full_name()]
+
+    @override
+    async def set_state(
+        self,
+        token: str,
+        state: BaseState,
+        lock_id: bytes | None = None,
+    ):
+        """Set the state for a token.
+
+        Args:
+            token: The token to set the state for.
+            state: The state to set.
+            lock_id: If provided, the lock_key must be set to this value to set the state.
+
+        Raises:
+            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
+            RuntimeError: If the state instance doesn't match the state name in the token.
+        """
+        # Check that we're holding the lock.
+        if (
+            lock_id is not None
+            and await self.redis.get(self._lock_key(token)) != lock_id
+        ):
+            raise LockExpiredError(
+                f"Lock expired for token {token} while processing. Consider increasing "
+                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
+                "or use `@rx.event(background=True)` decorator for long-running tasks."
+            )
+        elif lock_id is not None:
+            time_taken = self.lock_expiration / 1000 - (
+                await self.redis.ttl(self._lock_key(token))
+            )
+            if time_taken > self.lock_warning_threshold / 1000:
+                console.warn(
+                    f"Lock for token {token} was held too long {time_taken=}s, "
+                    f"use `@rx.event(background=True)` decorator for long-running tasks.",
+                    dedupe=True,
+                )
+
+        client_token, substate_name = _split_substate_key(token)
+        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
+        if state.parent_state is not None and state.get_full_name() != substate_name:
+            raise RuntimeError(
+                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
+            )
+
+        # Recursively set_state on all known substates.
+        tasks = [
+            asyncio.create_task(
+                self.set_state(
+                    _substate_key(client_token, substate),
+                    substate,
+                    lock_id,
+                )
+            )
+            for substate in state.substates.values()
+        ]
+        # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
+        if state._get_was_touched():
+            pickle_state = state._serialize()
+            if pickle_state:
+                await self.redis.set(
+                    _substate_key(client_token, state),
+                    pickle_state,
+                    ex=self.token_expiration,
+                )
+
+        # Wait for substates to be persisted.
+        for t in tasks:
+            await t
+
+    @override
+    @contextlib.asynccontextmanager
+    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
+        """Modify the state for a token while holding exclusive lock.
+
+        Args:
+            token: The token to modify the state for.
+
+        Yields:
+            The state for the token.
+        """
+        async with self._lock(token) as lock_id:
+            state = await self.get_state(token)
+            yield state
+            await self.set_state(token, state, lock_id)
+
+    @staticmethod
+    def _lock_key(token: str) -> bytes:
+        """Get the redis key for a token's lock.
+
+        Args:
+            token: The token to get the lock key for.
+
+        Returns:
+            The redis lock key for the token.
+        """
+        # All substates share the same lock domain, so ignore any substate path suffix.
+        client_token = _split_substate_key(token)[0]
+        return f"{client_token}_lock".encode()
+
+    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
+        """Try to get a redis lock for a token.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+
+        Returns:
+            True if the lock was obtained.
+        """
+        return await self.redis.set(
+            lock_key,
+            lock_id,
+            px=self.lock_expiration,
+            nx=True,  # only set if it doesn't exist
+        )
+
+    async def _get_pubsub_message(
+        self, pubsub: PubSub, timeout: float | None = None
+    ) -> None:
+        """Get lock release events from the pubsub.
+
+        Args:
+            pubsub: The pubsub to get a message from.
+            timeout: Remaining time to wait for a message.
+
+        Returns:
+            The message.
+        """
+        if timeout is None:
+            timeout = self.lock_expiration / 1000.0
+
+        started = time.time()
+        message = await pubsub.get_message(
+            ignore_subscribe_messages=True,
+            timeout=timeout,
+        )
+        if (
+            message is None
+            or message["data"] not in self._redis_keyspace_lock_release_events
+        ):
+            remaining = timeout - (time.time() - started)
+            if remaining <= 0:
+                return
+            await self._get_pubsub_message(pubsub, timeout=remaining)
+
+    async def _enable_keyspace_notifications(self):
+        """Enable keyspace notifications for the redis server.
+
+        Raises:
+            ResponseError: when the keyspace config cannot be set.
+        """
+        if self._redis_notify_keyspace_events_enabled:
+            return
+        # Find out which logical database index is being used.
+        self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
+
+        try:
+            await self.redis.config_set(
+                "notify-keyspace-events",
+                self._redis_notify_keyspace_events,
+            )
+        except ResponseError:
+            # Some redis servers only allow out-of-band configuration, so ignore errors here.
+            if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
+                raise
+        self._redis_notify_keyspace_events_enabled = True
+
+    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
+        """Wait for a redis lock to be released via pubsub.
+
+        Coroutine will not return until the lock is obtained.
+
+        Args:
+            lock_key: The redis key for the lock.
+            lock_id: The ID of the lock.
+        """
+        # Enable keyspace notifications for the lock key, so we know when it is available.
+        await self._enable_keyspace_notifications()
+        lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
+        async with self.redis.pubsub() as pubsub:
+            await pubsub.psubscribe(lock_key_channel)
+            # wait for the lock to be released
+            while True:
+                # fast path
+                if await self._try_get_lock(lock_key, lock_id):
+                    return
+                # wait for lock events
+                await self._get_pubsub_message(pubsub)
+
+    @contextlib.asynccontextmanager
+    async def _lock(self, token: str):
+        """Obtain a redis lock for a token.
+
+        Args:
+            token: The token to obtain a lock for.
+
+        Yields:
+            The ID of the lock (to be passed to set_state).
+
+        Raises:
+            LockExpiredError: If the lock has expired while processing the event.
+        """
+        lock_key = self._lock_key(token)
+        lock_id = uuid.uuid4().hex.encode()
+
+        if not await self._try_get_lock(lock_key, lock_id):
+            # Missed the fast-path to get lock, subscribe for lock delete/expire events
+            await self._wait_lock(lock_key, lock_id)
+        state_is_locked = True
+
+        try:
+            yield lock_id
+        except LockExpiredError:
+            state_is_locked = False
+            raise
+        finally:
+            if state_is_locked:
+                # only delete our lock
+                await self.redis.delete(lock_key)
+
+    async def close(self):
+        """Explicitly close the redis connection and connection_pool.
+
+        It is necessary in testing scenarios to close between asyncio test cases
+        to avoid having lingering redis connections associated with event loops
+        that will be closed (each test case uses its own event loop).
+
+        Note: Connections will be automatically reopened when needed.
+        """
+        await self.redis.aclose(close_connection_pool=True)
+
+
+def get_state_manager() -> StateManager:
+    """Get the state manager for the app that is currently running.
+
+    Returns:
+        The state manager.
+    """
+    return prerequisites.get_and_validate_app().app.state_manager

+ 726 - 2
reflex/istate/proxy.py

@@ -1,8 +1,309 @@
 """A module to hold state proxy classes."""
 
-from typing import Any
+from __future__ import annotations
 
-from reflex.state import StateProxy
+import asyncio
+import copy
+import dataclasses
+import functools
+import inspect
+import json
+from collections.abc import Callable, Sequence
+from types import MethodType
+from typing import TYPE_CHECKING, Any, SupportsIndex
+
+import pydantic
+import wrapt
+from pydantic import BaseModel as BaseModelV2
+from pydantic.v1 import BaseModel as BaseModelV1
+from sqlalchemy.orm import DeclarativeBase
+
+from reflex.base import Base
+from reflex.utils import prerequisites
+from reflex.utils.exceptions import ImmutableStateError
+from reflex.utils.serializers import serializer
+from reflex.vars.base import Var
+
+if TYPE_CHECKING:
+    from reflex.state import BaseState, StateUpdate
+
+
+class StateProxy(wrapt.ObjectProxy):
+    """Proxy of a state instance to control mutability of vars for a background task.
+
+    Since a background task runs against a state instance without holding the
+    state_manager lock for the token, the reference may become stale if the same
+    state is modified by another event handler.
+
+    The proxy object ensures that writes to the state are blocked unless
+    explicitly entering a context which refreshes the state from state_manager
+    and holds the lock for the token until exiting the context. After exiting
+    the context, a StateUpdate may be emitted to the frontend to notify the
+    client of the state change.
+
+    A background task will be passed the `StateProxy` as `self`, so mutability
+    can be safely performed inside an `async with self` block.
+
+        class State(rx.State):
+            counter: int = 0
+
+            @rx.event(background=True)
+            async def bg_increment(self):
+                await asyncio.sleep(1)
+                async with self:
+                    self.counter += 1
+    """
+
+    def __init__(
+        self,
+        state_instance: BaseState,
+        parent_state_proxy: StateProxy | None = None,
+    ):
+        """Create a proxy for a state instance.
+
+        If `get_state` is used on a StateProxy, the resulting state will be
+        linked to the given state via parent_state_proxy. The first state in the
+        chain is the state that initiated the background task.
+
+        Args:
+            state_instance: The state instance to proxy.
+            parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
+        """
+        super().__init__(state_instance)
+        # compile is not relevant to backend logic
+        self._self_app = prerequisites.get_and_validate_app().app
+        self._self_substate_path = tuple(state_instance.get_full_name().split("."))
+        self._self_actx = None
+        self._self_mutable = False
+        self._self_actx_lock = asyncio.Lock()
+        self._self_actx_lock_holder = None
+        self._self_parent_state_proxy = parent_state_proxy
+
+    def _is_mutable(self) -> bool:
+        """Check if the state is mutable.
+
+        Returns:
+            Whether the state is mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            return self._self_parent_state_proxy._is_mutable() or self._self_mutable
+        return self._self_mutable
+
+    async def __aenter__(self) -> StateProxy:
+        """Enter the async context manager protocol.
+
+        Sets mutability to True and enters the `App.modify_state` async context,
+        which refreshes the state from state_manager and holds the lock for the
+        given state token until exiting the context.
+
+        Background tasks should avoid blocking calls while inside the context.
+
+        Returns:
+            This StateProxy instance in mutable mode.
+
+        Raises:
+            ImmutableStateError: If the state is already mutable.
+        """
+        if self._self_parent_state_proxy is not None:
+            from reflex.state import State
+
+            parent_state = (
+                await self._self_parent_state_proxy.__aenter__()
+            ).__wrapped__
+            super().__setattr__(
+                "__wrapped__",
+                await parent_state.get_state(
+                    State.get_class_substate(self._self_substate_path)
+                ),
+            )
+            return self
+        current_task = asyncio.current_task()
+        if (
+            self._self_actx_lock.locked()
+            and current_task == self._self_actx_lock_holder
+        ):
+            raise ImmutableStateError(
+                "The state is already mutable. Do not nest `async with self` blocks."
+            )
+
+        from reflex.state import _substate_key
+
+        await self._self_actx_lock.acquire()
+        self._self_actx_lock_holder = current_task
+        self._self_actx = self._self_app.modify_state(
+            token=_substate_key(
+                self.__wrapped__.router.session.client_token,
+                self._self_substate_path,
+            )
+        )
+        mutable_state = await self._self_actx.__aenter__()
+        super().__setattr__(
+            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
+        )
+        self._self_mutable = True
+        return self
+
+    async def __aexit__(self, *exc_info: Any) -> None:
+        """Exit the async context manager protocol.
+
+        Sets proxy mutability to False and persists any state changes.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        if self._self_parent_state_proxy is not None:
+            await self._self_parent_state_proxy.__aexit__(*exc_info)
+            return
+        if self._self_actx is None:
+            return
+        self._self_mutable = False
+        try:
+            await self._self_actx.__aexit__(*exc_info)
+        finally:
+            self._self_actx_lock_holder = None
+            self._self_actx_lock.release()
+        self._self_actx = None
+
+    def __enter__(self):
+        """Enter the regular context manager protocol.
+
+        This is not supported for background tasks, and exists only to raise a more useful exception
+        when the StateProxy is used incorrectly.
+
+        Raises:
+            TypeError: always, because only async contextmanager protocol is supported.
+        """
+        raise TypeError("Background task must use `async with self` to modify state.")
+
+    def __exit__(self, *exc_info: Any) -> None:
+        """Exit the regular context manager protocol.
+
+        Args:
+            exc_info: The exception info tuple.
+        """
+        pass
+
+    def __getattr__(self, name: str) -> Any:
+        """Get the attribute from the underlying state instance.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if name in ["substates", "parent_state"] and not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+
+        value = super().__getattr__(name)
+        if not name.startswith("_self_") and isinstance(value, MutableProxy):
+            # ensure mutations to these containers are blocked unless proxy is _mutable
+            return ImmutableMutableProxy(
+                wrapped=value.__wrapped__,
+                state=self,
+                field_name=value._self_field_name,
+            )
+        if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
+            # Rebind event handler to the proxy instance
+            value = functools.partial(
+                value.func,
+                self,
+                *value.args[1:],
+                **value.keywords,
+            )
+        if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
+            # Rebind methods to the proxy instance
+            value = type(value)(value.__func__, self)
+        return value
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        """Set the attribute on the underlying state instance.
+
+        If the attribute is internal, set it on the proxy instance instead.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if (
+            name.startswith("_self_")  # wrapper attribute
+            or self._is_mutable()  # lock held
+            # non-persisted state attribute
+            or name in self.__wrapped__.get_skip_vars()
+        ):
+            super().__setattr__(name, value)
+            return
+
+        raise ImmutableStateError(
+            "Background task StateProxy is immutable outside of a context "
+            "manager. Use `async with self` to modify state."
+        )
+
+    def get_substate(self, path: Sequence[str]) -> BaseState:
+        """Only allow substate access with lock held.
+
+        Args:
+            path: The path to the substate.
+
+        Returns:
+            The substate.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return self.__wrapped__.get_substate(path)
+
+    async def get_state(self, state_cls: type[BaseState]) -> BaseState:
+        """Get an instance of the state associated with this token.
+
+        Args:
+            state_cls: The class of the state.
+
+        Returns:
+            The state.
+
+        Raises:
+            ImmutableStateError: If the state is not in mutable mode.
+        """
+        if not self._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return type(self)(
+            await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
+        )
+
+    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+        """Temporarily allow mutability to access parent_state.
+
+        Args:
+            *args: The args to pass to the underlying state instance.
+            **kwargs: The kwargs to pass to the underlying state instance.
+
+        Returns:
+            The state update.
+        """
+        original_mutable = self._self_mutable
+        self._self_mutable = True
+        try:
+            return await self.__wrapped__._as_state_update(*args, **kwargs)
+        finally:
+            self._self_mutable = original_mutable
 
 
 class ReadOnlyStateProxy(StateProxy):
@@ -31,3 +332,426 @@ class ReadOnlyStateProxy(StateProxy):
             NotImplementedError: Always raised when trying to mark the proxied state as dirty.
         """
         raise NotImplementedError("This is a read-only state proxy.")
+
+
+class MutableProxy(wrapt.ObjectProxy):
+    """A proxy for a mutable object that tracks changes."""
+
+    # Hint for finding the base class of the proxy.
+    __base_proxy__ = "MutableProxy"
+
+    # Methods on wrapped objects which should mark the state as dirty.
+    __mark_dirty_attrs__ = {
+        "add",
+        "append",
+        "clear",
+        "difference_update",
+        "discard",
+        "extend",
+        "insert",
+        "intersection_update",
+        "pop",
+        "popitem",
+        "remove",
+        "reverse",
+        "setdefault",
+        "sort",
+        "symmetric_difference_update",
+        "update",
+    }
+
+    # Methods on wrapped objects might return mutable objects that should be tracked.
+    __wrap_mutable_attrs__ = {
+        "get",
+        "setdefault",
+    }
+
+    # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
+    __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
+        pydantic.BaseModel.__dict__
+    )
+
+    # These types will be wrapped in MutableProxy
+    __mutable_types__ = (
+        list,
+        dict,
+        set,
+        Base,
+        DeclarativeBase,
+        BaseModelV2,
+        BaseModelV1,
+    )
+
+    # Dynamically generated classes for tracking dataclass mutations.
+    __dataclass_proxies__: dict[type, type] = {}
+
+    def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
+        """Create a proxy instance for a mutable object that tracks changes.
+
+        Args:
+            wrapped: The object to proxy.
+            *args: Other args passed to MutableProxy (ignored).
+            **kwargs: Other kwargs passed to MutableProxy (ignored).
+
+        Returns:
+            The proxy instance.
+        """
+        if dataclasses.is_dataclass(wrapped):
+            wrapped_cls = type(wrapped)
+            wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
+            # Find the associated class
+            if wrapper_cls_name not in cls.__dataclass_proxies__:
+                # Create a new class that has the __dataclass_fields__ defined
+                cls.__dataclass_proxies__[wrapper_cls_name] = type(
+                    wrapper_cls_name,
+                    (cls,),
+                    {
+                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportAttributeAccessIssue]
+                            wrapped_cls,
+                            dataclasses._FIELDS,  # pyright: ignore [reportAttributeAccessIssue]
+                        ),
+                    },
+                )
+            cls = cls.__dataclass_proxies__[wrapper_cls_name]
+        return super().__new__(cls)
+
+    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
+        """Create a proxy for a mutable object that tracks changes.
+
+        Args:
+            wrapped: The object to proxy.
+            state: The state to mark dirty when the object is changed.
+            field_name: The name of the field on the state associated with the
+                wrapped object.
+        """
+        super().__init__(wrapped)
+        self._self_state = state
+        self._self_field_name = field_name
+
+    def __repr__(self) -> str:
+        """Get the representation of the wrapped object.
+
+        Returns:
+            The representation of the wrapped object.
+        """
+        return f"{type(self).__name__}({self.__wrapped__})"
+
+    def _mark_dirty(
+        self,
+        wrapped: Callable | None = None,
+        instance: BaseState | None = None,
+        args: tuple = (),
+        kwargs: dict | None = None,
+    ) -> Any:
+        """Mark the state as dirty, then call a wrapped function.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function.
+        """
+        self._self_state.dirty_vars.add(self._self_field_name)
+        self._self_state._mark_dirty()
+        if wrapped is not None:
+            return wrapped(*args, **(kwargs or {}))
+
+    @classmethod
+    def _is_mutable_type(cls, value: Any) -> bool:
+        """Check if a value is of a mutable type and should be wrapped.
+
+        Args:
+            value: The value to check.
+
+        Returns:
+            Whether the value is of a mutable type.
+        """
+        return isinstance(value, cls.__mutable_types__) or (
+            dataclasses.is_dataclass(value) and not isinstance(value, Var)
+        )
+
+    @staticmethod
+    def _is_called_from_dataclasses_internal() -> bool:
+        """Check if the current function is called from dataclasses helper.
+
+        Returns:
+            Whether the current function is called from dataclasses internal code.
+        """
+        # Walk up the stack a bit to see if we are called from dataclasses
+        # internal code, for example `asdict` or `astuple`.
+        frame = inspect.currentframe()
+        for _ in range(5):
+            # Why not `inspect.stack()` -- this is much faster!
+            if not (frame := frame and frame.f_back):
+                break
+            if inspect.getfile(frame) == dataclasses.__file__:
+                return True
+        return False
+
+    def _wrap_recursive(self, value: Any) -> Any:
+        """Wrap a value recursively if it is mutable.
+
+        Args:
+            value: The value to wrap.
+
+        Returns:
+            The wrapped value.
+        """
+        # When called from dataclasses internal code, return the unwrapped value
+        if self._is_called_from_dataclasses_internal():
+            return value
+        # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
+        if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
+            base_cls = globals()[self.__base_proxy__]
+            return base_cls(
+                wrapped=value,
+                state=self._self_state,
+                field_name=self._self_field_name,
+            )
+        return value
+
+    def _wrap_recursive_decorator(
+        self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
+    ) -> Any:
+        """Wrap a function that returns a possibly mutable value.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function (possibly wrapped in a MutableProxy).
+        """
+        return self._wrap_recursive(wrapped(*args, **kwargs))
+
+    def __getattr__(self, __name: str) -> Any:
+        """Get the attribute on the proxied object and return a proxy if mutable.
+
+        Args:
+            __name: The name of the attribute.
+
+        Returns:
+            The attribute value.
+        """
+        value = super().__getattr__(__name)
+
+        if callable(value):
+            if __name in self.__mark_dirty_attrs__:
+                # Wrap special callables, like "append", which should mark state dirty.
+                value = wrapt.FunctionWrapper(value, self._mark_dirty)
+
+            if __name in self.__wrap_mutable_attrs__:
+                # Wrap methods that may return mutable objects tied to the state.
+                value = wrapt.FunctionWrapper(
+                    value,
+                    self._wrap_recursive_decorator,
+                )
+
+            if (
+                isinstance(self.__wrapped__, Base)
+                and __name not in self.__never_wrap_base_attrs__
+                and hasattr(value, "__func__")
+            ):
+                # Wrap methods called on Base subclasses, which might do _anything_
+                return wrapt.FunctionWrapper(
+                    functools.partial(value.__func__, self),  # pyright: ignore [reportFunctionMemberAccess]
+                    self._wrap_recursive_decorator,
+                )
+
+        if self._is_mutable_type(value) and __name not in (
+            "__wrapped__",
+            "_self_state",
+            "__dict__",
+        ):
+            # Recursively wrap mutable attribute values retrieved through this proxy.
+            return self._wrap_recursive(value)
+
+        return value
+
+    def __getitem__(self, key: Any) -> Any:
+        """Get the item on the proxied object and return a proxy if mutable.
+
+        Args:
+            key: The key of the item.
+
+        Returns:
+            The item value.
+        """
+        value = super().__getitem__(key)
+        if isinstance(key, slice) and isinstance(value, list):
+            return [self._wrap_recursive(item) for item in value]
+        # Recursively wrap mutable items retrieved through this proxy.
+        return self._wrap_recursive(value)
+
+    def __iter__(self) -> Any:
+        """Iterate over the proxied object and return a proxy if mutable.
+
+        Yields:
+            Each item value (possibly wrapped in MutableProxy).
+        """
+        for value in super().__iter__():
+            # Recursively wrap mutable items retrieved through this proxy.
+            yield self._wrap_recursive(value)
+
+    def __delattr__(self, name: str):
+        """Delete the attribute on the proxied object and mark state dirty.
+
+        Args:
+            name: The name of the attribute.
+        """
+        self._mark_dirty(super().__delattr__, args=(name,))
+
+    def __delitem__(self, key: str):
+        """Delete the item on the proxied object and mark state dirty.
+
+        Args:
+            key: The key of the item.
+        """
+        self._mark_dirty(super().__delitem__, args=(key,))
+
+    def __setitem__(self, key: str, value: Any):
+        """Set the item on the proxied object and mark state dirty.
+
+        Args:
+            key: The key of the item.
+            value: The value of the item.
+        """
+        self._mark_dirty(super().__setitem__, args=(key, value))
+
+    def __setattr__(self, name: str, value: Any):
+        """Set the attribute on the proxied object and mark state dirty.
+
+        If the attribute starts with "_self_", then the state is NOT marked
+        dirty as these are internal proxy attributes.
+
+        Args:
+            name: The name of the attribute.
+            value: The value of the attribute.
+        """
+        if name.startswith("_self_"):
+            # Special case attributes of the proxy itself, not applied to the wrapped object.
+            super().__setattr__(name, value)
+            return
+        self._mark_dirty(super().__setattr__, args=(name, value))
+
+    def __copy__(self) -> Any:
+        """Return a copy of the proxy.
+
+        Returns:
+            A copy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.copy(self.__wrapped__)
+
+    def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
+        """Return a deepcopy of the proxy.
+
+        Args:
+            memo: The memo dict to use for the deepcopy.
+
+        Returns:
+            A deepcopy of the wrapped object, unconnected to the proxy.
+        """
+        return copy.deepcopy(self.__wrapped__, memo=memo)
+
+    def __reduce_ex__(self, protocol_version: SupportsIndex):
+        """Get the state for redis serialization.
+
+        This method is called by cloudpickle to serialize the object.
+
+        It explicitly serializes the wrapped object, stripping off the mutable proxy.
+
+        Args:
+            protocol_version: The protocol version.
+
+        Returns:
+            Tuple of (wrapped class, empty args, class __getstate__)
+        """
+        return self.__wrapped__.__reduce_ex__(protocol_version)
+
+
+@serializer
+def serialize_mutable_proxy(mp: MutableProxy):
+    """Return the wrapped value of a MutableProxy.
+
+    Args:
+        mp: The MutableProxy to serialize.
+
+    Returns:
+        The wrapped object.
+    """
+    return mp.__wrapped__
+
+
+_orig_json_encoder_default = json.JSONEncoder.default
+
+
+def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
+    """Wrap JSONEncoder.default to handle MutableProxy objects.
+
+    Args:
+        self: the JSONEncoder instance.
+        o: the object to serialize.
+
+    Returns:
+        A JSON-able object.
+    """
+    try:
+        return o.__wrapped__
+    except AttributeError:
+        pass
+    return _orig_json_encoder_default(self, o)
+
+
+json.JSONEncoder.default = _json_encoder_default_wrapper
+
+
+class ImmutableMutableProxy(MutableProxy):
+    """A proxy for a mutable object that tracks changes.
+
+    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
+    to modify the wrapped object when the StateProxy is immutable.
+    """
+
+    # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
+    __base_proxy__ = "ImmutableMutableProxy"
+
+    def _mark_dirty(
+        self,
+        wrapped: Callable | None = None,
+        instance: BaseState | None = None,
+        args: tuple = (),
+        kwargs: dict | None = None,
+    ) -> Any:
+        """Raise an exception when an attempt is made to modify the object.
+
+        Intended for use with `FunctionWrapper` from the `wrapt` library.
+
+        Args:
+            wrapped: The wrapped function.
+            instance: The instance of the wrapped function.
+            args: The args for the wrapped function.
+            kwargs: The kwargs for the wrapped function.
+
+        Returns:
+            The result of the wrapped function.
+
+        Raises:
+            ImmutableStateError: if the StateProxy is not mutable.
+        """
+        if not self._self_state._is_mutable():
+            raise ImmutableStateError(
+                "Background task StateProxy is immutable outside of a context "
+                "manager. Use `async with self` to modify state."
+            )
+        return super()._mark_dirty(
+            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
+        )

+ 12 - 6
reflex/reflex.py

@@ -151,8 +151,10 @@ def _run(
     if not frontend and backend:
         _skip_compile()
 
+    prerequisites.assert_in_reflex_dir()
+
     # Check that the app is initialized.
-    if prerequisites.needs_reinit(frontend=frontend):
+    if frontend and prerequisites.needs_reinit():
         _init(name=config.app_name)
 
     # Delete the states folder if it exists.
@@ -403,19 +405,21 @@ def export(
 
     environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.EXPORT)
 
-    frontend_only, backend_only = prerequisites.check_running_mode(
+    should_frontend_run, should_backend_run = prerequisites.check_running_mode(
         frontend_only, backend_only
     )
 
     config = get_config()
 
-    if prerequisites.needs_reinit(frontend=frontend_only or not backend_only):
+    prerequisites.assert_in_reflex_dir()
+
+    if should_frontend_run and prerequisites.needs_reinit():
         _init(name=config.app_name)
 
     export_utils.export(
         zipping=zip,
-        frontend=frontend_only,
-        backend=backend_only,
+        frontend=should_frontend_run,
+        backend=should_backend_run,
         zip_dest_dir=zip_dest_dir,
         upload_db_file=upload_db_file,
         env=constants.Env.DEV if env == constants.Env.DEV else constants.Env.PROD,
@@ -631,8 +635,10 @@ def deploy(
     if interactive:
         dependency.check_requirements()
 
+    prerequisites.assert_in_reflex_dir()
+
     # Check if we are set up.
-    if prerequisites.needs_reinit(frontend=True):
+    if prerequisites.needs_reinit():
         _init(name=config.app_name)
     prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME)
 

+ 94 - 1624
reflex/state.py

@@ -9,24 +9,19 @@ import copy
 import dataclasses
 import functools
 import inspect
-import json
 import pickle
 import sys
-import time
 import typing
-import uuid
 import warnings
-from abc import ABC, abstractmethod
+from abc import ABC
 from collections.abc import AsyncIterator, Callable, Sequence
 from hashlib import md5
-from pathlib import Path
-from types import FunctionType, MethodType
+from types import FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
     BinaryIO,
     ClassVar,
-    SupportsIndex,
     TypeVar,
     cast,
     get_args,
@@ -34,22 +29,16 @@ from typing import (
 )
 
 import pydantic.v1 as pydantic
-import wrapt
 from pydantic import BaseModel as BaseModelV2
 from pydantic.v1 import BaseModel as BaseModelV1
-from pydantic.v1 import validator
 from pydantic.v1.fields import ModelField
-from redis.asyncio import Redis
-from redis.asyncio.client import PubSub
-from redis.exceptions import ResponseError
 from rich.markup import escape
-from sqlalchemy.orm import DeclarativeBase
 from typing_extensions import Self
 
 import reflex.istate.dynamic
 from reflex import constants, event
 from reflex.base import Base
-from reflex.config import PerformanceMode, environment, get_config
+from reflex.config import PerformanceMode, environment
 from reflex.event import (
     BACKGROUND_TASK_MARKER,
     Event,
@@ -58,19 +47,17 @@ from reflex.event import (
     fix_events,
 )
 from reflex.istate.data import RouterData
+from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy
+from reflex.istate.proxy import MutableProxy, StateProxy
 from reflex.istate.storage import ClientStorageBase
 from reflex.model import Model
-from reflex.utils import console, format, path_ops, prerequisites, types
+from reflex.utils import console, format, prerequisites, types
 from reflex.utils.exceptions import (
     ComputedVarShadowsBaseVarsError,
     ComputedVarShadowsStateVarError,
     DynamicComponentInvalidSignatureError,
     DynamicRouteArgShadowsStateVarError,
     EventHandlerShadowsBuiltInStateMethodError,
-    ImmutableStateError,
-    InvalidLockWarningThresholdError,
-    InvalidStateManagerModeError,
-    LockExpiredError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     StateMismatchError,
@@ -79,13 +66,12 @@ from reflex.utils.exceptions import (
     StateTooLargeError,
     UnretrievableVarValueError,
 )
+from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError
 from reflex.utils.exec import is_testing_env
-from reflex.utils.serializers import serializer
 from reflex.utils.types import (
     _isinstance,
     get_origin,
     is_union,
-    override,
     true_type_for_pydantic_field,
     value_inside_optional,
 )
@@ -627,6 +613,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         all_base_state_classes[cls.get_full_name()] = None
 
+    @classmethod
+    def _add_event_handler(
+        cls,
+        name: str,
+        fn: Callable,
+    ):
+        """Add an event handler dynamically to the state.
+
+        Args:
+            name: The name of the event handler.
+            fn: The function to call when the event is triggered.
+        """
+        handler = cls._create_event_handler(fn)
+        cls.event_handlers[name] = handler
+        setattr(cls, name, handler)
+
     @staticmethod
     def _copy_fn(fn: Callable) -> Callable:
         """Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
@@ -1011,6 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         Raises:
             VarTypeError: if the variable has an incorrect type
         """
+        from reflex.config import get_config
         from reflex.utils.exceptions import VarTypeError
 
         if not types.is_valid_var_type(prop._var_type):
@@ -1021,7 +1024,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 f'Found var "{prop._js_expr}" with type {prop._var_type}.'
             )
         cls._set_var(prop)
-        cls._create_setter(prop)
+        if get_config().state_auto_setters:
+            cls._create_setter(prop)
         cls._set_default_value(prop)
 
     @classmethod
@@ -2268,6 +2272,35 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
 
 
+def _serialize_type(type_: Any) -> str:
+    """Serialize a type.
+
+    Args:
+        type_: The type to serialize.
+
+    Returns:
+        The serialized type.
+    """
+    if not inspect.isclass(type_):
+        return f"{type_}"
+    return f"{type_.__module__}.{type_.__qualname__}"
+
+
+def is_serializable(value: Any) -> bool:
+    """Check if a value is serializable.
+
+    Args:
+        value: The value to check.
+
+    Returns:
+        Whether the value is serializable.
+    """
+    try:
+        return bool(pickle.dumps(value))
+    except Exception:
+        return False
+
+
 T_STATE = TypeVar("T_STATE", bound=BaseState)
 
 
@@ -2507,278 +2540,6 @@ class ComponentState(State, mixin=True):
         return component
 
 
-class StateProxy(wrapt.ObjectProxy):
-    """Proxy of a state instance to control mutability of vars for a background task.
-
-    Since a background task runs against a state instance without holding the
-    state_manager lock for the token, the reference may become stale if the same
-    state is modified by another event handler.
-
-    The proxy object ensures that writes to the state are blocked unless
-    explicitly entering a context which refreshes the state from state_manager
-    and holds the lock for the token until exiting the context. After exiting
-    the context, a StateUpdate may be emitted to the frontend to notify the
-    client of the state change.
-
-    A background task will be passed the `StateProxy` as `self`, so mutability
-    can be safely performed inside an `async with self` block.
-
-        class State(rx.State):
-            counter: int = 0
-
-            @rx.event(background=True)
-            async def bg_increment(self):
-                await asyncio.sleep(1)
-                async with self:
-                    self.counter += 1
-    """
-
-    def __init__(
-        self,
-        state_instance: BaseState,
-        parent_state_proxy: StateProxy | None = None,
-    ):
-        """Create a proxy for a state instance.
-
-        If `get_state` is used on a StateProxy, the resulting state will be
-        linked to the given state via parent_state_proxy. The first state in the
-        chain is the state that initiated the background task.
-
-        Args:
-            state_instance: The state instance to proxy.
-            parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
-        """
-        super().__init__(state_instance)
-        # compile is not relevant to backend logic
-        self._self_app = prerequisites.get_and_validate_app().app
-        self._self_substate_path = tuple(state_instance.get_full_name().split("."))
-        self._self_actx = None
-        self._self_mutable = False
-        self._self_actx_lock = asyncio.Lock()
-        self._self_actx_lock_holder = None
-        self._self_parent_state_proxy = parent_state_proxy
-
-    def _is_mutable(self) -> bool:
-        """Check if the state is mutable.
-
-        Returns:
-            Whether the state is mutable.
-        """
-        if self._self_parent_state_proxy is not None:
-            return self._self_parent_state_proxy._is_mutable() or self._self_mutable
-        return self._self_mutable
-
-    async def __aenter__(self) -> StateProxy:
-        """Enter the async context manager protocol.
-
-        Sets mutability to True and enters the `App.modify_state` async context,
-        which refreshes the state from state_manager and holds the lock for the
-        given state token until exiting the context.
-
-        Background tasks should avoid blocking calls while inside the context.
-
-        Returns:
-            This StateProxy instance in mutable mode.
-
-        Raises:
-            ImmutableStateError: If the state is already mutable.
-        """
-        if self._self_parent_state_proxy is not None:
-            parent_state = (
-                await self._self_parent_state_proxy.__aenter__()
-            ).__wrapped__
-            super().__setattr__(
-                "__wrapped__",
-                await parent_state.get_state(
-                    State.get_class_substate(self._self_substate_path)
-                ),
-            )
-            return self
-        current_task = asyncio.current_task()
-        if (
-            self._self_actx_lock.locked()
-            and current_task == self._self_actx_lock_holder
-        ):
-            raise ImmutableStateError(
-                "The state is already mutable. Do not nest `async with self` blocks."
-            )
-        await self._self_actx_lock.acquire()
-        self._self_actx_lock_holder = current_task
-        self._self_actx = self._self_app.modify_state(
-            token=_substate_key(
-                self.__wrapped__.router.session.client_token,
-                self._self_substate_path,
-            )
-        )
-        mutable_state = await self._self_actx.__aenter__()
-        super().__setattr__(
-            "__wrapped__", mutable_state.get_substate(self._self_substate_path)
-        )
-        self._self_mutable = True
-        return self
-
-    async def __aexit__(self, *exc_info: Any) -> None:
-        """Exit the async context manager protocol.
-
-        Sets proxy mutability to False and persists any state changes.
-
-        Args:
-            exc_info: The exception info tuple.
-        """
-        if self._self_parent_state_proxy is not None:
-            await self._self_parent_state_proxy.__aexit__(*exc_info)
-            return
-        if self._self_actx is None:
-            return
-        self._self_mutable = False
-        try:
-            await self._self_actx.__aexit__(*exc_info)
-        finally:
-            self._self_actx_lock_holder = None
-            self._self_actx_lock.release()
-        self._self_actx = None
-
-    def __enter__(self):
-        """Enter the regular context manager protocol.
-
-        This is not supported for background tasks, and exists only to raise a more useful exception
-        when the StateProxy is used incorrectly.
-
-        Raises:
-            TypeError: always, because only async contextmanager protocol is supported.
-        """
-        raise TypeError("Background task must use `async with self` to modify state.")
-
-    def __exit__(self, *exc_info: Any) -> None:
-        """Exit the regular context manager protocol.
-
-        Args:
-            exc_info: The exception info tuple.
-        """
-        pass
-
-    def __getattr__(self, name: str) -> Any:
-        """Get the attribute from the underlying state instance.
-
-        Args:
-            name: The name of the attribute.
-
-        Returns:
-            The value of the attribute.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if name in ["substates", "parent_state"] and not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        value = super().__getattr__(name)
-        if not name.startswith("_self_") and isinstance(value, MutableProxy):
-            # ensure mutations to these containers are blocked unless proxy is _mutable
-            return ImmutableMutableProxy(
-                wrapped=value.__wrapped__,
-                state=self,
-                field_name=value._self_field_name,
-            )
-        if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
-            # Rebind event handler to the proxy instance
-            value = functools.partial(
-                value.func,
-                self,
-                *value.args[1:],
-                **value.keywords,
-            )
-        if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
-            # Rebind methods to the proxy instance
-            value = type(value)(value.__func__, self)
-        return value
-
-    def __setattr__(self, name: str, value: Any) -> None:
-        """Set the attribute on the underlying state instance.
-
-        If the attribute is internal, set it on the proxy instance instead.
-
-        Args:
-            name: The name of the attribute.
-            value: The value of the attribute.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if (
-            name.startswith("_self_")  # wrapper attribute
-            or self._is_mutable()  # lock held
-            # non-persisted state attribute
-            or name in self.__wrapped__.get_skip_vars()
-        ):
-            super().__setattr__(name, value)
-            return
-
-        raise ImmutableStateError(
-            "Background task StateProxy is immutable outside of a context "
-            "manager. Use `async with self` to modify state."
-        )
-
-    def get_substate(self, path: Sequence[str]) -> BaseState:
-        """Only allow substate access with lock held.
-
-        Args:
-            path: The path to the substate.
-
-        Returns:
-            The substate.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return self.__wrapped__.get_substate(path)
-
-    async def get_state(self, state_cls: type[BaseState]) -> BaseState:
-        """Get an instance of the state associated with this token.
-
-        Args:
-            state_cls: The class of the state.
-
-        Returns:
-            The state.
-
-        Raises:
-            ImmutableStateError: If the state is not in mutable mode.
-        """
-        if not self._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return type(self)(
-            await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
-        )
-
-    async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
-        """Temporarily allow mutability to access parent_state.
-
-        Args:
-            *args: The args to pass to the underlying state instance.
-            **kwargs: The kwargs to pass to the underlying state instance.
-
-        Returns:
-            The state update.
-        """
-        original_mutable = self._self_mutable
-        self._self_mutable = True
-        try:
-            return await self.__wrapped__._as_state_update(*args, **kwargs)
-        finally:
-            self._self_mutable = original_mutable
-
-
 @dataclasses.dataclass(
     frozen=True,
 )
@@ -2803,1345 +2564,54 @@ class StateUpdate:
         return format.json_dumps(self)
 
 
-class StateManager(Base, ABC):
-    """A class to manage many client states."""
-
-    # The state class to use.
-    state: type[BaseState]
-
-    @classmethod
-    def create(cls, state: type[BaseState]):
-        """Create a new state manager.
-
-        Args:
-            state: The state class to use.
-
-        Raises:
-            InvalidStateManagerModeError: If the state manager mode is invalid.
-
-        Returns:
-            The state manager (either disk, memory or redis).
-        """
-        config = get_config()
-        if prerequisites.parse_redis_url() is not None:
-            config.state_manager_mode = constants.StateManagerMode.REDIS
-        if config.state_manager_mode == constants.StateManagerMode.MEMORY:
-            return StateManagerMemory(state=state)
-        if config.state_manager_mode == constants.StateManagerMode.DISK:
-            return StateManagerDisk(state=state)
-        if config.state_manager_mode == constants.StateManagerMode.REDIS:
-            redis = prerequisites.get_redis()
-            if redis is not None:
-                # make sure expiration values are obtained only from the config object on creation
-                return StateManagerRedis(
-                    state=state,
-                    redis=redis,
-                    token_expiration=config.redis_token_expiration,
-                    lock_expiration=config.redis_lock_expiration,
-                    lock_warning_threshold=config.redis_lock_warning_threshold,
-                )
-        raise InvalidStateManagerModeError(
-            f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
-        )
-
-    @abstractmethod
-    async def get_state(self, token: str) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        pass
-
-    @abstractmethod
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        pass
-
-    @abstractmethod
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        yield self.state()
-
-
-class StateManagerMemory(StateManager):
-    """A state manager that stores states in memory."""
-
-    # The mapping of client ids to states.
-    states: dict[str, BaseState] = {}
-
-    # The mutex ensures the dict of mutexes is updated exclusively
-    _state_manager_lock = asyncio.Lock()
-
-    # The dict of mutexes for each client
-    _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
-
-    class Config:  # pyright: ignore [reportIncompatibleVariableOverride]
-        """The Pydantic config."""
-
-        fields = {
-            "_states_locks": {"exclude": True},
-        }
-
-    @override
-    async def get_state(self, token: str) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = _split_substate_key(token)[0]
-        if token not in self.states:
-            self.states[token] = self.state(_reflex_internal_init=True)
-        return self.states[token]
-
-    @override
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        pass
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        token = _split_substate_key(token)[0]
-        if token not in self._states_locks:
-            async with self._state_manager_lock:
-                if token not in self._states_locks:
-                    self._states_locks[token] = asyncio.Lock()
-
-        async with self._states_locks[token]:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state)
-
+def code_uses_state_contexts(javascript_code: str) -> bool:
+    """Check if the rendered Javascript uses state contexts.
 
-def _default_token_expiration() -> int:
-    """Get the default token expiration time.
+    Args:
+        javascript_code: The Javascript code to check.
 
     Returns:
-        The default token expiration time.
+        True if the code attempts to access a member of StateContexts.
     """
-    return get_config().redis_token_expiration
+    return bool("useContext(StateContexts" in javascript_code)
 
 
-def _serialize_type(type_: Any) -> str:
-    """Serialize a type.
+def reload_state_module(
+    module: str,
+    state: type[BaseState] = State,
+) -> None:
+    """Reset rx.State subclasses to avoid conflict when reloading.
 
     Args:
-        type_: The type to serialize.
+        module: The module to reload.
+        state: Recursive argument for the state class to reload.
 
-    Returns:
-        The serialized type.
     """
-    if not inspect.isclass(type_):
-        return f"{type_}"
-    return f"{type_.__module__}.{type_.__qualname__}"
-
-
-def is_serializable(value: Any) -> bool:
-    """Check if a value is serializable.
+    # Clean out all potentially dirty states of reloaded modules.
+    for pd_state in tuple(state._potentially_dirty_states):
+        with contextlib.suppress(ValueError):
+            if (
+                state.get_root_state().get_class_substate(pd_state).__module__ == module
+                and module is not None
+            ):
+                state._potentially_dirty_states.remove(pd_state)
+    for subclass in tuple(state.class_subclasses):
+        reload_state_module(module=module, state=subclass)
+        if subclass.__module__ == module and module is not None:
+            all_base_state_classes.pop(subclass.get_full_name(), None)
+            state.class_subclasses.remove(subclass)
+            state._always_dirty_substates.discard(subclass.get_name())
+            state._var_dependencies = {}
+            state._init_var_dependency_dicts()
+    state.get_class_substate.cache_clear()
 
-    Args:
-        value: The value to check.
 
-    Returns:
-        Whether the value is serializable.
-    """
-    try:
-        return bool(pickle.dumps(value))
-    except Exception:
-        return False
-
-
-def reset_disk_state_manager():
-    """Reset the disk state manager."""
-    states_directory = prerequisites.get_states_dir()
-    if states_directory.exists():
-        for path in states_directory.iterdir():
-            path.unlink()
-
-
-class StateManagerDisk(StateManager):
-    """A state manager that stores states in memory."""
-
-    # The mapping of client ids to states.
-    states: dict[str, BaseState] = {}
-
-    # The mutex ensures the dict of mutexes is updated exclusively
-    _state_manager_lock = asyncio.Lock()
-
-    # The dict of mutexes for each client
-    _states_locks: dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
-
-    # The token expiration time (s).
-    token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
-
-    class Config:  # pyright: ignore [reportIncompatibleVariableOverride]
-        """The Pydantic config."""
-
-        fields = {
-            "_states_locks": {"exclude": True},
-        }
-        keep_untouched = (functools.cached_property,)
-
-    def __init__(self, state: type[BaseState]):
-        """Create a new state manager.
-
-        Args:
-            state: The state class to use.
-        """
-        super().__init__(state=state)
-
-        path_ops.mkdir(self.states_directory)
-
-        self._purge_expired_states()
-
-    @functools.cached_property
-    def states_directory(self) -> Path:
-        """Get the states directory.
-
-        Returns:
-            The states directory.
-        """
-        return prerequisites.get_states_dir()
-
-    def _purge_expired_states(self):
-        """Purge expired states from the disk."""
-        import time
-
-        for path in path_ops.ls(self.states_directory):
-            # check path is a pickle file
-            if path.suffix != ".pkl":
-                continue
-
-            # load last edited field from file
-            last_edited = path.stat().st_mtime
-
-            # check if the file is older than the token expiration time
-            if time.time() - last_edited > self.token_expiration:
-                # remove the file
-                path.unlink()
-
-    def token_path(self, token: str) -> Path:
-        """Get the path for a token.
-
-        Args:
-            token: The token to get the path for.
-
-        Returns:
-            The path for the token.
-        """
-        return (
-            self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
-        ).absolute()
-
-    async def load_state(self, token: str) -> BaseState | None:
-        """Load a state object based on the provided token.
-
-        Args:
-            token: The token used to identify the state object.
-
-        Returns:
-            The loaded state object or None.
-        """
-        token_path = self.token_path(token)
-
-        if token_path.exists():
-            try:
-                with token_path.open(mode="rb") as file:
-                    return BaseState._deserialize(fp=file)
-            except Exception:
-                pass
-
-    async def populate_substates(
-        self, client_token: str, state: BaseState, root_state: BaseState
-    ):
-        """Populate the substates of a state object.
-
-        Args:
-            client_token: The client token.
-            state: The state object to populate.
-            root_state: The root state object.
-        """
-        for substate in state.get_substates():
-            substate_token = _substate_key(client_token, substate)
-
-            fresh_instance = await root_state.get_state(substate)
-            instance = await self.load_state(substate_token)
-            if instance is not None:
-                # Ensure all substates exist, even if they weren't serialized previously.
-                instance.substates = fresh_instance.substates
-            else:
-                instance = fresh_instance
-            state.substates[substate.get_name()] = instance
-            instance.parent_state = state
-
-            await self.populate_substates(client_token, instance, root_state)
-
-    @override
-    async def get_state(
-        self,
-        token: str,
-    ) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-
-        Returns:
-            The state for the token.
-        """
-        client_token = _split_substate_key(token)[0]
-        root_state = self.states.get(client_token)
-        if root_state is not None:
-            # Retrieved state from memory.
-            return root_state
-
-        # Deserialize root state from disk.
-        root_state = await self.load_state(_substate_key(client_token, self.state))
-        # Create a new root state tree with all substates instantiated.
-        fresh_root_state = self.state(_reflex_internal_init=True)
-        if root_state is None:
-            root_state = fresh_root_state
-        else:
-            # Ensure all substates exist, even if they were not serialized previously.
-            root_state.substates = fresh_root_state.substates
-        self.states[client_token] = root_state
-        await self.populate_substates(client_token, root_state, root_state)
-        return root_state
-
-    async def set_state_for_substate(self, client_token: str, substate: BaseState):
-        """Set the state for a substate.
-
-        Args:
-            client_token: The client token.
-            substate: The substate to set.
-        """
-        substate_token = _substate_key(client_token, substate)
-
-        if substate._get_was_touched():
-            substate._was_touched = False  # Reset the touched flag after serializing.
-            pickle_state = substate._serialize()
-            if pickle_state:
-                if not self.states_directory.exists():
-                    self.states_directory.mkdir(parents=True, exist_ok=True)
-                self.token_path(substate_token).write_bytes(pickle_state)
-
-        for substate_substate in substate.substates.values():
-            await self.set_state_for_substate(client_token, substate_substate)
-
-    @override
-    async def set_state(self, token: str, state: BaseState):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-        """
-        client_token, substate = _split_substate_key(token)
-        await self.set_state_for_substate(client_token, state)
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        # Memory state manager ignores the substate suffix and always returns the top-level state.
-        client_token, substate = _split_substate_key(token)
-        if client_token not in self._states_locks:
-            async with self._state_manager_lock:
-                if client_token not in self._states_locks:
-                    self._states_locks[client_token] = asyncio.Lock()
-
-        async with self._states_locks[client_token]:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state)
-
-
-def _default_lock_expiration() -> int:
-    """Get the default lock expiration time.
-
-    Returns:
-        The default lock expiration time.
-    """
-    return get_config().redis_lock_expiration
-
-
-def _default_lock_warning_threshold() -> int:
-    """Get the default lock warning threshold.
-
-    Returns:
-        The default lock warning threshold.
-    """
-    return get_config().redis_lock_warning_threshold
-
-
-class StateManagerRedis(StateManager):
-    """A state manager that stores states in redis."""
-
-    # The redis client to use.
-    redis: Redis
-
-    # The token expiration time (s).
-    token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
-
-    # The maximum time to hold a lock (ms).
-    lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
-
-    # The maximum time to hold a lock (ms) before warning.
-    lock_warning_threshold: int = pydantic.Field(
-        default_factory=_default_lock_warning_threshold
-    )
-
-    # The keyspace subscription string when redis is waiting for lock to be released.
-    _redis_notify_keyspace_events: str = (
-        "K"  # Enable keyspace notifications (target a particular key)
-        "g"  # For generic commands (DEL, EXPIRE, etc)
-        "x"  # For expired events
-        "e"  # For evicted events (i.e. maxmemory exceeded)
-    )
-
-    # These events indicate that a lock is no longer held.
-    _redis_keyspace_lock_release_events: set[bytes] = {
-        b"del",
-        b"expire",
-        b"expired",
-        b"evicted",
-    }
-
-    # Whether keyspace notifications have been enabled.
-    _redis_notify_keyspace_events_enabled: bool = False
-
-    # The logical database number used by the redis client.
-    _redis_db: int = 0
-
-    def _get_required_state_classes(
-        self,
-        target_state_cls: type[BaseState],
-        subclasses: bool = False,
-        required_state_classes: set[type[BaseState]] | None = None,
-    ) -> set[type[BaseState]]:
-        """Recursively determine which states are required to fetch the target state.
-
-        This will always include potentially dirty substates that depend on vars
-        in the target_state_cls.
-
-        Args:
-            target_state_cls: The target state class being fetched.
-            subclasses: Whether to include subclasses of the target state.
-            required_state_classes: Recursive argument tracking state classes that have already been seen.
-
-        Returns:
-            The set of state classes required to fetch the target state.
-        """
-        if required_state_classes is None:
-            required_state_classes = set()
-        # Get the substates if requested.
-        if subclasses:
-            for substate in target_state_cls.get_substates():
-                self._get_required_state_classes(
-                    substate,
-                    subclasses=True,
-                    required_state_classes=required_state_classes,
-                )
-        if target_state_cls in required_state_classes:
-            return required_state_classes
-        required_state_classes.add(target_state_cls)
-
-        # Get dependent substates.
-        for pd_substates in target_state_cls._get_potentially_dirty_states():
-            self._get_required_state_classes(
-                pd_substates,
-                subclasses=False,
-                required_state_classes=required_state_classes,
-            )
-
-        # Get the parent state if it exists.
-        if parent_state := target_state_cls.get_parent_state():
-            self._get_required_state_classes(
-                parent_state,
-                subclasses=False,
-                required_state_classes=required_state_classes,
-            )
-        return required_state_classes
-
-    def _get_populated_states(
-        self,
-        target_state: BaseState,
-        populated_states: dict[str, BaseState] | None = None,
-    ) -> dict[str, BaseState]:
-        """Recursively determine which states from target_state are already fetched.
-
-        Args:
-            target_state: The state to check for populated states.
-            populated_states: Recursive argument tracking states seen in previous calls.
-
-        Returns:
-            A dictionary of state full name to state instance.
-        """
-        if populated_states is None:
-            populated_states = {}
-        if target_state.get_full_name() in populated_states:
-            return populated_states
-        populated_states[target_state.get_full_name()] = target_state
-        for substate in target_state.substates.values():
-            self._get_populated_states(substate, populated_states=populated_states)
-        if target_state.parent_state is not None:
-            self._get_populated_states(
-                target_state.parent_state, populated_states=populated_states
-            )
-        return populated_states
-
-    @override
-    async def get_state(
-        self,
-        token: str,
-        top_level: bool = True,
-        for_state_instance: BaseState | None = None,
-    ) -> BaseState:
-        """Get the state for a token.
-
-        Args:
-            token: The token to get the state for.
-            top_level: If true, return an instance of the top-level state (self.state).
-            for_state_instance: If provided, attach the requested states to this existing state tree.
-
-        Returns:
-            The state for the token.
-
-        Raises:
-            RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
-                requested state was not fetched.
-        """
-        # Split the actual token from the fully qualified substate name.
-        token, state_path = _split_substate_key(token)
-        if state_path:
-            # Get the State class associated with the given path.
-            state_cls = self.state.get_class_substate(state_path)
-        else:
-            raise RuntimeError(
-                f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
-            )
-
-        # Determine which states we already have.
-        flat_state_tree: dict[str, BaseState] = (
-            self._get_populated_states(for_state_instance) if for_state_instance else {}
-        )
-
-        # Determine which states from the tree need to be fetched.
-        required_state_classes = sorted(
-            self._get_required_state_classes(state_cls, subclasses=True)
-            - {type(s) for s in flat_state_tree.values()},
-            key=lambda x: x.get_full_name(),
-        )
-
-        redis_pipeline = self.redis.pipeline()
-        for state_cls in required_state_classes:
-            redis_pipeline.get(_substate_key(token, state_cls))
-
-        for state_cls, redis_state in zip(
-            required_state_classes,
-            await redis_pipeline.execute(),
-            strict=False,
-        ):
-            state = None
-
-            if redis_state is not None:
-                # Deserialize the substate.
-                with contextlib.suppress(StateSchemaMismatchError):
-                    state = BaseState._deserialize(data=redis_state)
-            if state is None:
-                # Key didn't exist or schema mismatch so create a new instance for this token.
-                state = state_cls(
-                    init_substates=False,
-                    _reflex_internal_init=True,
-                )
-            flat_state_tree[state.get_full_name()] = state
-            if state.get_parent_state() is not None:
-                parent_state_name, _dot, state_name = state.get_full_name().rpartition(
-                    "."
-                )
-                parent_state = flat_state_tree.get(parent_state_name)
-                if parent_state is None:
-                    raise RuntimeError(
-                        f"Parent state for {state.get_full_name()} was not found "
-                        "in the state tree, but should have already been fetched. "
-                        "This is a bug",
-                    )
-                parent_state.substates[state_name] = state
-                state.parent_state = parent_state
-
-        # To retain compatibility with previous implementation, by default, we return
-        # the top-level state which should always be fetched or already cached.
-        if top_level:
-            return flat_state_tree[self.state.get_full_name()]
-        return flat_state_tree[state_cls.get_full_name()]
-
-    @override
-    async def set_state(
-        self,
-        token: str,
-        state: BaseState,
-        lock_id: bytes | None = None,
-    ):
-        """Set the state for a token.
-
-        Args:
-            token: The token to set the state for.
-            state: The state to set.
-            lock_id: If provided, the lock_key must be set to this value to set the state.
-
-        Raises:
-            LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
-            RuntimeError: If the state instance doesn't match the state name in the token.
-        """
-        # Check that we're holding the lock.
-        if (
-            lock_id is not None
-            and await self.redis.get(self._lock_key(token)) != lock_id
-        ):
-            raise LockExpiredError(
-                f"Lock expired for token {token} while processing. Consider increasing "
-                f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
-                "or use `@rx.event(background=True)` decorator for long-running tasks."
-            )
-        elif lock_id is not None:
-            time_taken = self.lock_expiration / 1000 - (
-                await self.redis.ttl(self._lock_key(token))
-            )
-            if time_taken > self.lock_warning_threshold / 1000:
-                console.warn(
-                    f"Lock for token {token} was held too long {time_taken=}s, "
-                    f"use `@rx.event(background=True)` decorator for long-running tasks.",
-                    dedupe=True,
-                )
-
-        client_token, substate_name = _split_substate_key(token)
-        # If the substate name on the token doesn't match the instance name, it cannot have a parent.
-        if state.parent_state is not None and state.get_full_name() != substate_name:
-            raise RuntimeError(
-                f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
-            )
-
-        # Recursively set_state on all known substates.
-        tasks = [
-            asyncio.create_task(
-                self.set_state(
-                    _substate_key(client_token, substate),
-                    substate,
-                    lock_id,
-                )
-            )
-            for substate in state.substates.values()
-        ]
-        # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
-        if state._get_was_touched():
-            pickle_state = state._serialize()
-            if pickle_state:
-                await self.redis.set(
-                    _substate_key(client_token, state),
-                    pickle_state,
-                    ex=self.token_expiration,
-                )
-
-        # Wait for substates to be persisted.
-        for t in tasks:
-            await t
-
-    @override
-    @contextlib.asynccontextmanager
-    async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
-        """Modify the state for a token while holding exclusive lock.
-
-        Args:
-            token: The token to modify the state for.
-
-        Yields:
-            The state for the token.
-        """
-        async with self._lock(token) as lock_id:
-            state = await self.get_state(token)
-            yield state
-            await self.set_state(token, state, lock_id)
-
-    @validator("lock_warning_threshold")
-    @classmethod
-    def validate_lock_warning_threshold(
-        cls, lock_warning_threshold: int, values: dict[str, int]
-    ):
-        """Validate the lock warning threshold.
-
-        Args:
-            lock_warning_threshold: The lock warning threshold.
-            values: The validated attributes.
-
-        Returns:
-            The lock warning threshold.
-
-        Raises:
-            InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
-        """
-        if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
-            raise InvalidLockWarningThresholdError(
-                f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
-            )
-        return lock_warning_threshold
-
-    @staticmethod
-    def _lock_key(token: str) -> bytes:
-        """Get the redis key for a token's lock.
-
-        Args:
-            token: The token to get the lock key for.
-
-        Returns:
-            The redis lock key for the token.
-        """
-        # All substates share the same lock domain, so ignore any substate path suffix.
-        client_token = _split_substate_key(token)[0]
-        return f"{client_token}_lock".encode()
-
-    async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
-        """Try to get a redis lock for a token.
-
-        Args:
-            lock_key: The redis key for the lock.
-            lock_id: The ID of the lock.
-
-        Returns:
-            True if the lock was obtained.
-        """
-        return await self.redis.set(
-            lock_key,
-            lock_id,
-            px=self.lock_expiration,
-            nx=True,  # only set if it doesn't exist
-        )
-
-    async def _get_pubsub_message(
-        self, pubsub: PubSub, timeout: float | None = None
-    ) -> None:
-        """Get lock release events from the pubsub.
-
-        Args:
-            pubsub: The pubsub to get a message from.
-            timeout: Remaining time to wait for a message.
-
-        Returns:
-            The message.
-        """
-        if timeout is None:
-            timeout = self.lock_expiration / 1000.0
-
-        started = time.time()
-        message = await pubsub.get_message(
-            ignore_subscribe_messages=True,
-            timeout=timeout,
-        )
-        if (
-            message is None
-            or message["data"] not in self._redis_keyspace_lock_release_events
-        ):
-            remaining = timeout - (time.time() - started)
-            if remaining <= 0:
-                return
-            await self._get_pubsub_message(pubsub, timeout=remaining)
-
-    async def _enable_keyspace_notifications(self):
-        """Enable keyspace notifications for the redis server.
-
-        Raises:
-            ResponseError: when the keyspace config cannot be set.
-        """
-        if self._redis_notify_keyspace_events_enabled:
-            return
-        # Find out which logical database index is being used.
-        self._redis_db = self.redis.get_connection_kwargs().get("db", self._redis_db)
-
-        try:
-            await self.redis.config_set(
-                "notify-keyspace-events",
-                self._redis_notify_keyspace_events,
-            )
-        except ResponseError:
-            # Some redis servers only allow out-of-band configuration, so ignore errors here.
-            if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
-                raise
-        self._redis_notify_keyspace_events_enabled = True
-
-    async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
-        """Wait for a redis lock to be released via pubsub.
-
-        Coroutine will not return until the lock is obtained.
-
-        Args:
-            lock_key: The redis key for the lock.
-            lock_id: The ID of the lock.
-        """
-        # Enable keyspace notifications for the lock key, so we know when it is available.
-        await self._enable_keyspace_notifications()
-        lock_key_channel = f"__keyspace@{self._redis_db}__:{lock_key.decode()}"
-        async with self.redis.pubsub() as pubsub:
-            await pubsub.psubscribe(lock_key_channel)
-            # wait for the lock to be released
-            while True:
-                # fast path
-                if await self._try_get_lock(lock_key, lock_id):
-                    return
-                # wait for lock events
-                await self._get_pubsub_message(pubsub)
-
-    @contextlib.asynccontextmanager
-    async def _lock(self, token: str):
-        """Obtain a redis lock for a token.
-
-        Args:
-            token: The token to obtain a lock for.
-
-        Yields:
-            The ID of the lock (to be passed to set_state).
-
-        Raises:
-            LockExpiredError: If the lock has expired while processing the event.
-        """
-        lock_key = self._lock_key(token)
-        lock_id = uuid.uuid4().hex.encode()
-
-        if not await self._try_get_lock(lock_key, lock_id):
-            # Missed the fast-path to get lock, subscribe for lock delete/expire events
-            await self._wait_lock(lock_key, lock_id)
-        state_is_locked = True
-
-        try:
-            yield lock_id
-        except LockExpiredError:
-            state_is_locked = False
-            raise
-        finally:
-            if state_is_locked:
-                # only delete our lock
-                await self.redis.delete(lock_key)
-
-    async def close(self):
-        """Explicitly close the redis connection and connection_pool.
-
-        It is necessary in testing scenarios to close between asyncio test cases
-        to avoid having lingering redis connections associated with event loops
-        that will be closed (each test case uses its own event loop).
-
-        Note: Connections will be automatically reopened when needed.
-        """
-        await self.redis.aclose(close_connection_pool=True)
-
-
-def get_state_manager() -> StateManager:
-    """Get the state manager for the app that is currently running.
-
-    Returns:
-        The state manager.
-    """
-    return prerequisites.get_and_validate_app().app.state_manager
-
-
-class MutableProxy(wrapt.ObjectProxy):
-    """A proxy for a mutable object that tracks changes."""
-
-    # Hint for finding the base class of the proxy.
-    __base_proxy__ = "MutableProxy"
-
-    # Methods on wrapped objects which should mark the state as dirty.
-    __mark_dirty_attrs__ = {
-        "add",
-        "append",
-        "clear",
-        "difference_update",
-        "discard",
-        "extend",
-        "insert",
-        "intersection_update",
-        "pop",
-        "popitem",
-        "remove",
-        "reverse",
-        "setdefault",
-        "sort",
-        "symmetric_difference_update",
-        "update",
-    }
-
-    # Methods on wrapped objects might return mutable objects that should be tracked.
-    __wrap_mutable_attrs__ = {
-        "get",
-        "setdefault",
-    }
-
-    # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
-    __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
-        pydantic.BaseModel.__dict__
-    )
-
-    # These types will be wrapped in MutableProxy
-    __mutable_types__ = (
-        list,
-        dict,
-        set,
-        Base,
-        DeclarativeBase,
-        BaseModelV2,
-        BaseModelV1,
-    )
-
-    # Dynamically generated classes for tracking dataclass mutations.
-    __dataclass_proxies__: dict[type, type] = {}
-
-    def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
-        """Create a proxy instance for a mutable object that tracks changes.
-
-        Args:
-            wrapped: The object to proxy.
-            *args: Other args passed to MutableProxy (ignored).
-            **kwargs: Other kwargs passed to MutableProxy (ignored).
-
-        Returns:
-            The proxy instance.
-        """
-        if dataclasses.is_dataclass(wrapped):
-            wrapped_cls = type(wrapped)
-            wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
-            # Find the associated class
-            if wrapper_cls_name not in cls.__dataclass_proxies__:
-                # Create a new class that has the __dataclass_fields__ defined
-                cls.__dataclass_proxies__[wrapper_cls_name] = type(
-                    wrapper_cls_name,
-                    (cls,),
-                    {
-                        dataclasses._FIELDS: getattr(  # pyright: ignore [reportAttributeAccessIssue]
-                            wrapped_cls,
-                            dataclasses._FIELDS,  # pyright: ignore [reportAttributeAccessIssue]
-                        ),
-                    },
-                )
-            cls = cls.__dataclass_proxies__[wrapper_cls_name]
-        return super().__new__(cls)
-
-    def __init__(self, wrapped: Any, state: BaseState, field_name: str):
-        """Create a proxy for a mutable object that tracks changes.
-
-        Args:
-            wrapped: The object to proxy.
-            state: The state to mark dirty when the object is changed.
-            field_name: The name of the field on the state associated with the
-                wrapped object.
-        """
-        super().__init__(wrapped)
-        self._self_state = state
-        self._self_field_name = field_name
-
-    def __repr__(self) -> str:
-        """Get the representation of the wrapped object.
-
-        Returns:
-            The representation of the wrapped object.
-        """
-        return f"{type(self).__name__}({self.__wrapped__})"
-
-    def _mark_dirty(
-        self,
-        wrapped: Callable | None = None,
-        instance: BaseState | None = None,
-        args: tuple = (),
-        kwargs: dict | None = None,
-    ) -> Any:
-        """Mark the state as dirty, then call a wrapped function.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function.
-        """
-        self._self_state.dirty_vars.add(self._self_field_name)
-        self._self_state._mark_dirty()
-        if wrapped is not None:
-            return wrapped(*args, **(kwargs or {}))
-
-    @classmethod
-    def _is_mutable_type(cls, value: Any) -> bool:
-        """Check if a value is of a mutable type and should be wrapped.
-
-        Args:
-            value: The value to check.
-
-        Returns:
-            Whether the value is of a mutable type.
-        """
-        return isinstance(value, cls.__mutable_types__) or (
-            dataclasses.is_dataclass(value) and not isinstance(value, Var)
-        )
-
-    @staticmethod
-    def _is_called_from_dataclasses_internal() -> bool:
-        """Check if the current function is called from dataclasses helper.
-
-        Returns:
-            Whether the current function is called from dataclasses internal code.
-        """
-        # Walk up the stack a bit to see if we are called from dataclasses
-        # internal code, for example `asdict` or `astuple`.
-        frame = inspect.currentframe()
-        for _ in range(5):
-            # Why not `inspect.stack()` -- this is much faster!
-            if not (frame := frame and frame.f_back):
-                break
-            if inspect.getfile(frame) == dataclasses.__file__:
-                return True
-        return False
-
-    def _wrap_recursive(self, value: Any) -> Any:
-        """Wrap a value recursively if it is mutable.
-
-        Args:
-            value: The value to wrap.
-
-        Returns:
-            The wrapped value.
-        """
-        # When called from dataclasses internal code, return the unwrapped value
-        if self._is_called_from_dataclasses_internal():
-            return value
-        # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
-        if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
-            base_cls = globals()[self.__base_proxy__]
-            return base_cls(
-                wrapped=value,
-                state=self._self_state,
-                field_name=self._self_field_name,
-            )
-        return value
-
-    def _wrap_recursive_decorator(
-        self, wrapped: Callable, instance: BaseState, args: list, kwargs: dict
-    ) -> Any:
-        """Wrap a function that returns a possibly mutable value.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function (possibly wrapped in a MutableProxy).
-        """
-        return self._wrap_recursive(wrapped(*args, **kwargs))
-
-    def __getattr__(self, __name: str) -> Any:
-        """Get the attribute on the proxied object and return a proxy if mutable.
-
-        Args:
-            __name: The name of the attribute.
-
-        Returns:
-            The attribute value.
-        """
-        value = super().__getattr__(__name)
-
-        if callable(value):
-            if __name in self.__mark_dirty_attrs__:
-                # Wrap special callables, like "append", which should mark state dirty.
-                value = wrapt.FunctionWrapper(value, self._mark_dirty)
-
-            if __name in self.__wrap_mutable_attrs__:
-                # Wrap methods that may return mutable objects tied to the state.
-                value = wrapt.FunctionWrapper(
-                    value,
-                    self._wrap_recursive_decorator,
-                )
-
-            if (
-                isinstance(self.__wrapped__, Base)
-                and __name not in self.__never_wrap_base_attrs__
-                and hasattr(value, "__func__")
-            ):
-                # Wrap methods called on Base subclasses, which might do _anything_
-                return wrapt.FunctionWrapper(
-                    functools.partial(value.__func__, self),  # pyright: ignore [reportFunctionMemberAccess]
-                    self._wrap_recursive_decorator,
-                )
-
-        if self._is_mutable_type(value) and __name not in (
-            "__wrapped__",
-            "_self_state",
-            "__dict__",
-        ):
-            # Recursively wrap mutable attribute values retrieved through this proxy.
-            return self._wrap_recursive(value)
-
-        return value
-
-    def __getitem__(self, key: Any) -> Any:
-        """Get the item on the proxied object and return a proxy if mutable.
-
-        Args:
-            key: The key of the item.
-
-        Returns:
-            The item value.
-        """
-        value = super().__getitem__(key)
-        # Recursively wrap mutable items retrieved through this proxy.
-        return self._wrap_recursive(value)
-
-    def __iter__(self) -> Any:
-        """Iterate over the proxied object and return a proxy if mutable.
-
-        Yields:
-            Each item value (possibly wrapped in MutableProxy).
-        """
-        for value in super().__iter__():
-            # Recursively wrap mutable items retrieved through this proxy.
-            yield self._wrap_recursive(value)
-
-    def __delattr__(self, name: str):
-        """Delete the attribute on the proxied object and mark state dirty.
-
-        Args:
-            name: The name of the attribute.
-        """
-        self._mark_dirty(super().__delattr__, args=(name,))
-
-    def __delitem__(self, key: str):
-        """Delete the item on the proxied object and mark state dirty.
-
-        Args:
-            key: The key of the item.
-        """
-        self._mark_dirty(super().__delitem__, args=(key,))
-
-    def __setitem__(self, key: str, value: Any):
-        """Set the item on the proxied object and mark state dirty.
-
-        Args:
-            key: The key of the item.
-            value: The value of the item.
-        """
-        self._mark_dirty(super().__setitem__, args=(key, value))
-
-    def __setattr__(self, name: str, value: Any):
-        """Set the attribute on the proxied object and mark state dirty.
-
-        If the attribute starts with "_self_", then the state is NOT marked
-        dirty as these are internal proxy attributes.
-
-        Args:
-            name: The name of the attribute.
-            value: The value of the attribute.
-        """
-        if name.startswith("_self_"):
-            # Special case attributes of the proxy itself, not applied to the wrapped object.
-            super().__setattr__(name, value)
-            return
-        self._mark_dirty(super().__setattr__, args=(name, value))
-
-    def __copy__(self) -> Any:
-        """Return a copy of the proxy.
-
-        Returns:
-            A copy of the wrapped object, unconnected to the proxy.
-        """
-        return copy.copy(self.__wrapped__)
-
-    def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Any:
-        """Return a deepcopy of the proxy.
-
-        Args:
-            memo: The memo dict to use for the deepcopy.
-
-        Returns:
-            A deepcopy of the wrapped object, unconnected to the proxy.
-        """
-        return copy.deepcopy(self.__wrapped__, memo=memo)
-
-    def __reduce_ex__(self, protocol_version: SupportsIndex):
-        """Get the state for redis serialization.
-
-        This method is called by cloudpickle to serialize the object.
-
-        It explicitly serializes the wrapped object, stripping off the mutable proxy.
-
-        Args:
-            protocol_version: The protocol version.
-
-        Returns:
-            Tuple of (wrapped class, empty args, class __getstate__)
-        """
-        return self.__wrapped__.__reduce_ex__(protocol_version)
-
-
-@serializer
-def serialize_mutable_proxy(mp: MutableProxy):
-    """Return the wrapped value of a MutableProxy.
-
-    Args:
-        mp: The MutableProxy to serialize.
-
-    Returns:
-        The wrapped object.
-    """
-    return mp.__wrapped__
-
-
-_orig_json_encoder_default = json.JSONEncoder.default
-
-
-def _json_encoder_default_wrapper(self: json.JSONEncoder, o: Any) -> Any:
-    """Wrap JSONEncoder.default to handle MutableProxy objects.
-
-    Args:
-        self: the JSONEncoder instance.
-        o: the object to serialize.
-
-    Returns:
-        A JSON-able object.
-    """
-    try:
-        return o.__wrapped__
-    except AttributeError:
-        pass
-    return _orig_json_encoder_default(self, o)
-
-
-json.JSONEncoder.default = _json_encoder_default_wrapper
-
-
-class ImmutableMutableProxy(MutableProxy):
-    """A proxy for a mutable object that tracks changes.
-
-    This wrapper comes from StateProxy, and will raise an exception if an attempt is made
-    to modify the wrapped object when the StateProxy is immutable.
-    """
-
-    # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
-    __base_proxy__ = "ImmutableMutableProxy"
-
-    def _mark_dirty(
-        self,
-        wrapped: Callable | None = None,
-        instance: BaseState | None = None,
-        args: tuple = (),
-        kwargs: dict | None = None,
-    ) -> Any:
-        """Raise an exception when an attempt is made to modify the object.
-
-        Intended for use with `FunctionWrapper` from the `wrapt` library.
-
-        Args:
-            wrapped: The wrapped function.
-            instance: The instance of the wrapped function.
-            args: The args for the wrapped function.
-            kwargs: The kwargs for the wrapped function.
-
-        Returns:
-            The result of the wrapped function.
-
-        Raises:
-            ImmutableStateError: if the StateProxy is not mutable.
-        """
-        if not self._self_state._is_mutable():
-            raise ImmutableStateError(
-                "Background task StateProxy is immutable outside of a context "
-                "manager. Use `async with self` to modify state."
-            )
-        return super()._mark_dirty(
-            wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
-        )
-
-
-def code_uses_state_contexts(javascript_code: str) -> bool:
-    """Check if the rendered Javascript uses state contexts.
-
-    Args:
-        javascript_code: The Javascript code to check.
-
-    Returns:
-        True if the code attempts to access a member of StateContexts.
-    """
-    return bool("useContext(StateContexts" in javascript_code)
-
-
-def reload_state_module(
-    module: str,
-    state: type[BaseState] = State,
-) -> None:
-    """Reset rx.State subclasses to avoid conflict when reloading.
-
-    Args:
-        module: The module to reload.
-        state: Recursive argument for the state class to reload.
-
-    """
-    # Clean out all potentially dirty states of reloaded modules.
-    for pd_state in tuple(state._potentially_dirty_states):
-        with contextlib.suppress(ValueError):
-            if (
-                state.get_root_state().get_class_substate(pd_state).__module__ == module
-                and module is not None
-            ):
-                state._potentially_dirty_states.remove(pd_state)
-    for subclass in tuple(state.class_subclasses):
-        reload_state_module(module=module, state=subclass)
-        if subclass.__module__ == module and module is not None:
-            all_base_state_classes.pop(subclass.get_full_name(), None)
-            state.class_subclasses.remove(subclass)
-            state._always_dirty_substates.discard(subclass.get_name())
-            state._var_dependencies = {}
-            state._init_var_dependency_dicts()
-    state.get_class_substate.cache_clear()
+from reflex.istate.manager import LockExpiredError as LockExpiredError  # noqa: E402
+from reflex.istate.manager import StateManager as StateManager  # noqa: E402
+from reflex.istate.manager import StateManagerDisk as StateManagerDisk  # noqa: E402
+from reflex.istate.manager import StateManagerMemory as StateManagerMemory  # noqa: E402
+from reflex.istate.manager import StateManagerRedis as StateManagerRedis  # noqa: E402
+from reflex.istate.manager import get_state_manager as get_state_manager  # noqa: E402
+from reflex.istate.manager import (  # noqa: E402
+    reset_disk_state_manager as reset_disk_state_manager,
+)

+ 20 - 13
reflex/testing.py

@@ -45,6 +45,7 @@ from reflex.state import (
 )
 from reflex.utils import console
 from reflex.utils.export import export
+from reflex.utils.types import ASGIApp
 
 try:
     from selenium import webdriver
@@ -110,6 +111,7 @@ class AppHarness:
     app_module_path: Path
     app_module: types.ModuleType | None = None
     app_instance: reflex.App | None = None
+    app_asgi: ASGIApp | None = None
     frontend_process: subprocess.Popen | None = None
     frontend_url: str | None = None
     frontend_output_thread: threading.Thread | None = None
@@ -270,11 +272,14 @@ class AppHarness:
             # Ensure the AppHarness test does not skip State assignment due to running via pytest
             os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
             os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
-            self.app_module = reflex.utils.prerequisites.get_compiled_app(
-                # Do not reload the module for pre-existing apps (only apps generated from source)
-                reload=self.app_source is not None
+            # Ensure we actually compile the app during first initialization.
+            self.app_instance, self.app_module = (
+                reflex.utils.prerequisites.get_and_validate_app(
+                    # Do not reload the module for pre-existing apps (only apps generated from source)
+                    reload=self.app_source is not None
+                )
             )
-        self.app_instance = self.app_module.app
+            self.app_asgi = self.app_instance()
         if self.app_instance and isinstance(
             self.app_instance._state_manager, StateManagerRedis
         ):
@@ -300,10 +305,10 @@ class AppHarness:
         async def _shutdown(*args, **kwargs) -> None:
             # ensure redis is closed before event loop
             if self.app_instance is not None and isinstance(
-                self.app_instance.state_manager, StateManagerRedis
+                self.app_instance._state_manager, StateManagerRedis
             ):
                 with contextlib.suppress(ValueError):
-                    await self.app_instance.state_manager.close()
+                    await self.app_instance._state_manager.close()
 
             # socketio shutdown handler
             if self.app_instance is not None and self.app_instance.sio is not None:
@@ -323,11 +328,11 @@ class AppHarness:
         return _shutdown
 
     def _start_backend(self, port: int = 0):
-        if self.app_instance is None or self.app_instance._api is None:
+        if self.app_asgi is None:
             raise RuntimeError("App was not initialized.")
         self.backend = uvicorn.Server(
             uvicorn.Config(
-                app=self.app_instance._api,
+                app=self.app_asgi,
                 host="127.0.0.1",
                 port=port,
             )
@@ -349,13 +354,13 @@ class AppHarness:
         if (
             self.app_instance is not None
             and isinstance(
-                self.app_instance.state_manager,
+                self.app_instance._state_manager,
                 StateManagerRedis,
             )
             and self.app_instance._state is not None
         ):
             with contextlib.suppress(RuntimeError):
-                await self.app_instance.state_manager.close()
+                await self.app_instance._state_manager.close()
             self.app_instance._state_manager = StateManagerRedis.create(
                 state=self.app_instance._state,
             )
@@ -937,7 +942,9 @@ class AppHarnessProd(AppHarness):
 
             get_config().loglevel = reflex.constants.LogLevel.INFO
 
-            if reflex.utils.prerequisites.needs_reinit(frontend=True):
+            reflex.utils.prerequisites.assert_in_reflex_dir()
+
+            if reflex.utils.prerequisites.needs_reinit():
                 reflex.reflex._init(name=get_config().app_name)
 
             export(
@@ -957,12 +964,12 @@ class AppHarnessProd(AppHarness):
             raise RuntimeError("Frontend did not start")
 
     def _start_backend(self):
-        if self.app_instance is None:
+        if self.app_asgi is None:
             raise RuntimeError("App was not initialized.")
         environment.REFLEX_SKIP_COMPILE.set(True)
         self.backend = uvicorn.Server(
             uvicorn.Config(
-                app=self.app_instance,
+                app=self.app_asgi,
                 host="127.0.0.1",
                 port=0,
                 workers=reflex.utils.processes.get_num_workers(),

+ 1 - 2
reflex/utils/build.py

@@ -83,7 +83,7 @@ def _zip(
     files_to_zip: list[str] = []
     # Traverse the root directory in a top-down manner. In this traversal order,
     # we can modify the dirs list in-place to remove directories we don't want to include.
-    for root, dirs, files in os.walk(root_dir, topdown=True):
+    for root, dirs, files in os.walk(root_dir, topdown=True, followlinks=True):
         root = Path(root)
         # Modify the dirs in-place so excluded and hidden directories are skipped in next traversal.
         dirs[:] = [
@@ -112,7 +112,6 @@ def _zip(
                 for file in root_dir.glob(glob)
                 if file.name not in files_to_exclude
             ]
-
     # Create a progress bar for zipping the component.
     progress = Progress(
         *Progress.get_default_columns()[:-1],

+ 16 - 6
reflex/utils/format.py

@@ -181,10 +181,13 @@ def to_camel_case(text: str, treat_hyphens_as_underscores: bool = True) -> str:
     Returns:
         The camel case string.
     """
-    char = "_" if not treat_hyphens_as_underscores else "-_"
-    words = re.split(f"[{char}]", text)
+    if treat_hyphens_as_underscores:
+        text = text.replace("-", "_")
+    words = text.split("_")
     # Capitalize the first letter of each word except the first one
-    converted_word = words[0] + "".join(x.capitalize() for x in words[1:])
+    if len(words) == 1:
+        return words[0]
+    converted_word = words[0] + "".join([w.capitalize() for w in words[1:]])
     return converted_word
 
 
@@ -436,12 +439,19 @@ def format_props(*single_props, **key_value_props) -> list[str]:
     from reflex.vars.base import LiteralVar, Var
 
     return [
-        (
-            f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}"
+        ":".join(
+            [
+                str(name if "-" not in name else LiteralVar.create(name)),
+                str(
+                    format_prop(
+                        prop if isinstance(prop, Var) else LiteralVar.create(prop)
+                    )
+                ),
+            ]
         )
         for name, prop in sorted(key_value_props.items())
         if prop is not None
-    ] + [(f"{LiteralVar.create(prop)!s}") for prop in single_props]
+    ] + [(f"...{LiteralVar.create(prop)!s}") for prop in single_props]
 
 
 def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:

+ 22 - 35
reflex/utils/prerequisites.py

@@ -37,7 +37,7 @@ from redis.exceptions import RedisError
 from reflex import constants, model
 from reflex.compiler import templates
 from reflex.config import Config, environment, get_config
-from reflex.utils import console, net, path_ops, processes
+from reflex.utils import console, net, path_ops, processes, redir
 from reflex.utils.decorator import once
 from reflex.utils.exceptions import SystemPackageMissingError
 from reflex.utils.format import format_library_name
@@ -61,7 +61,6 @@ class Template:
     name: str
     description: str
     code_url: str
-    demo_url: str | None = None
 
 
 @dataclasses.dataclass(frozen=True)
@@ -1362,17 +1361,11 @@ def check_running_mode(frontend: bool, backend: bool) -> tuple[bool, bool]:
     return frontend, backend
 
 
-def needs_reinit(frontend: bool = True) -> bool:
-    """Check if an app needs to be reinitialized.
-
-    Args:
-        frontend: Whether to check if the frontend is initialized.
-
-    Returns:
-        Whether the app needs to be reinitialized.
+def assert_in_reflex_dir():
+    """Assert that the current working directory is the reflex directory.
 
     Raises:
-        Exit: If the app is not initialized.
+        Exit: If the current working directory is not the reflex directory.
     """
     if not constants.Config.FILE.exists():
         console.error(
@@ -1380,10 +1373,13 @@ def needs_reinit(frontend: bool = True) -> bool:
         )
         raise click.exceptions.Exit(1)
 
-    # Don't need to reinit if not running in frontend mode.
-    if not frontend:
-        return False
 
+def needs_reinit() -> bool:
+    """Check if an app needs to be reinitialized.
+
+    Returns:
+        Whether the app needs to be reinitialized.
+    """
     # Make sure the .reflex directory exists.
     if not environment.REFLEX_DIR.get().exists():
         return True
@@ -1597,22 +1593,13 @@ def prompt_for_template_options(templates: list[Template]) -> str:
     # Show the user the URLs of each template to preview.
     console.print("\nGet started with a template:")
 
-    def format_demo_url_str(url: str | None) -> str:
-        return f" ({url})" if url else ""
-
     # Prompt the user to select a template.
-    id_to_name = {
-        str(
-            idx
-        ): f"{template.name.replace('_', ' ').replace('-', ' ')}{format_demo_url_str(template.demo_url)} - {template.description}"
-        for idx, template in enumerate(templates)
-    }
-    for id in range(len(id_to_name)):
-        console.print(f"({id}) {id_to_name[str(id)]}")
+    for index, template in enumerate(templates):
+        console.print(f"({index}) {template.description}")
 
     template = console.ask(
         "Which template would you like to use?",
-        choices=[str(i) for i in range(len(id_to_name))],
+        choices=[str(i) for i in range(len(templates))],
         show_choices=False,
         default="0",
     )
@@ -1881,14 +1868,17 @@ def initialize_app(app_name: str, template: str | None = None) -> str | None:
 
     if template is None:
         template = prompt_for_template_options(get_init_cli_prompt_options())
+
         if template == constants.Templates.CHOOSE_TEMPLATES:
-            console.print(
-                f"Go to the templates page ({constants.Templates.REFLEX_TEMPLATES_URL}) and copy the command to init with a template."
-            )
+            redir.reflex_templates()
             raise click.exceptions.Exit(0)
 
+    if template == constants.Templates.AI:
+        redir.reflex_build_redirect()
+        raise click.exceptions.Exit(0)
+
     # If the blank template is selected, create a blank app.
-    if template in (constants.Templates.DEFAULT,):
+    if template == constants.Templates.DEFAULT:
         # Default app creation behavior: a blank app.
         initialize_default_app(app_name)
     else:
@@ -1911,19 +1901,16 @@ def get_init_cli_prompt_options() -> list[Template]:
         Template(
             name=constants.Templates.DEFAULT,
             description="A blank Reflex app.",
-            demo_url=constants.Templates.DEFAULT_TEMPLATE_URL,
             code_url="",
         ),
         Template(
             name=constants.Templates.AI,
-            description="Generate a template using AI [Experimental]",
-            demo_url="",
+            description="[bold]Try our free AI builder.",
             code_url="",
         ),
         Template(
             name=constants.Templates.CHOOSE_TEMPLATES,
-            description="Choose an existing template.",
-            demo_url="",
+            description="Premade templates built by the Reflex team.",
             code_url="",
         ),
     ]

+ 1 - 0
reflex/utils/pyi_generator.py

@@ -50,6 +50,7 @@ EXCLUDED_PROPS = [
     "tag",
     "is_default",
     "special_props",
+    "_is_tag_in_global_scope",
     "_invalid_children",
     "_memoization_mode",
     "_rename_props",

+ 7 - 0
reflex/utils/redir.py

@@ -21,6 +21,8 @@ def open_browser(target_url: str) -> None:
         console.warn(
             f"Unable to automatically open the browser. Please navigate to {target_url} in your browser."
         )
+    else:
+        console.info(f"Opening browser to {target_url}.")
 
 
 def open_browser_and_wait(
@@ -52,3 +54,8 @@ def open_browser_and_wait(
 def reflex_build_redirect() -> None:
     """Open the browser window to reflex.build."""
     open_browser(constants.Templates.REFLEX_BUILD_FRONTEND)
+
+
+def reflex_templates():
+    """Open the browser window to reflex.build/templates."""
+    open_browser(constants.Templates.REFLEX_TEMPLATES_URL)

+ 14 - 0
reflex/utils/serializers.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import contextlib
 import dataclasses
+import decimal
 import functools
 import inspect
 import json
@@ -386,6 +387,19 @@ def serialize_uuid(uuid: UUID) -> str:
     return str(uuid)
 
 
+@serializer(to=float)
+def serialize_decimal(value: decimal.Decimal) -> float:
+    """Serialize a Decimal to a float.
+
+    Args:
+        value: The Decimal to serialize.
+
+    Returns:
+        The serialized Decimal as a float.
+    """
+    return float(value)
+
+
 @serializer(to=str)
 def serialize_color(color: Color) -> str:
     """Serialize a color.

+ 24 - 8
reflex/utils/types.py

@@ -120,6 +120,10 @@ class Unset:
 
 
 @lru_cache
+def _get_origin_cached(tp: Any):
+    return get_origin_og(tp)
+
+
 def get_origin(tp: Any):
     """Get the origin of a class.
 
@@ -129,7 +133,11 @@ def get_origin(tp: Any):
     Returns:
         The origin of the class.
     """
-    return get_origin_og(tp)
+    return (
+        origin
+        if (origin := getattr(tp, "__origin__", None)) is not None
+        else _get_origin_cached(tp)
+    )
 
 
 @lru_cache
@@ -190,7 +198,6 @@ def is_none(cls: GenericType) -> bool:
     return cls is type(None) or cls is None
 
 
-@lru_cache
 def is_union(cls: GenericType) -> bool:
     """Check if a class is a Union.
 
@@ -200,10 +207,12 @@ def is_union(cls: GenericType) -> bool:
     Returns:
         Whether the class is a Union.
     """
-    return get_origin(cls) in UnionTypes
+    origin = getattr(cls, "__origin__", None)
+    if origin is Union:
+        return True
+    return origin is None and isinstance(cls, types.UnionType)
 
 
-@lru_cache
 def is_literal(cls: GenericType) -> bool:
     """Check if a class is a Literal.
 
@@ -213,7 +222,7 @@ def is_literal(cls: GenericType) -> bool:
     Returns:
         Whether the class is a literal.
     """
-    return get_origin(cls) is Literal
+    return getattr(cls, "__origin__", None) is Literal
 
 
 def has_args(cls: type) -> bool:
@@ -900,12 +909,19 @@ def validate_parameter_literals(func: Callable):
     Returns:
         The wrapper function.
     """
+    console.deprecate(
+        "validate_parameter_literals",
+        reason="Use manual validation instead.",
+        deprecation_version="0.7.11",
+        removal_version="0.8.0",
+        dedupe=True,
+    )
+
+    func_params = list(inspect.signature(func).parameters.items())
+    annotations = {param[0]: param[1].annotation for param in func_params}
 
     @wraps(func)
     def wrapper(*args, **kwargs):
-        func_params = list(inspect.signature(func).parameters.items())
-        annotations = {param[0]: param[1].annotation for param in func_params}
-
         # validate args
         for param, arg in zip(annotations, args, strict=False):
             if annotations[param] is inspect.Parameter.empty:

+ 13 - 1
reflex/vars/base.py

@@ -15,6 +15,7 @@ import string
 import uuid
 import warnings
 from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
+from decimal import Decimal
 from types import CodeType, FunctionType
 from typing import (  # noqa: UP035
     TYPE_CHECKING,
@@ -630,6 +631,14 @@ class Var(Generic[VAR_TYPE]):
         _var_data: VarData | None = None,
     ) -> LiteralNumberVar[float]: ...
 
+    @overload
+    @classmethod
+    def create(
+        cls,
+        value: Decimal,
+        _var_data: VarData | None = None,
+    ) -> LiteralNumberVar[Decimal]: ...
+
     @overload
     @classmethod
     def create(  # pyright: ignore [reportOverlappingOverload]
@@ -743,7 +752,10 @@ class Var(Generic[VAR_TYPE]):
     def to(self, output: type[int]) -> NumberVar[int]: ...
 
     @overload
-    def to(self, output: type[int] | type[float]) -> NumberVar: ...
+    def to(self, output: type[float]) -> NumberVar[float]: ...
+
+    @overload
+    def to(self, output: type[Decimal]) -> NumberVar[Decimal]: ...
 
     @overload
     def to(

+ 16 - 8
reflex/vars/number.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import dataclasses
+import decimal
 import json
 import math
 from collections.abc import Callable
@@ -30,7 +31,10 @@ from .base import (
 )
 
 NUMBER_T = TypeVarExt(
-    "NUMBER_T", bound=(int | float), default=(int | float), covariant=True
+    "NUMBER_T",
+    bound=(int | float | decimal.Decimal),
+    default=(int | float | decimal.Decimal),
+    covariant=True,
 )
 
 if TYPE_CHECKING:
@@ -54,7 +58,7 @@ def raise_unsupported_operand_types(
     )
 
 
-class NumberVar(Var[NUMBER_T], python_types=(int, float)):
+class NumberVar(Var[NUMBER_T], python_types=(int, float, decimal.Decimal)):
     """Base class for immutable number vars."""
 
     def __add__(self, other: number_types) -> NumberVar:
@@ -285,13 +289,13 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
 
         return number_exponent_operation(+other, self)
 
-    def __neg__(self):
+    def __neg__(self) -> NumberVar:
         """Negate the number.
 
         Returns:
             The number negation operation.
         """
-        return number_negate_operation(self)
+        return number_negate_operation(self)  # pyright: ignore [reportReturnType]
 
     def __invert__(self):
         """Boolean NOT the number.
@@ -943,7 +947,7 @@ def boolean_not_operation(value: BooleanVar):
 class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
     """Base class for immutable literal number vars."""
 
-    _var_value: float | int = dataclasses.field(default=0)
+    _var_value: float | int | decimal.Decimal = dataclasses.field(default=0)
 
     def json(self) -> str:
         """Get the JSON representation of the var.
@@ -954,6 +958,8 @@ class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
         Raises:
             PrimitiveUnserializableToJSONError: If the var is unserializable to JSON.
         """
+        if isinstance(self._var_value, decimal.Decimal):
+            return json.dumps(float(self._var_value))
         if math.isinf(self._var_value) or math.isnan(self._var_value):
             raise PrimitiveUnserializableToJSONError(
                 f"No valid JSON representation for {self}"
@@ -969,7 +975,9 @@ class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
         return hash((type(self).__name__, self._var_value))
 
     @classmethod
-    def create(cls, value: float | int, _var_data: VarData | None = None):
+    def create(
+        cls, value: float | int | decimal.Decimal, _var_data: VarData | None = None
+    ):
         """Create the number var.
 
         Args:
@@ -1039,7 +1047,7 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
         )
 
 
-number_types = NumberVar | int | float
+number_types = NumberVar | int | float | decimal.Decimal
 boolean_types = BooleanVar | bool
 
 
@@ -1112,4 +1120,4 @@ def ternary_operation(
     return value
 
 
-NUMBER_TYPES = (int, float, NumberVar)
+NUMBER_TYPES = (int, float, decimal.Decimal, NumberVar)

+ 2 - 1
reflex/vars/sequence.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import collections.abc
 import dataclasses
+import decimal
 import inspect
 import json
 import re
@@ -1558,7 +1559,7 @@ def is_tuple_type(t: GenericType) -> bool:
 
 
 def _determine_value_of_array_index(
-    var_type: GenericType, index: int | float | None = None
+    var_type: GenericType, index: int | float | decimal.Decimal | None = None
 ):
     """Determine the value of an array index.
 

Fișier diff suprimat deoarece este prea mare
+ 0 - 0
tests/benchmarks/fixtures.py


+ 1 - 2
tests/integration/test_connection_banner.py

@@ -1,6 +1,5 @@
 """Test case for displaying the connection banner when the websocket drops."""
 
-import functools
 from collections.abc import Generator
 
 import pytest
@@ -77,7 +76,7 @@ def connection_banner(
 
     with AppHarness.create(
         root=tmp_path,
-        app_source=functools.partial(ConnectionBanner),
+        app_source=ConnectionBanner,
         app_name=(
             "connection_banner_reflex_cloud"
             if simulate_compile_context == constants.CompileContext.DEPLOY

+ 57 - 6
tests/integration/test_lifespan.py

@@ -1,5 +1,6 @@
 """Test cases for the Starlette lifespan integration."""
 
+import functools
 from collections.abc import Generator
 
 import pytest
@@ -10,8 +11,15 @@ from reflex.testing import AppHarness
 from .utils import SessionStorage
 
 
-def LifespanApp():
-    """App with lifespan tasks and context."""
+def LifespanApp(
+    mount_cached_fastapi: bool = False, mount_api_transformer: bool = False
+) -> None:
+    """App with lifespan tasks and context.
+
+    Args:
+        mount_cached_fastapi: Whether to mount the cached FastAPI app.
+        mount_api_transformer: Whether to mount the API transformer.
+    """
     import asyncio
     from contextlib import asynccontextmanager
 
@@ -72,25 +80,68 @@ def LifespanApp():
             ),
         )
 
-    app = rx.App()
+    from fastapi import FastAPI
+
+    app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None)
+
+    if mount_cached_fastapi:
+        assert app.api is not None
+
     app.register_lifespan_task(lifespan_task)
     app.register_lifespan_task(lifespan_context, inc=2)
     app.add_page(index)
 
 
+@pytest.fixture(
+    params=[False, True], ids=["no_api_transformer", "mount_api_transformer"]
+)
+def mount_api_transformer(request: pytest.FixtureRequest) -> bool:
+    """Whether to use api_transformer in the app.
+
+    Args:
+        request: pytest fixture request object
+
+    Returns:
+        bool: Whether to use api_transformer
+    """
+    return request.param
+
+
+@pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"])
+def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool:
+    """Whether to use cached FastAPI in the app (app.api).
+
+    Args:
+        request: pytest fixture request object
+
+    Returns:
+        Whether to use cached FastAPI
+    """
+    return request.param
+
+
 @pytest.fixture()
-def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
+def lifespan_app(
+    tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool
+) -> Generator[AppHarness, None, None]:
     """Start LifespanApp app at tmp_path via AppHarness.
 
     Args:
         tmp_path: pytest tmp_path fixture
+        mount_api_transformer: Whether to mount the API transformer.
+        mount_cached_fastapi: Whether to mount the cached FastAPI app.
 
     Yields:
         running AppHarness instance
     """
     with AppHarness.create(
         root=tmp_path,
-        app_source=LifespanApp,
+        app_source=functools.partial(
+            LifespanApp,
+            mount_cached_fastapi=mount_cached_fastapi,
+            mount_api_transformer=mount_api_transformer,
+        ),
+        app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}",
     ) as harness:
         yield harness
 
@@ -112,7 +163,7 @@ async def test_lifespan(lifespan_app: AppHarness):
     context_global = driver.find_element(By.ID, "context_global")
     task_global = driver.find_element(By.ID, "task_global")
 
-    assert context_global.text == "2"
+    assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2"
     assert lifespan_app.app_module.lifespan_context_global == 2
 
     original_task_global_text = task_global.text

+ 4 - 4
tests/units/components/base/test_bare.py

@@ -9,10 +9,10 @@ STATE_VAR = Var(_js_expr="default_state.name")
 @pytest.mark.parametrize(
     "contents,expected",
     [
-        ("hello", '{"hello"}'),
-        ("{}", '{"{}"}'),
-        (None, '{""}'),
-        (STATE_VAR, "{default_state.name}"),
+        ("hello", '"hello"'),
+        ("{}", '"{}"'),
+        (None, '""'),
+        (STATE_VAR, "default_state.name"),
     ],
 )
 def test_fstrings(contents, expected):

+ 4 - 6
tests/units/components/base/test_link.py

@@ -3,13 +3,11 @@ from reflex.components.base.link import RawLink, ScriptTag
 
 def test_raw_link():
     raw_link = RawLink.create("https://example.com").render()
-    assert raw_link["name"] == "link"
-    assert raw_link["children"][0]["contents"] == '{"https://example.com"}'
+    assert raw_link["name"] == '"link"'
+    assert raw_link["children"][0]["contents"] == '"https://example.com"'
 
 
 def test_script_tag():
     script_tag = ScriptTag.create("console.log('Hello, world!');").render()
-    assert script_tag["name"] == "script"
-    assert (
-        script_tag["children"][0]["contents"] == "{\"console.log('Hello, world!');\"}"
-    )
+    assert script_tag["name"] == '"script"'
+    assert script_tag["children"][0]["contents"] == "\"console.log('Hello, world!');\""

+ 5 - 5
tests/units/components/base/test_script.py

@@ -14,7 +14,7 @@ def test_script_inline():
     assert render_dict["name"] == "Script"
     assert not render_dict["contents"]
     assert len(render_dict["children"]) == 1
-    assert render_dict["children"][0]["contents"] == '{"let x = 42"}'
+    assert render_dict["children"][0]["contents"] == '"let x = 42"'
 
 
 def test_script_src():
@@ -24,7 +24,7 @@ def test_script_src():
     assert render_dict["name"] == "Script"
     assert not render_dict["contents"]
     assert not render_dict["children"]
-    assert 'src={"foo.js"}' in render_dict["props"]
+    assert 'src:"foo.js"' in render_dict["props"]
 
 
 def test_script_neither():
@@ -62,14 +62,14 @@ def test_script_event_handler():
     )
     render_dict = component.render()
     assert (
-        f'onReady={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
+        f'onReady:((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{  }}), ({{  }})))], args, ({{  }}))))'
         in render_dict["props"]
     )
     assert (
-        f'onLoad={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
+        f'onLoad:((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_load", ({{  }}), ({{  }})))], args, ({{  }}))))'
         in render_dict["props"]
     )
     assert (
-        f'onError={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }}), ({{  }})))], args, ({{  }}))))}}'
+        f'onError:((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_error", ({{  }}), ({{  }})))], args, ({{  }}))))'
         in render_dict["props"]
     )

+ 19 - 12
tests/units/components/core/test_colors.py

@@ -9,10 +9,10 @@ from reflex.vars.base import LiteralVar
 class ColorState(rx.State):
     """Test color state."""
 
-    color: str = "mint"
-    color_part: str = "tom"
-    shade: int = 4
-    alpha: bool = False
+    color: rx.Field[str] = rx.field("mint")
+    color_part: rx.Field[str] = rx.field("tom")
+    shade: rx.Field[int] = rx.field(4)
+    alpha: rx.Field[bool] = rx.field(False)
 
 
 color_state_name = ColorState.get_full_name().replace(".", "__")
@@ -22,6 +22,12 @@ def create_color_var(color):
     return LiteralVar.create(color)
 
 
+color_with_fstring = rx.color(
+    f"{ColorState.color}",  # pyright: ignore [reportArgumentType]
+    ColorState.shade,
+)
+
+
 @pytest.mark.parametrize(
     "color, expected, expected_type",
     [
@@ -41,26 +47,27 @@ def create_color_var(color):
             Color,
         ),
         (
-            create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")),
-            f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
+            create_color_var(color_with_fstring),
+            f'("var(--"+{color_state_name!s}.color+"-"+(((__to_string) => __to_string.toString())({color_state_name!s}.shade))+")")',
             Color,
         ),
         (
             create_color_var(
-                rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}")
+                rx.color(
+                    f"{ColorState.color_part}ato",  # pyright: ignore [reportArgumentType]
+                    ColorState.shade,
+                )
             ),
-            f'("var(--"+({color_state_name!s}.color_part+"ato")+"-"+{color_state_name!s}.shade+")")',
+            f'("var(--"+({color_state_name!s}.color_part+"ato")+"-"+(((__to_string) => __to_string.toString())({color_state_name!s}.shade))+")")',
             Color,
         ),
         (
-            create_color_var(f"{rx.color(ColorState.color, f'{ColorState.shade}')}"),
+            create_color_var(f"{rx.color(ColorState.color, ColorState.shade)}"),
             f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
             str,
         ),
         (
-            create_color_var(
-                f"{rx.color(f'{ColorState.color}', f'{ColorState.shade}')}"
-            ),
+            create_color_var(f"{color_with_fstring}"),
             f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
             str,
         ),

+ 2 - 2
tests/units/components/core/test_cond.py

@@ -57,7 +57,7 @@ def test_validate_cond(cond_state: BaseState):
 
     [true_value_text] = true_value["children"]
     assert true_value_text["name"] == "RadixThemesText"
-    assert true_value_text["children"][0]["contents"] == '{"cond is True"}'
+    assert true_value_text["children"][0]["contents"] == '"cond is True"'
 
     # false value
     false_value = condition["false_value"]
@@ -65,7 +65,7 @@ def test_validate_cond(cond_state: BaseState):
 
     [false_value_text] = false_value["children"]
     assert false_value_text["name"] == "RadixThemesText"
-    assert false_value_text["children"][0]["contents"] == '{"cond is False"}'
+    assert false_value_text["children"][0]["contents"] == '"cond is False"'
 
 
 @pytest.mark.parametrize(

+ 1 - 1
tests/units/components/core/test_foreach.py

@@ -282,7 +282,7 @@ def test_foreach_component_styles():
         )
     )
     component._add_style_recursive({box: {"color": "red"}})
-    assert 'css={({ ["color"] : "red" })}' in str(component)
+    assert 'css:({ ["color"] : "red" })' in str(component)
 
 
 def test_foreach_component_state():

+ 5 - 3
tests/units/components/core/test_html.py

@@ -19,7 +19,7 @@ def test_html_create():
     assert str(html.dangerouslySetInnerHTML) == '({ ["__html"] : "<p>Hello !</p>" })'  # pyright: ignore [reportAttributeAccessIssue]
     assert (
         str(html)
-        == '<div className={"rx-Html"} dangerouslySetInnerHTML={({ ["__html"] : "<p>Hello !</p>" })}/>'
+        == 'jsx("div",{className:"rx-Html",dangerouslySetInnerHTML:({ ["__html"] : "<p>Hello !</p>" })},)\n'
     )
 
 
@@ -31,11 +31,13 @@ def test_html_fstring_create():
 
     html = Html.create(f"<p>Hello {TestState.myvar}!</p>")
 
+    html_dangerouslySetInnerHTML = html.dangerouslySetInnerHTML  # pyright: ignore [reportAttributeAccessIssue]
+
     assert (
-        str(html.dangerouslySetInnerHTML)  # pyright: ignore [reportAttributeAccessIssue]
+        str(html_dangerouslySetInnerHTML)
         == f'({{ ["__html"] : ("<p>Hello "+{TestState.myvar!s}+"!</p>") }})'
     )
     assert (
         str(html)
-        == f'<div className={{"rx-Html"}} dangerouslySetInnerHTML={{{html.dangerouslySetInnerHTML!s}}}/>'  # pyright: ignore [reportAttributeAccessIssue]
+        == f'jsx("div",{{className:"rx-Html",dangerouslySetInnerHTML:{html_dangerouslySetInnerHTML!s}}},)\n'
     )

+ 10 - 9
tests/units/components/core/test_match.py

@@ -1,3 +1,4 @@
+import re
 from collections.abc import Mapping, Sequence
 
 import pytest
@@ -47,7 +48,7 @@ def test_match_components():
     assert match_cases[0][0]._var_type is int
     first_return_value_render = match_cases[0][1]
     assert first_return_value_render["name"] == "RadixThemesText"
-    assert first_return_value_render["children"][0]["contents"] == '{"first value"}'
+    assert first_return_value_render["children"][0]["contents"] == '"first value"'
 
     assert match_cases[1][0]._js_expr == "2"
     assert match_cases[1][0]._var_type is int
@@ -55,36 +56,36 @@ def test_match_components():
     assert match_cases[1][1]._var_type is int
     second_return_value_render = match_cases[1][2]
     assert second_return_value_render["name"] == "RadixThemesText"
-    assert second_return_value_render["children"][0]["contents"] == '{"second value"}'
+    assert second_return_value_render["children"][0]["contents"] == '"second value"'
 
     assert match_cases[2][0]._js_expr == "[1, 2]"
     assert match_cases[2][0]._var_type == Sequence[int]
     third_return_value_render = match_cases[2][1]
     assert third_return_value_render["name"] == "RadixThemesText"
-    assert third_return_value_render["children"][0]["contents"] == '{"third value"}'
+    assert third_return_value_render["children"][0]["contents"] == '"third value"'
 
     assert match_cases[3][0]._js_expr == '"random"'
     assert match_cases[3][0]._var_type is str
     fourth_return_value_render = match_cases[3][1]
     assert fourth_return_value_render["name"] == "RadixThemesText"
-    assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
+    assert fourth_return_value_render["children"][0]["contents"] == '"fourth value"'
 
     assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
     assert match_cases[4][0]._var_type == Mapping[str, str]
     fifth_return_value_render = match_cases[4][1]
     assert fifth_return_value_render["name"] == "RadixThemesText"
-    assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
+    assert fifth_return_value_render["children"][0]["contents"] == '"fifth value"'
 
     assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)"
     assert match_cases[5][0]._var_type is int
     fifth_return_value_render = match_cases[5][1]
     assert fifth_return_value_render["name"] == "RadixThemesText"
-    assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}'
+    assert fifth_return_value_render["children"][0]["contents"] == '"sixth value"'
 
     default = match_child["default"]
 
     assert default["name"] == "RadixThemesText"
-    assert default["children"][0]["contents"] == '{"default value"}'
+    assert default["children"][0]["contents"] == '"default value"'
 
 
 @pytest.mark.parametrize(
@@ -264,7 +265,7 @@ def test_match_case_tuple_elements(match_case):
                 ([1, 2], rx.text("third value")),
                 rx.text("default value"),
             ),
-            'Match cases should have the same return types. Case 3 with return value `<RadixThemesText as={"p"}> {"first value"} </RadixThemesText>` '
+            'Match cases should have the same return types. Case 3 with return value `jsx( RadixThemesText, {as:"p"}, "first value" ,)` '
             "of type <class 'reflex.components.radix.themes.typography.text.Text'> is not <class 'reflex.vars.base.Var'>",
         ),
     ],
@@ -276,7 +277,7 @@ def test_match_different_return_types(cases: tuple, error_msg: str):
         cases: The match cases.
         error_msg: Expected error message.
     """
-    with pytest.raises(MatchTypeError, match=error_msg):
+    with pytest.raises(MatchTypeError, match=re.escape(error_msg)):
         Match.create(MatchState.value, *cases)
 
 

+ 2 - 2
tests/units/components/datadisplay/test_dataeditor.py

@@ -4,8 +4,8 @@ from reflex.components.datadisplay.dataeditor import DataEditor
 def test_dataeditor():
     editor_wrapper = DataEditor.create().render()
     editor = editor_wrapper["children"][0]
-    assert editor_wrapper["name"] == "div"
+    assert editor_wrapper["name"] == '"div"'
     assert editor_wrapper["props"] == [
-        'css={({ ["width"] : "100%", ["height"] : "100%" })}'
+        'css:({ ["width"] : "100%", ["height"] : "100%" })'
     ]
     assert editor["name"] == "DataEditor"

+ 2 - 2
tests/units/components/datadisplay/test_datatable.py

@@ -47,8 +47,8 @@ def test_validate_data_table(data_table_state: rx.State, expected):
     expected = f"{state_name}.{expected}" if expected else state_name
 
     assert data_table_dict["props"] == [
-        f"columns={{{expected}.columns}}",
-        f"data={{{expected}.data}}",
+        f"columns:{expected}.columns",
+        f"data:{expected}.data",
     ]
 
 

+ 12 - 12
tests/units/components/el/test_svg.py

@@ -16,59 +16,59 @@ from reflex.components.el.elements.media import (
 
 def test_circle():
     circle = Circle.create().render()
-    assert circle["name"] == "circle"
+    assert circle["name"] == '"circle"'
 
 
 def test_defs():
     defs = Defs.create().render()
-    assert defs["name"] == "defs"
+    assert defs["name"] == '"defs"'
 
 
 def test_ellipse():
     ellipse = Ellipse.create().render()
-    assert ellipse["name"] == "ellipse"
+    assert ellipse["name"] == '"ellipse"'
 
 
 def test_line():
     line = Line.create().render()
-    assert line["name"] == "line"
+    assert line["name"] == '"line"'
 
 
 def test_linear_gradient():
     linear_gradient = LinearGradient.create().render()
-    assert linear_gradient["name"] == "linearGradient"
+    assert linear_gradient["name"] == '"linearGradient"'
 
 
 def test_path():
     path = Path.create().render()
-    assert path["name"] == "path"
+    assert path["name"] == '"path"'
 
 
 def test_polygon():
     polygon = Polygon.create().render()
-    assert polygon["name"] == "polygon"
+    assert polygon["name"] == '"polygon"'
 
 
 def test_radial_gradient():
     radial_gradient = RadialGradient.create().render()
-    assert radial_gradient["name"] == "radialGradient"
+    assert radial_gradient["name"] == '"radialGradient"'
 
 
 def test_rect():
     rect = Rect.create().render()
-    assert rect["name"] == "rect"
+    assert rect["name"] == '"rect"'
 
 
 def test_svg():
     svg = Svg.create().render()
-    assert svg["name"] == "svg"
+    assert svg["name"] == '"svg"'
 
 
 def test_text():
     text = Text.create().render()
-    assert text["name"] == "text"
+    assert text["name"] == '"text"'
 
 
 def test_stop():
     stop = Stop.create().render()
-    assert stop["name"] == "stop"
+    assert stop["name"] == '"stop"'

+ 1 - 1
tests/units/components/forms/test_form.py

@@ -11,7 +11,7 @@ def test_render_on_submit():
     )
     f = Form.create(on_submit=submit_it)
     exp_submit_name = f"handleSubmit_{f.handle_submit_unique_name}"  # pyright: ignore [reportAttributeAccessIssue]
-    assert f"onSubmit={{{exp_submit_name}}}" in f.render()["props"]
+    assert f"onSubmit:{exp_submit_name}" in f.render()["props"]
 
 
 def test_render_no_on_submit():

Fișier diff suprimat deoarece este prea mare
+ 0 - 0
tests/units/components/markdown/test_markdown.py


+ 38 - 20
tests/units/components/test_component.py

@@ -686,14 +686,14 @@ def test_component_create_unallowed_types(children, test_component):
                 "children": [
                     {
                         "name": "RadixThemesText",
-                        "props": ['as={"p"}'],
+                        "props": ['as:"p"'],
                         "contents": "",
                         "special_props": [],
                         "children": [
                             {
                                 "name": "",
                                 "props": [],
-                                "contents": '{"first_text"}',
+                                "contents": '"first_text"',
                                 "special_props": [],
                                 "children": [],
                                 "autofocus": False,
@@ -716,7 +716,7 @@ def test_component_create_unallowed_types(children, test_component):
                             {
                                 "autofocus": False,
                                 "children": [],
-                                "contents": '{"first_text"}',
+                                "contents": '"first_text"',
                                 "name": "",
                                 "props": [],
                                 "special_props": [],
@@ -724,7 +724,7 @@ def test_component_create_unallowed_types(children, test_component):
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
-                        "props": ['as={"p"}'],
+                        "props": ['as:"p"'],
                         "special_props": [],
                     },
                     {
@@ -733,7 +733,7 @@ def test_component_create_unallowed_types(children, test_component):
                             {
                                 "autofocus": False,
                                 "children": [],
-                                "contents": '{"second_text"}',
+                                "contents": '"second_text"',
                                 "name": "",
                                 "props": [],
                                 "special_props": [],
@@ -741,7 +741,7 @@ def test_component_create_unallowed_types(children, test_component):
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
-                        "props": ['as={"p"}'],
+                        "props": ['as:"p"'],
                         "special_props": [],
                     },
                 ],
@@ -762,7 +762,7 @@ def test_component_create_unallowed_types(children, test_component):
                             {
                                 "autofocus": False,
                                 "children": [],
-                                "contents": '{"first_text"}',
+                                "contents": '"first_text"',
                                 "name": "",
                                 "props": [],
                                 "special_props": [],
@@ -770,7 +770,7 @@ def test_component_create_unallowed_types(children, test_component):
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
-                        "props": ['as={"p"}'],
+                        "props": ['as:"p"'],
                         "special_props": [],
                     },
                     {
@@ -785,7 +785,7 @@ def test_component_create_unallowed_types(children, test_component):
                                             {
                                                 "autofocus": False,
                                                 "children": [],
-                                                "contents": '{"second_text"}',
+                                                "contents": '"second_text"',
                                                 "name": "",
                                                 "props": [],
                                                 "special_props": [],
@@ -793,7 +793,7 @@ def test_component_create_unallowed_types(children, test_component):
                                         ],
                                         "contents": "",
                                         "name": "RadixThemesText",
-                                        "props": ['as={"p"}'],
+                                        "props": ['as:"p"'],
                                         "special_props": [],
                                     }
                                 ],
@@ -1163,10 +1163,10 @@ def test_component_with_only_valid_children(fixture, request):
 @pytest.mark.parametrize(
     "component,rendered",
     [
-        (rx.text("hi"), '<RadixThemesText as={"p"}>\n\n{"hi"}\n</RadixThemesText>'),
+        (rx.text("hi"), 'jsx(\nRadixThemesText,\n{as:"p"},\n"hi"\n,)'),
         (
             rx.box(rx.heading("test", size="3")),
-            '<RadixThemesBox>\n\n<RadixThemesHeading size={"3"}>\n\n{"test"}\n</RadixThemesHeading>\n</RadixThemesBox>',
+            'jsx(\nRadixThemesBox,\n{},\njsx(\nRadixThemesHeading,\n{size:"3"},\n"test"\n,),)',
         ),
     ],
 )
@@ -1691,6 +1691,24 @@ def test_validate_invalid_children():
             rx.fragment(invalid_component()),
         )
 
+    with pytest.raises(ValueError):
+        rx.el.p(rx.el.p("what"))
+
+    with pytest.raises(ValueError):
+        rx.el.p(rx.el.div("what"))
+
+    with pytest.raises(ValueError):
+        rx.el.button(rx.el.button("what"))
+
+    with pytest.raises(ValueError):
+        rx.el.p(rx.el.ol(rx.el.li("what")))
+
+    with pytest.raises(ValueError):
+        rx.el.p(rx.el.ul(rx.el.li("what")))
+
+    with pytest.raises(ValueError):
+        rx.el.a(rx.el.a("what"))
+
     with pytest.raises(ValueError):
         valid_component2(
             rx.fragment(
@@ -1771,14 +1789,14 @@ def test_rename_props():
 
     c1 = C1.create(prop1="prop1_1", prop2="prop2_1")
     rendered_c1 = c1.render()
-    assert 'renamed_prop1={"prop1_1"}' in rendered_c1["props"]
-    assert 'renamed_prop2={"prop2_1"}' in rendered_c1["props"]
+    assert 'renamed_prop1:"prop1_1"' in rendered_c1["props"]
+    assert 'renamed_prop2:"prop2_1"' in rendered_c1["props"]
 
     c2 = C2.create(prop1="prop1_2", prop2="prop2_2", prop3="prop3_2")
     rendered_c2 = c2.render()
-    assert 'renamed_prop1={"prop1_2"}' in rendered_c2["props"]
-    assert 'subclass_prop2={"prop2_2"}' in rendered_c2["props"]
-    assert 'renamed_prop3={"prop3_2"}' in rendered_c2["props"]
+    assert 'renamed_prop1:"prop1_2"' in rendered_c2["props"]
+    assert 'subclass_prop2:"prop2_2"' in rendered_c2["props"]
+    assert 'renamed_prop3:"prop3_2"' in rendered_c2["props"]
 
 
 def test_custom_component_get_imports():
@@ -2165,7 +2183,7 @@ def test_add_style_embedded_vars(test_state: BaseState):
     assert "useParent" in page._get_all_hooks_internal()
     assert (
         str(page).count(
-            f'css={{({{ ["fakeParent"] : "parent", ["color"] : "var(--plum-10)", ["fake"] : "text", ["margin"] : ({test_state.get_name()}.num+"%") }})}}'
+            f'css:({{ ["fakeParent"] : "parent", ["color"] : "var(--plum-10)", ["fake"] : "text", ["margin"] : ({test_state.get_name()}.num+"%") }})'
         )
         == 1
     )
@@ -2186,10 +2204,10 @@ def test_add_style_foreach():
     assert len(page.children[0].children) == 1
 
     # Expect the style to be added to the child of the foreach
-    assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0])
+    assert 'css:({ ["color"] : "red" })' in str(page.children[0].children[0])
 
     # Expect only one instance of this CSS dict in the rendered page
-    assert str(page).count('css={({ ["color"] : "red" })}') == 1
+    assert str(page).count('css:({ ["color"] : "red" })') == 1
 
 
 class TriggerState(rx.State):

+ 4 - 4
tests/units/components/test_tag.py

@@ -8,10 +8,10 @@ from reflex.vars.base import LiteralVar, Var
     "props,test_props",
     [
         ({}, []),
-        ({"key-hypen": 1}, ["key-hypen={1}"]),
-        ({"key": 1}, ["key={1}"]),
-        ({"key": "value"}, ['key={"value"}']),
-        ({"key": True, "key2": "value2"}, ["key={true}", 'key2={"value2"}']),
+        ({"key-hypen": 1}, ['"key-hypen":1']),
+        ({"key": 1}, ["key:1"]),
+        ({"key": "value"}, ['key:"value"']),
+        ({"key": True, "key2": "value2"}, ["key:true", 'key2:"value2"']),
     ],
 )
 def test_format_props(props: dict[str, Var], test_props: list):

+ 30 - 31
tests/units/test_app.py

@@ -14,6 +14,7 @@ from unittest.mock import AsyncMock
 
 import pytest
 import sqlmodel
+from fastapi.responses import StreamingResponse
 from pytest_mock import MockerFixture
 from starlette.applications import Starlette
 from starlette.datastructures import UploadFile
@@ -830,6 +831,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
 
     upload_fn = upload(app)
     streaming_response = await upload_fn(request_mock)
+    assert isinstance(streaming_response, StreamingResponse)
     async for state_update in streaming_response.body_iterator:
         assert (
             state_update
@@ -1373,24 +1375,23 @@ def test_app_wrap_compile_theme(
         line.strip() for line in app_js_contents.splitlines() if line.strip()
     ]
     lines = "".join(app_js_lines)
-    assert (
+    expected = (
         "function AppWrap({children}) {"
         "return ("
-        + ("<StrictMode>" if react_strict_mode else "")
-        + "<RadixThemesColorModeProvider>"
-        "<RadixThemesTheme accentColor={\"plum\"} css={{...theme.styles.global[':root'], ...theme.styles.global.body}}>"
-        "<Fragment>"
-        "<MemoizedToastProvider/>"
-        "<Fragment>"
-        "{children}"
-        "</Fragment>"
-        "</Fragment>"
-        "</RadixThemesTheme>"
-        "</RadixThemesColorModeProvider>"
-        + ("</StrictMode>" if react_strict_mode else "")
-        + ")"
+        + ("jsx(StrictMode,{}," if react_strict_mode else "")
+        + "jsx(RadixThemesColorModeProvider,{},"
+        "jsx(RadixThemesTheme,{accentColor:\"plum\",css:{...theme.styles.global[':root'], ...theme.styles.global.body}},"
+        "jsx(Fragment,{},"
+        "jsx(MemoizedToastProvider,{},),"
+        "jsx(Fragment,{},"
+        "children,"
+        "),"
+        "),"
+        "),"
+        ")" + (",)" if react_strict_mode else "") + ")"
         "}"
-    ) in lines
+    )
+    assert expected in lines
 
 
 @pytest.mark.parametrize(
@@ -1440,23 +1441,21 @@ def test_app_wrap_priority(
         line.strip() for line in app_js_contents.splitlines() if line.strip()
     ]
     lines = "".join(app_js_lines)
-    assert (
+    expected = (
         "function AppWrap({children}) {"
-        "return (" + ("<StrictMode>" if react_strict_mode else "") + "<RadixThemesBox>"
-        '<RadixThemesText as={"p"}>'
-        "<RadixThemesColorModeProvider>"
-        "<Fragment2>"
-        "<Fragment>"
-        "<MemoizedToastProvider/>"
-        "<Fragment>"
-        "{children}"
-        "</Fragment>"
-        "</Fragment>"
-        "</Fragment2>"
-        "</RadixThemesColorModeProvider>"
-        "</RadixThemesText>"
-        "</RadixThemesBox>" + ("</StrictMode>" if react_strict_mode else "")
-    ) in lines
+        "return ("
+        + ("jsx(StrictMode,{}," if react_strict_mode else "")
+        + "jsx(RadixThemesBox,{},"
+        'jsx(RadixThemesText,{as:"p"},'
+        "jsx(RadixThemesColorModeProvider,{},"
+        "jsx(Fragment2,{},"
+        "jsx(Fragment,{},"
+        "jsx(MemoizedToastProvider,{},),"
+        "jsx(Fragment,{},"
+        "children"
+        ",),),),),)" + (",)" if react_strict_mode else "")
+    )
+    assert expected in lines
 
 
 def test_app_state_determination():

+ 44 - 0
tests/units/test_event.py

@@ -483,3 +483,47 @@ def test_event_bound_method() -> None:
 
     w = Wrapper()
     _ = rx.input(on_change=w.get_handler)
+
+
+def test_decentralized_event_with_args():
+    """Test the decentralized event."""
+
+    class S(BaseState):
+        field: Field[str] = field("")
+
+    @event
+    def e(s: S, arg: str):
+        s.field = arg
+
+    _ = rx.input(on_change=e("foo"))
+
+
+def test_decentralized_event_no_args():
+    """Test the decentralized event with no args."""
+
+    class S(BaseState):
+        field: Field[str] = field("")
+
+    @event
+    def e(s: S):
+        s.field = "foo"
+
+    _ = rx.input(on_change=e())
+    _ = rx.input(on_change=e)
+
+
+class GlobalState(BaseState):
+    """Global state for testing decentralized events."""
+
+    field: Field[str] = field("")
+
+
+@event
+def f(s: GlobalState, arg: str):
+    s.field = arg
+
+
+def test_decentralized_event_global_state():
+    """Test the decentralized event with a global state."""
+    _ = rx.input(on_change=f("foo"))
+    _ = rx.input(on_change=f)

+ 48 - 10
tests/units/test_state.py

@@ -27,18 +27,20 @@ from reflex.app import App
 from reflex.base import Base
 from reflex.constants import CompileVars, RouteVar, SocketEvent
 from reflex.event import Event, EventHandler
+from reflex.istate.manager import (
+    LockExpiredError,
+    StateManager,
+    StateManagerDisk,
+    StateManagerMemory,
+    StateManagerRedis,
+)
 from reflex.state import (
     BaseState,
     ImmutableStateError,
-    LockExpiredError,
     MutableProxy,
     OnLoadInternalState,
     RouterData,
     State,
-    StateManager,
-    StateManagerDisk,
-    StateManagerMemory,
-    StateManagerRedis,
     StateProxy,
     StateUpdate,
     _substate_key,
@@ -1505,6 +1507,10 @@ def test_setattr_of_mutable_types(mutable_state: MutableTestState):
     assert isinstance(array[1], list)
     assert isinstance(array[2], MutableProxy)
     assert isinstance(array[2], dict)
+    assert isinstance(array[:], list)
+    assert not isinstance(array[:], MutableProxy)
+    assert isinstance(array[:][1], MutableProxy)
+    assert isinstance(array[:][1], list)
 
     assert isinstance(hashmap, MutableProxy)
     assert isinstance(hashmap, dict)
@@ -1773,7 +1779,7 @@ def substate_token_redis(state_manager_redis, token):
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_expire(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str
+    state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
 ):
     """Test that the state manager lock expires and raises exception exiting context.
 
@@ -1795,7 +1801,7 @@ async def test_state_manager_lock_expire(
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_expire_contend(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str
+    state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
 ):
     """Test that the state manager lock expires and queued waiters proceed.
 
@@ -1840,7 +1846,10 @@ async def test_state_manager_lock_expire_contend(
 
 @pytest.mark.asyncio
 async def test_state_manager_lock_warning_threshold_contend(
-    state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
+    state_manager_redis: StateManagerRedis,
+    token: str,
+    substate_token_redis: str,
+    mocker,
 ):
     """Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
 
@@ -3350,7 +3359,8 @@ config = rx.Config(
     with chdir(proj_root):
         # reload config for each parameter to avoid stale values
         reflex.config.get_config(reload=True)
-        from reflex.state import State, StateManager
+        from reflex.istate.manager import StateManager
+        from reflex.state import State
 
         state_manager = StateManager.create(state=State)
         assert state_manager.lock_expiration == expected_values[0]  # pyright: ignore [reportAttributeAccessIssue]
@@ -3388,13 +3398,41 @@ config = rx.Config(
     with chdir(proj_root):
         # reload config for each parameter to avoid stale values
         reflex.config.get_config(reload=True)
-        from reflex.state import State, StateManager
+        from reflex.istate.manager import StateManager
+        from reflex.state import State
 
         with pytest.raises(InvalidLockWarningThresholdError):
             StateManager.create(state=State)
         del sys.modules[constants.Config.MODULE]
 
 
+def test_auto_setters_off(tmp_path):
+    proj_root = tmp_path / "project1"
+    proj_root.mkdir()
+
+    config_string = """
+import reflex as rx
+config = rx.Config(
+    app_name="project1",
+    state_auto_setters=False,
+)
+    """
+
+    (proj_root / "rxconfig.py").write_text(dedent(config_string))
+
+    with chdir(proj_root):
+        # reload config for each parameter to avoid stale values
+        reflex.config.get_config(reload=True)
+        from reflex.state import State
+
+        class TestState(State):
+            """A test state."""
+
+            num: int = 0
+
+        assert list(TestState.event_handlers) == ["setvar"]
+
+
 class MixinState(State, mixin=True):
     """A mixin state for testing."""
 

+ 41 - 0
tests/units/test_var.py

@@ -1,3 +1,4 @@
+import decimal
 import json
 import math
 import typing
@@ -1920,3 +1921,43 @@ def test_str_var_in_components(mocker):
     rx.vstack(
         str(StateWithVar.field),
     )
+
+
+def test_decimal_number_operations():
+    """Test that decimal.Decimal values work with NumberVar operations."""
+    dec_num = Var.create(decimal.Decimal("123.456"))
+    assert isinstance(dec_num._var_value, decimal.Decimal)
+    assert str(dec_num) == "123.456"
+
+    result = dec_num + 10
+    assert str(result) == "(123.456 + 10)"
+
+    result = dec_num * 2
+    assert str(result) == "(123.456 * 2)"
+
+    result = dec_num / 2
+    assert str(result) == "(123.456 / 2)"
+
+    result = dec_num > 100
+    assert str(result) == "(123.456 > 100)"
+
+    result = dec_num < 200
+    assert str(result) == "(123.456 < 200)"
+
+    assert dec_num.json() == "123.456"
+
+
+def test_decimal_var_type_compatibility():
+    """Test that decimal.Decimal values are compatible with NumberVar type system."""
+    dec_num = Var.create(decimal.Decimal("123.456"))
+    int_num = Var.create(42)
+    float_num = Var.create(3.14)
+
+    result = dec_num + int_num
+    assert str(result) == "(123.456 + 42)"
+
+    result = dec_num * float_num
+    assert str(result) == "(123.456 * 3.14)"
+
+    result = (dec_num + int_num) / float_num
+    assert str(result) == "((123.456 + 42) / 3.14)"

+ 3 - 3
tests/units/utils/test_format.py

@@ -457,7 +457,7 @@ def test_format_match(
                     _js_expr=f"(({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />)"
                 ),
             },
-            '({ ["h1"] : (({node, ...props}) => <Heading {...props} as={"h1"} />) })',
+            '({ ["h1"] : (({node, ...props}) => <Heading {...props} as:"h1" />) })',
         ),
     ],
 )
@@ -475,9 +475,9 @@ def test_format_prop(prop: Var, formatted: str):
     "single_props,key_value_props,output",
     [
         (
-            [Var(_js_expr="{...props}")],
+            [Var(_js_expr="props")],
             {"key": 42},
-            ["key={42}", "{...props}"],
+            ["key:42", "...props"],
         ),
     ],
 )

+ 6 - 0
tests/units/utils/test_serializers.py

@@ -1,4 +1,5 @@
 import datetime
+import decimal
 import json
 from enum import Enum
 from pathlib import Path
@@ -188,6 +189,9 @@ class BaseSubclass(Base):
         (Color(color="slate", shade=1), "var(--slate-1)"),
         (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),
         (Color(color="accent", shade=1, alpha=True), "var(--accent-a1)"),
+        (decimal.Decimal("123.456"), 123.456),
+        (decimal.Decimal("-0.5"), -0.5),
+        (decimal.Decimal("0"), 0.0),
     ],
 )
 def test_serialize(value: Any, expected: str):
@@ -226,6 +230,8 @@ def test_serialize(value: Any, expected: str):
         (Color(color="slate", shade=1), '"var(--slate-1)"', True),
         (BaseSubclass, '"BaseSubclass"', True),
         (Path(), '"."', True),
+        (decimal.Decimal("123.456"), "123.456", True),
+        (decimal.Decimal("-0.5"), "-0.5", True),
     ],
 )
 def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):

Fișier diff suprimat deoarece este prea mare
+ 447 - 447
uv.lock


Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff