فهرست منبع

[REF-889] useContext per substate (#2149)

Masen Furer 1 سال پیش
والد
کامیت
1603144c7d
65فایلهای تغییر یافته به همراه1257 افزوده شده و 455 حذف شده
  1. 21 0
      integration/test_var_operations.py
  2. 7 5
      reflex/.templates/jinja/web/pages/_app.js.jinja2
  3. 0 26
      reflex/.templates/jinja/web/pages/index.js.jinja2
  4. 38 9
      reflex/.templates/jinja/web/utils/context.js.jinja2
  5. 21 38
      reflex/.templates/web/utils/state.js
  6. 1 1
      reflex/app.py
  7. 2 32
      reflex/compiler/compiler.py
  8. 2 1
      reflex/compiler/templates.py
  9. 4 3
      reflex/compiler/utils.py
  10. 16 2
      reflex/components/base/bare.py
  11. 192 16
      reflex/components/component.py
  12. 2 1
      reflex/components/datadisplay/code.py
  13. 2 1
      reflex/components/datadisplay/code.pyi
  14. 2 1
      reflex/components/datadisplay/dataeditor.py
  15. 2 1
      reflex/components/datadisplay/dataeditor.pyi
  16. 6 6
      reflex/components/datadisplay/datatable.py
  17. 1 1
      reflex/components/datadisplay/datatable.pyi
  18. 2 2
      reflex/components/datadisplay/moment.py
  19. 1 1
      reflex/components/datadisplay/moment.pyi
  20. 2 2
      reflex/components/forms/colormodeswitch.py
  21. 2 2
      reflex/components/forms/colormodeswitch.pyi
  22. 13 1
      reflex/components/forms/debounce.py
  23. 2 1
      reflex/components/forms/debounce.pyi
  24. 2 1
      reflex/components/forms/editor.py
  25. 2 1
      reflex/components/forms/editor.pyi
  26. 8 2
      reflex/components/forms/form.py
  27. 1 1
      reflex/components/forms/form.pyi
  28. 2 2
      reflex/components/forms/input.py
  29. 1 1
      reflex/components/forms/input.pyi
  30. 3 1
      reflex/components/forms/pininput.py
  31. 27 10
      reflex/components/forms/upload.py
  32. 2 1
      reflex/components/forms/upload.pyi
  33. 44 7
      reflex/components/layout/cond.py
  34. 3 3
      reflex/components/layout/html.py
  35. 5 2
      reflex/components/layout/html.pyi
  36. 10 10
      reflex/components/libs/chakra.py
  37. 1 1
      reflex/components/libs/chakra.pyi
  38. 7 8
      reflex/components/navigation/client_side_routing.py
  39. 3 3
      reflex/components/navigation/client_side_routing.pyi
  40. 13 8
      reflex/components/overlay/banner.py
  41. 4 3
      reflex/components/overlay/banner.pyi
  42. 2 2
      reflex/components/radix/themes/base.py
  43. 1 1
      reflex/components/radix/themes/base.pyi
  44. 2 1
      reflex/components/typography/markdown.py
  45. 2 1
      reflex/components/typography/markdown.pyi
  46. 4 0
      reflex/constants/__init__.py
  47. 2 0
      reflex/constants/base.py
  48. 25 0
      reflex/constants/compiler.py
  49. 1 1
      reflex/middleware/hydrate_middleware.py
  50. 9 5
      reflex/state.py
  51. 71 7
      reflex/style.py
  52. 19 6
      reflex/utils/format.py
  53. 41 4
      reflex/utils/imports.py
  54. 1 0
      reflex/utils/types.py
  55. 298 120
      reflex/vars.py
  56. 19 19
      reflex/vars.pyi
  57. 1 1
      tests/compiler/test_compiler.py
  58. 1 1
      tests/components/layout/test_cond.py
  59. 162 2
      tests/components/test_component.py
  60. 3 3
      tests/middleware/test_hydrate_middleware.py
  61. 1 3
      tests/test_app.py
  62. 24 15
      tests/test_state.py
  63. 2 1
      tests/test_style.py
  64. 50 16
      tests/test_var.py
  65. 37 28
      tests/utils/test_format.py

+ 21 - 0
integration/test_var_operations.py

@@ -28,6 +28,7 @@ def VarOperations():
         str_var4: str = "a long string"
         str_var4: str = "a long string"
         dict1: dict = {1: 2}
         dict1: dict = {1: 2}
         dict2: dict = {3: 4}
         dict2: dict = {3: 4}
+        html_str: str = "<div>hello</div>"
 
 
     app = rx.App(state=VarOperationState)
     app = rx.App(state=VarOperationState)
 
 
@@ -522,6 +523,19 @@ def VarOperations():
             rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
             rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
             rx.text(VarOperationState.list3.join(""), id="list_join"),
             rx.text(VarOperationState.list3.join(""), id="list_join"),
             rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
             rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
+            # Index from an op var
+            rx.text(
+                VarOperationState.list3[VarOperationState.int_var1 % 3],
+                id="list_index_mod",
+            ),
+            rx.html(
+                VarOperationState.html_str,
+                id="html_str",
+            ),
+            rx.highlight(
+                "second",
+                query=[VarOperationState.str_var2],
+            ),
             rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
             rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
             rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
             rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
             rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
             rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
@@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("dict_eq_dict", "false"),
         ("dict_eq_dict", "false"),
         ("dict_neq_dict", "true"),
         ("dict_neq_dict", "true"),
         ("dict_contains", "true"),
         ("dict_contains", "true"),
+        # index from an op var
+        ("list_index_mod", "second"),
+        # html component with var
+        ("html_str", "hello"),
     ]
     ]
 
 
     for tag, expected in tests:
     for tag, expected in tests:
         assert driver.find_element(By.ID, tag).text == expected
         assert driver.find_element(By.ID, tag).text == expected
+
+    # Highlight component with var query (does not plumb ID)
+    assert driver.find_element(By.TAG_NAME, "mark").text == "second"

+ 7 - 5
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -1,7 +1,7 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 {% extends "web/pages/base_page.js.jinja2" %}
 
 
 {% block declaration %}
 {% block declaration %}
-import { EventLoopProvider } from "/utils/context.js";
+import { EventLoopProvider, StateProvider } from "/utils/context.js";
 import { ThemeProvider } from 'next-themes'
 import { ThemeProvider } from 'next-themes'
 
 
 {% for custom_code in custom_codes %}
 {% for custom_code in custom_codes %}
@@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
   return (
   return (
     <ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
     <ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
       <AppWrap>
       <AppWrap>
-        <EventLoopProvider>
-          <Component {...pageProps} />
-        </EventLoopProvider>
+        <StateProvider>
+          <EventLoopProvider>
+            <Component {...pageProps} />
+          </EventLoopProvider>
+        </StateProvider>
       </AppWrap>
       </AppWrap>
     </ThemeProvider>
     </ThemeProvider>
   );
   );
 }
 }
 
 
-{% endblock %}
+{% endblock %}

+ 0 - 26
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -8,32 +8,6 @@
 
 
 {% block export %}
 {% block export %}
 export default function Component() {
 export default function Component() {
-{% if state_name %}
-  const {{state_name}} = useContext(StateContext)
-{% endif %}
-  const {{const.router}} = useRouter()
-  const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
-  const focusRef = useRef();
-  
-  // Main event loop.
-  const [addEvents, connectError] = useContext(EventLoopContext)
-
-  // Set focus to the specified element.
-  useEffect(() => {
-    if (focusRef.current) {
-      focusRef.current.focus();
-    }
-  })
-
-  // Route after the initial page hydration.
-  useEffect(() => {
-    const change_complete = () => addEvents(initialEvents())
-    {{const.router}}.events.on('routeChangeComplete', change_complete)
-    return () => {
-      {{const.router}}.events.off('routeChangeComplete', change_complete)
-    }
-  }, [{{const.router}}])
-
   {% for hook in hooks %}
   {% for hook in hooks %}
   {{ hook }}
   {{ hook }}
   {% endfor %}
   {% endfor %}

+ 38 - 9
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -1,5 +1,5 @@
-import { createContext, useState } from "react"
-import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
+import { createContext, useContext, useMemo, useReducer, useState } from "react"
+import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
 
 
 {% if initial_state %}
 {% if initial_state %}
 export const initialState = {{ initial_state|json_dumps }}
 export const initialState = {{ initial_state|json_dumps }}
@@ -8,7 +8,12 @@ export const initialState = {}
 {% endif %}
 {% endif %}
 
 
 export const ColorModeContext = createContext(null);
 export const ColorModeContext = createContext(null);
-export const StateContext = createContext(null);
+export const DispatchContext = createContext(null);
+export const StateContexts = {
+  {% for state_name in initial_state %}
+  {{state_name|var_name}}: createContext(null),
+  {% endfor %}
+}
 export const EventLoopContext = createContext(null);
 export const EventLoopContext = createContext(null);
 {% if client_storage %}
 {% if client_storage %}
 export const clientStorage = {{ client_storage|json_dumps }}
 export const clientStorage = {{ client_storage|json_dumps }}
@@ -27,16 +32,40 @@ export const initialEvents = () => []
 export const isDevMode = {{ is_dev_mode|json_dumps }}
 export const isDevMode = {{ is_dev_mode|json_dumps }}
 
 
 export function EventLoopProvider({ children }) {
 export function EventLoopProvider({ children }) {
-  const [state, addEvents, connectError] = useEventLoop(
-    initialState,
+  const dispatch = useContext(DispatchContext)
+  const [addEvents, connectError] = useEventLoop(
+    dispatch,
     initialEvents,
     initialEvents,
     clientStorage,
     clientStorage,
   )
   )
   return (
   return (
     <EventLoopContext.Provider value={[addEvents, connectError]}>
     <EventLoopContext.Provider value={[addEvents, connectError]}>
-      <StateContext.Provider value={state}>
-        {children}
-      </StateContext.Provider>
+      {children}
     </EventLoopContext.Provider>
     </EventLoopContext.Provider>
   )
   )
-}
+}
+
+export function StateProvider({ children }) {
+  {% for state_name in initial_state %}
+  const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
+  {% endfor %}
+  const dispatchers = useMemo(() => {
+    return {
+      {% for state_name in initial_state %}
+      "{{state_name}}": dispatch_{{state_name|var_name}},
+      {% endfor %}
+    }
+  }, [])
+
+  return (
+    {% for state_name in initial_state %}
+    <StateContexts.{{state_name|var_name}}.Provider value={ {{state_name|var_name}} }>
+    {% endfor %}
+      <DispatchContext.Provider value={dispatchers}>
+        {children}
+      </DispatchContext.Provider>
+    {% for state_name in initial_state|reverse %}
+    </StateContexts.{{state_name|var_name}}.Provider>
+    {% endfor %}
+  )
+}

+ 21 - 38
reflex/.templates/web/utils/state.js

@@ -6,7 +6,7 @@ import env from "env.json";
 import Cookies from "universal-cookie";
 import Cookies from "universal-cookie";
 import { useEffect, useReducer, useRef, useState } from "react";
 import { useEffect, useReducer, useRef, useState } from "react";
 import Router, { useRouter } from "next/router";
 import Router, { useRouter } from "next/router";
-import { initialEvents } from "utils/context.js"
+import { initialEvents, initialState } from "utils/context.js"
 
 
 // Endpoint URLs.
 // Endpoint URLs.
 const EVENTURL = env.EVENT
 const EVENTURL = env.EVENT
@@ -100,37 +100,10 @@ export const getEventURL = () => {
  * @param delta The delta to apply.
  * @param delta The delta to apply.
  */
  */
 export const applyDelta = (state, delta) => {
 export const applyDelta = (state, delta) => {
-  const new_state = { ...state }
-  for (const substate in delta) {
-    let s = new_state;
-    const path = substate.split(".").slice(1);
-    while (path.length > 0) {
-      s = s[path.shift()];
-    }
-    for (const key in delta[substate]) {
-      s[key] = delta[substate][key];
-    }
-  }
-  return new_state
+  return { ...state, ...delta }
 };
 };
 
 
 
 
-/**
- * Get all local storage items in a key-value object.
- * @returns object of items in local storage.
- */
-export const getAllLocalStorageItems = () => {
-  var localStorageItems = {};
-
-  for (var i = 0, len = localStorage.length; i < len; i++) {
-    var key = localStorage.key(i);
-    localStorageItems[key] = localStorage.getItem(key);
-  }
-
-  return localStorageItems;
-}
-
-
 /**
 /**
  * Handle frontend event or send the event to the backend via Websocket.
  * Handle frontend event or send the event to the backend via Websocket.
  * @param event The event to send.
  * @param event The event to send.
@@ -346,7 +319,9 @@ export const connect = async (
   // On each received message, queue the updates and events.
   // On each received message, queue the updates and events.
   socket.current.on("event", message => {
   socket.current.on("event", message => {
     const update = JSON5.parse(message)
     const update = JSON5.parse(message)
-    dispatch(update.delta)
+    for (const substate in update.delta) {
+      dispatch[substate](update.delta[substate])
+    }
     applyClientStorageDelta(client_storage, update.delta)
     applyClientStorageDelta(client_storage, update.delta)
     event_processing = !update.final
     event_processing = !update.final
     if (update.events) {
     if (update.events) {
@@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {
 
 
 /**
 /**
  * Establish websocket event loop for a NextJS page.
  * Establish websocket event loop for a NextJS page.
- * @param initial_state The initial app state.
- * @param initial_events Function that returns the initial app events.
+ * @param dispatch The reducer dispatch function to update state.
+ * @param initial_events The initial app events.
  * @param client_storage The client storage object from context.js
  * @param client_storage The client storage object from context.js
  *
  *
- * @returns [state, addEvents, connectError] -
- *   state is a reactive dict,
+ * @returns [addEvents, connectError] -
  *   addEvents is used to queue an event, and
  *   addEvents is used to queue an event, and
  *   connectError is a reactive js error from the websocket connection (or null if connected).
  *   connectError is a reactive js error from the websocket connection (or null if connected).
  */
  */
 export const useEventLoop = (
 export const useEventLoop = (
-  initial_state = {},
+  dispatch,
   initial_events = () => [],
   initial_events = () => [],
   client_storage = {},
   client_storage = {},
 ) => {
 ) => {
   const socket = useRef(null)
   const socket = useRef(null)
   const router = useRouter()
   const router = useRouter()
-  const [state, dispatch] = useReducer(applyDelta, initial_state)
   const [connectError, setConnectError] = useState(null)
   const [connectError, setConnectError] = useState(null)
 
 
   // Function to add new events to the event queue.
   // Function to add new events to the event queue.
@@ -570,7 +543,7 @@ export const useEventLoop = (
       return;
       return;
     }
     }
     // only use websockets if state is present
     // only use websockets if state is present
-    if (Object.keys(state).length > 0) {
+    if (Object.keys(initialState).length > 0) {
       // Initialize the websocket connection.
       // Initialize the websocket connection.
       if (!socket.current) {
       if (!socket.current) {
         connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
         connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
@@ -583,7 +556,17 @@ export const useEventLoop = (
       })()
       })()
     }
     }
   })
   })
-  return [state, addEvents, connectError]
+
+  // Route after the initial page hydration.
+  useEffect(() => {
+    const change_complete = () => addEvents(initial_events())
+    router.events.on('routeChangeComplete', change_complete)
+    return () => {
+      router.events.off('routeChangeComplete', change_complete)
+    }
+  }, [router])
+
+  return [addEvents, connectError]
 }
 }
 
 
 /***
 /***

+ 1 - 1
reflex/app.py

@@ -63,7 +63,7 @@ from reflex.state import (
     StateUpdate,
     StateUpdate,
 )
 )
 from reflex.utils import console, format, prerequisites, types
 from reflex.utils import console, format, prerequisites, types
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
 
 
 # Define custom types.
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 ComponentCallable = Callable[[], Component]

+ 2 - 32
reflex/compiler/compiler.py

@@ -10,40 +10,10 @@ from reflex.compiler import templates, utils
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.config import get_config
 from reflex.config import get_config
 from reflex.state import State
 from reflex.state import State
-from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 
 
 # Imports to be included in every Reflex app.
 # Imports to be included in every Reflex app.
-DEFAULT_IMPORTS: imports.ImportDict = {
-    "react": [
-        ImportVar(tag="Fragment"),
-        ImportVar(tag="useEffect"),
-        ImportVar(tag="useRef"),
-        ImportVar(tag="useState"),
-        ImportVar(tag="useContext"),
-    ],
-    "next/router": [ImportVar(tag="useRouter")],
-    f"/{constants.Dirs.STATE_PATH}": [
-        ImportVar(tag="uploadFiles"),
-        ImportVar(tag="Event"),
-        ImportVar(tag="isTrue"),
-        ImportVar(tag="spreadArraysOrObjects"),
-        ImportVar(tag="preventDefault"),
-        ImportVar(tag="refs"),
-        ImportVar(tag="getRefValue"),
-        ImportVar(tag="getRefValues"),
-        ImportVar(tag="getAllLocalStorageItems"),
-        ImportVar(tag="useEventLoop"),
-    ],
-    "/utils/context.js": [
-        ImportVar(tag="EventLoopContext"),
-        ImportVar(tag="initialEvents"),
-        ImportVar(tag="StateContext"),
-        ImportVar(tag="ColorModeContext"),
-    ],
-    "/utils/helpers/range.js": [
-        ImportVar(tag="range", is_default=True),
-    ],
+DEFAULT_IMPORTS: ImportDict = {
     "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
     "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
 }
 }
 
 

+ 2 - 1
reflex/compiler/templates.py

@@ -3,7 +3,7 @@
 from jinja2 import Environment, FileSystemLoader, Template
 from jinja2 import Environment, FileSystemLoader, Template
 
 
 from reflex import constants
 from reflex import constants
-from reflex.utils.format import json_dumps
+from reflex.utils.format import format_state_name, json_dumps
 
 
 
 
 class ReflexJinjaEnvironment(Environment):
 class ReflexJinjaEnvironment(Environment):
@@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment):
         )
         )
         self.filters["json_dumps"] = json_dumps
         self.filters["json_dumps"] = json_dumps
         self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
         self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
+        self.filters["var_name"] = format_state_name
         self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
         self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
         self.globals["const"] = {
         self.globals["const"] = {
             "socket": constants.CompileVars.SOCKET,
             "socket": constants.CompileVars.SOCKET,

+ 4 - 3
reflex/compiler/utils.py

@@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
 from reflex.state import Cookie, LocalStorage, State
 from reflex.state import Cookie, LocalStorage, State
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
 from reflex.utils import console, format, imports, path_ops
-from reflex.vars import ImportVar
 
 
 # To re-export this function.
 # To re-export this function.
 merge_imports = imports.merge_imports
 merge_imports = imports.merge_imports
 
 
 
 
-def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
+def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
     """Compile an import statement.
     """Compile an import statement.
 
 
     Args:
     Args:
@@ -343,7 +342,9 @@ def get_context_path() -> str:
     Returns:
     Returns:
         The path of the context module.
         The path of the context module.
     """
     """
-    return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
+    return os.path.join(
+        constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
+    )
 
 
 
 
 def get_components_path() -> str:
 def get_components_path() -> str:

+ 16 - 2
reflex/components/base/bare.py

@@ -1,7 +1,7 @@
 """A bare component."""
 """A bare component."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any
+from typing import Any, Iterator
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
@@ -24,7 +24,21 @@ class Bare(Component):
         Returns:
         Returns:
             The component.
             The component.
         """
         """
-        return cls(contents=str(contents))  # type: ignore
+        if isinstance(contents, Var) and contents._var_data:
+            contents = contents.to(str)
+        else:
+            contents = str(contents)
+        return cls(contents=contents)  # type: ignore
 
 
     def _render(self) -> Tag:
     def _render(self) -> Tag:
         return Tagless(contents=str(self.contents))
         return Tagless(contents=str(self.contents))
+
+    def _get_vars(self) -> Iterator[Var]:
+        """Walk all Vars used in this component.
+
+        Yields:
+            The contents if it is a Var, otherwise nothing.
+        """
+        if isinstance(self.contents, Var):
+            # Fast path for Bare text components.
+            yield self.contents

+ 192 - 16
reflex/components/component.py

@@ -5,11 +5,11 @@ from __future__ import annotations
 import typing
 import typing
 from abc import ABC
 from abc import ABC
 from functools import lru_cache, wraps
 from functools import lru_cache, wraps
-from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
-from reflex.constants import Dirs, EventTriggers
+from reflex.constants import Dirs, EventTriggers, Hooks, Imports
 from reflex.event import (
 from reflex.event import (
     EventChain,
     EventChain,
     EventHandler,
     EventHandler,
@@ -20,8 +20,9 @@ from reflex.event import (
 )
 )
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, format, imports, types
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
 from reflex.utils.serializers import serializer
-from reflex.vars import BaseVar, ImportVar, Var
+from reflex.vars import BaseVar, Var
 
 
 
 
 class Component(Base, ABC):
 class Component(Base, ABC):
@@ -388,7 +389,11 @@ class Component(Base, ABC):
             props = props.copy()
             props = props.copy()
 
 
         props.update(
         props.update(
-            self.event_triggers,
+            **{
+                trigger: handler
+                for trigger, handler in self.event_triggers.items()
+                if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT}
+            },
             key=self.key,
             key=self.key,
             id=self.id,
             id=self.id,
             class_name=self.class_name,
             class_name=self.class_name,
@@ -488,7 +493,7 @@ class Component(Base, ABC):
         """
         """
         if type(self) in style:
         if type(self) in style:
             # Extract the style for this component.
             # Extract the style for this component.
-            component_style = Style(style[type(self)])
+            component_style = style[type(self)]
 
 
             # Only add style props that are not overridden.
             # Only add style props that are not overridden.
             component_style = {
             component_style = {
@@ -564,6 +569,78 @@ class Component(Base, ABC):
             if self._valid_children:
             if self._valid_children:
                 validate_valid_child(name)
                 validate_valid_child(name)
 
 
+    @staticmethod
+    def _get_vars_from_event_triggers(
+        event_triggers: dict[str, EventChain | Var],
+    ) -> Iterator[tuple[str, list[Var]]]:
+        """Get the Vars associated with each event trigger.
+
+        Args:
+            event_triggers: The event triggers from the component instance.
+
+        Yields:
+            tuple of (event_name, event_vars)
+        """
+        for event_trigger, event in event_triggers.items():
+            if isinstance(event, Var):
+                yield event_trigger, [event]
+            elif isinstance(event, EventChain):
+                event_args = []
+                for spec in event.events:
+                    for args in spec.args:
+                        event_args.extend(args)
+                yield event_trigger, event_args
+
+    def _get_vars(self) -> list[Var]:
+        """Walk all Vars used in this component.
+
+        Returns:
+            Each var referenced by the component (props, styles, event handlers).
+        """
+        vars = getattr(self, "__vars", None)
+        if vars is not None:
+            return vars
+        vars = self.__vars = []
+        # Get Vars associated with event trigger arguments.
+        for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers):
+            vars.extend(event_vars)
+
+        # Get Vars associated with component props.
+        for prop in self.get_props():
+            prop_var = getattr(self, prop)
+            if isinstance(prop_var, Var):
+                vars.append(prop_var)
+
+        # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded.
+        if self.style:
+            vars.append(
+                BaseVar(
+                    _var_name="style",
+                    _var_type=str,
+                    _var_data=self.style._var_data,
+                )
+            )
+
+        # Special props are always Var instances.
+        vars.extend(self.special_props)
+
+        # Get Vars associated with common Component props.
+        for comp_prop in (
+            self.class_name,
+            self.id,
+            self.key,
+            self.autofocus,
+            *self.custom_attrs.values(),
+        ):
+            if isinstance(comp_prop, Var):
+                vars.append(comp_prop)
+            elif isinstance(comp_prop, str):
+                # Collapse VarData encoded in f-strings.
+                var = Var.create_safe(comp_prop)
+                if var._var_data is not None:
+                    vars.append(var)
+        return vars
+
     def _get_custom_code(self) -> str | None:
     def _get_custom_code(self) -> str | None:
         """Get custom code for the component.
         """Get custom code for the component.
 
 
@@ -644,6 +721,33 @@ class Component(Base, ABC):
             dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
             dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
         }
         }
 
 
+    def _get_hooks_imports(self) -> imports.ImportDict:
+        """Get the imports required by certain hooks.
+
+        Returns:
+            The imports required for all selected hooks.
+        """
+        _imports = {}
+
+        if self._get_ref_hook():
+            # Handle hooks needed for attaching react refs to DOM nodes.
+            _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
+            _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+
+        if self._get_mount_lifecycle_hook():
+            # Handle hooks for `on_mount` / `on_unmount`.
+            _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+
+        if self._get_special_hooks():
+            # Handle additional internal hooks (autofocus, etc).
+            _imports.setdefault("react", set()).update(
+                {
+                    ImportVar(tag="useRef"),
+                    ImportVar(tag="useEffect"),
+                },
+            )
+        return _imports
+
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.
 
 
@@ -651,13 +755,26 @@ class Component(Base, ABC):
             The imports needed by the component.
             The imports needed by the component.
         """
         """
         _imports = {}
         _imports = {}
+
+        # Import this component's tag from the main library.
         if self.library is not None and self.tag is not None:
         if self.library is not None and self.tag is not None:
             _imports[self.library] = {self.import_var}
             _imports[self.library] = {self.import_var}
 
 
+        # Get static imports required for event processing.
+        event_imports = Imports.EVENTS if self.event_triggers else {}
+
+        # Collect imports from Vars used directly by this component.
+        var_imports = [
+            var._var_data.imports for var in self._get_vars() if var._var_data
+        ]
+
         return imports.merge_imports(
         return imports.merge_imports(
             *self._get_props_imports(),
             *self._get_props_imports(),
             self._get_dependencies_imports(),
             self._get_dependencies_imports(),
+            self._get_hooks_imports(),
             _imports,
             _imports,
+            event_imports,
+            *var_imports,
         )
         )
 
 
     def get_imports(self) -> imports.ImportDict:
     def get_imports(self) -> imports.ImportDict:
@@ -678,13 +795,13 @@ class Component(Base, ABC):
         """
         """
         # pop on_mount and on_unmount from event_triggers since these are handled by
         # pop on_mount and on_unmount from event_triggers since these are handled by
         # hooks, not as actually props in the component
         # hooks, not as actually props in the component
-        on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None)
-        on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None)
-        if on_mount:
+        on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None)
+        on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None)
+        if on_mount is not None:
             on_mount = format.format_event_chain(on_mount)
             on_mount = format.format_event_chain(on_mount)
-        if on_unmount:
+        if on_unmount is not None:
             on_unmount = format.format_event_chain(on_unmount)
             on_unmount = format.format_event_chain(on_unmount)
-        if on_mount or on_unmount:
+        if on_mount is not None or on_unmount is not None:
             return f"""
             return f"""
                 useEffect(() => {{
                 useEffect(() => {{
                     {on_mount or ""}
                     {on_mount or ""}
@@ -703,6 +820,47 @@ class Component(Base, ABC):
         if ref is not None:
         if ref is not None:
             return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
             return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
 
 
+    def _get_vars_hooks(self) -> set[str]:
+        """Get the hooks required by vars referenced in this component.
+
+        Returns:
+            The hooks for the vars.
+        """
+        vars_hooks = set()
+        for var in self._get_vars():
+            if var._var_data:
+                vars_hooks.update(var._var_data.hooks)
+        return vars_hooks
+
+    def _get_events_hooks(self) -> set[str]:
+        """Get the hooks required by events referenced in this component.
+
+        Returns:
+            The hooks for the events.
+        """
+        if self.event_triggers:
+            return {Hooks.EVENTS}
+        return set()
+
+    def _get_special_hooks(self) -> set[str]:
+        """Get the hooks required by special actions referenced in this component.
+
+        Returns:
+            The hooks for special actions.
+        """
+        if self.autofocus:
+            return {
+                """
+                // Set focus to the specified element.
+                const focusRef = useRef(null)
+                useEffect(() => {
+                  if (focusRef.current) {
+                    focusRef.current.focus();
+                  }
+                })""",
+            }
+        return set()
+
     def _get_hooks_internal(self) -> Set[str]:
     def _get_hooks_internal(self) -> Set[str]:
         """Get the React hooks for this component managed by the framework.
         """Get the React hooks for this component managed by the framework.
 
 
@@ -712,10 +870,15 @@ class Component(Base, ABC):
         Returns:
         Returns:
             Set of internally managed hooks.
             Set of internally managed hooks.
         """
         """
-        return set(
-            hook
-            for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
-            if hook
+        return (
+            set(
+                hook
+                for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
+                if hook
+            )
+            | self._get_vars_hooks()
+            | self._get_events_hooks()
+            | self._get_special_hooks()
         )
         )
 
 
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:
@@ -1018,11 +1181,24 @@ class NoSSRComponent(Component):
     """A dynamic component that is not rendered on the server."""
     """A dynamic component that is not rendered on the server."""
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
-        dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}}
+        """Get the imports for the component.
+
+        Returns:
+            The imports for dynamically importing the component at module load time.
+        """
+        # Next.js dynamic import mechanism.
+        dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]}
+
+        # The normal imports for this component.
+        _imports = super()._get_imports()
+
+        # Do NOT import the main library/tag statically.
+        if self.library is not None:
+            _imports[self.library] = [imports.ImportVar(tag=None, render=False)]
 
 
         return imports.merge_imports(
         return imports.merge_imports(
             dynamic_import,
             dynamic_import,
-            {self.library: {ImportVar(tag=None, render=False)}},
+            _imports,
             self._get_dependencies_imports(),
             self._get_dependencies_imports(),
         )
         )
 
 

+ 2 - 1
reflex/components/datadisplay/code.py

@@ -12,7 +12,8 @@ from reflex.components.media import Icon
 from reflex.event import set_clipboard
 from reflex.event import set_clipboard
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format, imports
 from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 LiteralCodeBlockTheme = Literal[
 LiteralCodeBlockTheme = Literal[
     "a11y-dark",
     "a11y-dark",

+ 2 - 1
reflex/components/datadisplay/code.pyi

@@ -16,7 +16,8 @@ from reflex.components.media import Icon
 from reflex.event import set_clipboard
 from reflex.event import set_clipboard
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format, imports
 from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 LiteralCodeBlockTheme = Literal[
 LiteralCodeBlockTheme = Literal[
     "a11y-dark",
     "a11y-dark",

+ 2 - 1
reflex/components/datadisplay/dataeditor.py

@@ -8,8 +8,9 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
 from reflex.components.literals import LiteralRowMarker
 from reflex.utils import console, format, imports, types
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
 from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
 
 
 
 
 # TODO: Fix the serialization issue for custom types.
 # TODO: Fix the serialization issue for custom types.

+ 2 - 1
reflex/components/datadisplay/dataeditor.pyi

@@ -13,8 +13,9 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
 from reflex.components.literals import LiteralRowMarker
 from reflex.utils import console, format, imports, types
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
 from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
 
 
 class GridColumnIcons(Enum):
 class GridColumnIcons(Enum):
     Array = "array"
     Array = "array"

+ 6 - 6
reflex/components/datadisplay/datatable.py

@@ -8,7 +8,7 @@ from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
 from reflex.utils import imports, types
 from reflex.utils.serializers import serialize, serializer
 from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
 
 
 
 
 class Gridjs(Component):
 class Gridjs(Component):
@@ -105,7 +105,7 @@ class DataTable(Gridjs):
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
         return imports.merge_imports(
             super()._get_imports(),
             super()._get_imports(),
-            {"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
+            {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
         )
         )
 
 
     def _render(self) -> Tag:
     def _render(self) -> Tag:
@@ -113,13 +113,13 @@ class DataTable(Gridjs):
             self.columns = BaseVar(
             self.columns = BaseVar(
                 _var_name=f"{self.data._var_name}.columns",
                 _var_name=f"{self.data._var_name}.columns",
                 _var_type=List[Any],
                 _var_type=List[Any],
-                _var_state=self.data._var_state,
-            )
+                _var_full_name_needs_state_prefix=True,
+            )._replace(merge_var_data=self.data._var_data)
             self.data = BaseVar(
             self.data = BaseVar(
                 _var_name=f"{self.data._var_name}.data",
                 _var_name=f"{self.data._var_name}.data",
                 _var_type=List[List[Any]],
                 _var_type=List[List[Any]],
-                _var_state=self.data._var_state,
-            )
+                _var_full_name_needs_state_prefix=True,
+            )._replace(merge_var_data=self.data._var_data)
         if types.is_dataframe(type(self.data)):
         if types.is_dataframe(type(self.data)):
             # If given a pandas df break up the data and columns
             # If given a pandas df break up the data and columns
             data = serialize(self.data)
             data = serialize(self.data)

+ 1 - 1
reflex/components/datadisplay/datatable.pyi

@@ -12,7 +12,7 @@ from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
 from reflex.utils import imports, types
 from reflex.utils.serializers import serialize, serializer
 from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
 
 
 class Gridjs(Component):
 class Gridjs(Component):
     @overload
     @overload

+ 2 - 2
reflex/components/datadisplay/moment.py

@@ -4,7 +4,7 @@ from typing import Any, Dict, List
 
 
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 
 
 class Moment(NoSSRComponent):
 class Moment(NoSSRComponent):
@@ -78,7 +78,7 @@ class Moment(NoSSRComponent):
         if self.tz is not None:
         if self.tz is not None:
             merged_imports = imports.merge_imports(
             merged_imports = imports.merge_imports(
                 merged_imports,
                 merged_imports,
-                {"moment-timezone": {ImportVar(tag="")}},
+                {"moment-timezone": {imports.ImportVar(tag="")}},
             )
             )
         return merged_imports
         return merged_imports
 
 

+ 1 - 1
reflex/components/datadisplay/moment.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from typing import Any, Dict, List
 from typing import Any, Dict, List
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class Moment(NoSSRComponent):
 class Moment(NoSSRComponent):
     def get_event_triggers(self) -> Dict[str, Any]: ...
     def get_event_triggers(self) -> Dict[str, Any]: ...

+ 2 - 2
reflex/components/forms/colormodeswitch.py

@@ -22,7 +22,7 @@ from reflex.components.component import Component
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.media.icon import Icon
 from reflex.components.media.icon import Icon
 from reflex.style import color_mode, toggle_color_mode
 from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
 
 
 from .button import Button
 from .button import Button
 from .switch import Switch
 from .switch import Switch
@@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
 DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
 DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
 
 
 
 
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component:
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
     """Create a component or Prop based on color_mode.
     """Create a component or Prop based on color_mode.
 
 
     Args:
     Args:

+ 2 - 2
reflex/components/forms/colormodeswitch.pyi

@@ -12,7 +12,7 @@ from reflex.components.component import Component
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.media.icon import Icon
 from reflex.components.media.icon import Icon
 from reflex.style import color_mode, toggle_color_mode
 from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
 from .button import Button
 from .button import Button
 from .switch import Switch
 from .switch import Switch
 
 
@@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str
 DEFAULT_LIGHT_ICON: Icon
 DEFAULT_LIGHT_ICON: Icon
 DEFAULT_DARK_ICON: Icon
 DEFAULT_DARK_ICON: Icon
 
 
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ...
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ...
 
 
 class ColorModeIcon(Cond):
 class ColorModeIcon(Cond):
     @overload
     @overload

+ 13 - 1
reflex/components/forms/debounce.py

@@ -1,10 +1,11 @@
 """Wrapper around react-debounce-input."""
 """Wrapper around react-debounce-input."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any
+from typing import Any, Set
 
 
 from reflex.components import Component
 from reflex.components import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
+from reflex.utils import imports
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 
 
@@ -77,6 +78,17 @@ class DebounceInput(Component):
         object.__setattr__(child, "render", lambda: "")
         object.__setattr__(child, "render", lambda: "")
         return tag
         return tag
 
 
+    def _get_imports(self) -> imports.ImportDict:
+        return imports.merge_imports(
+            super()._get_imports(), *[c._get_imports() for c in self.children]
+        )
+
+    def _get_hooks_internal(self) -> Set[str]:
+        hooks = super()._get_hooks_internal()
+        for child in self.children:
+            hooks.update(child._get_hooks_internal())
+        return hooks
+
 
 
 def props_not_none(c: Component) -> dict[str, Any]:
 def props_not_none(c: Component) -> dict[str, Any]:
     """Get all properties of the component that are not None.
     """Get all properties of the component that are not None.

+ 2 - 1
reflex/components/forms/debounce.pyi

@@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
-from typing import Any
+from typing import Any, Set
 from reflex.components import Component
 from reflex.components import Component
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
+from reflex.utils import imports
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 class DebounceInput(Component):
 class DebounceInput(Component):

+ 2 - 1
reflex/components/forms/editor.py

@@ -8,7 +8,8 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
 from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 
 
 class EditorButtonList(list, enum.Enum):
 class EditorButtonList(list, enum.Enum):

+ 2 - 1
reflex/components/forms/editor.pyi

@@ -13,7 +13,8 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
 from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 class EditorButtonList(list, enum.Enum):
 class EditorButtonList(list, enum.Enum):
     BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]
     BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]

+ 8 - 2
reflex/components/forms/form.py

@@ -8,7 +8,7 @@ from jinja2 import Environment
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
 from reflex.event import EventChain
 from reflex.utils import imports
 from reflex.utils import imports
 from reflex.utils.format import format_event_chain, to_camel_case
 from reflex.utils.format import format_event_chain, to_camel_case
@@ -65,7 +65,13 @@ class Form(ChakraComponent):
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
         return imports.merge_imports(
             super()._get_imports(),
             super()._get_imports(),
-            {"react": {imports.ImportVar(tag="useCallback")}},
+            {
+                "react": {imports.ImportVar(tag="useCallback")},
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="getRefValue"),
+                    imports.ImportVar(tag="getRefValues"),
+                },
+            },
         )
         )
 
 
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:

+ 1 - 1
reflex/components/forms/form.pyi

@@ -12,7 +12,7 @@ from jinja2 import Environment
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
 from reflex.event import EventChain
 from reflex.utils import imports
 from reflex.utils import imports
 from reflex.utils.format import format_event_chain, to_camel_case
 from reflex.utils.format import format_event_chain, to_camel_case

+ 2 - 2
reflex/components/forms/input.py

@@ -11,7 +11,7 @@ from reflex.components.libs.chakra import (
 )
 )
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 
 
 class Input(ChakraComponent):
 class Input(ChakraComponent):
@@ -61,7 +61,7 @@ class Input(ChakraComponent):
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
         return imports.merge_imports(
             super()._get_imports(),
             super()._get_imports(),
-            {"/utils/state": {ImportVar(tag="set_val")}},
+            {"/utils/state": {imports.ImportVar(tag="set_val")}},
         )
         )
 
 
     def get_event_triggers(self) -> Dict[str, Any]:
     def get_event_triggers(self) -> Dict[str, Any]:

+ 1 - 1
reflex/components/forms/input.pyi

@@ -17,7 +17,7 @@ from reflex.components.libs.chakra import (
 )
 )
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class Input(ChakraComponent):
 class Input(ChakraComponent):
     def get_event_triggers(self) -> Dict[str, Any]: ...
     def get_event_triggers(self) -> Dict[str, Any]: ...

+ 3 - 1
reflex/components/forms/pininput.py

@@ -68,9 +68,11 @@ class PinInput(ChakraComponent):
         Returns:
         Returns:
             The merged import dict.
             The merged import dict.
         """
         """
+        range_var = Var.range(0)
         return merge_imports(
         return merge_imports(
             super()._get_imports(),
             super()._get_imports(),
             PinInputField().get_imports(),  # type: ignore
             PinInputField().get_imports(),  # type: ignore
+            range_var._var_data.imports if range_var._var_data is not None else {},
         )
         )
 
 
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
@@ -117,7 +119,7 @@ class PinInput(ChakraComponent):
             )
             )
             refs_declaration._var_is_local = True
             refs_declaration._var_is_local = True
             if ref:
             if ref:
-                return f"const {ref} = {refs_declaration}"
+                return f"const {ref} = {str(refs_declaration)}"
             return super()._get_ref_hook()
             return super()._get_ref_hook()
 
 
     def _render(self) -> Tag:
     def _render(self) -> Tag:

+ 27 - 10
reflex/components/forms/upload.py

@@ -7,9 +7,10 @@ from reflex import constants
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
+from reflex.constants import Dirs
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 
 DEFAULT_UPLOAD_ID: str = "default"
 DEFAULT_UPLOAD_ID: str = "default"
 
 
@@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
     return BaseVar(
     return BaseVar(
         _var_name=f"e => upload_files.{id_}[1]((files) => e)",
         _var_name=f"e => upload_files.{id_}[1]((files) => e)",
         _var_type=EventChain,
         _var_type=EventChain,
+        _var_data=VarData(  # type: ignore
+            imports={
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="upload_files"),
+                },
+            },
+        ),
     )
     )
 
 
 
 
@@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
     return BaseVar(
     return BaseVar(
         _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
         _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
         _var_type=List[str],
         _var_type=List[str],
+        _var_data=VarData(  # type: ignore
+            imports={
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="upload_files"),
+                },
+            },
+        ),
     )
     )
 
 
 
 
@@ -166,14 +181,16 @@ class Upload(Component):
 
 
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:
         return (
         return (
-            (super()._get_hooks() or "")
-            + f"""
-        upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
-        """
-        )
+            super()._get_hooks() or ""
+        ) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);"
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
-        return {
-            **super()._get_imports(),
-            f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
-        }
+        return imports.merge_imports(
+            super()._get_imports(),
+            {
+                "react": {imports.ImportVar(tag="useState")},
+                f"/{constants.Dirs.STATE_PATH}": [
+                    imports.ImportVar(tag="upload_files")
+                ],
+            },
+        )

+ 2 - 1
reflex/components/forms/upload.pyi

@@ -12,9 +12,10 @@ from reflex import constants
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
+from reflex.constants import Dirs
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 
 DEFAULT_UPLOAD_ID: str
 DEFAULT_UPLOAD_ID: str
 
 

+ 44 - 7
reflex/components/layout/cond.py

@@ -1,13 +1,18 @@
 """Create a list of components from an iterable."""
 """Create a list of components from an iterable."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, overload
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.layout.fragment import Fragment
 from reflex.components.layout.fragment import Fragment
 from reflex.components.tags import CondTag, Tag
 from reflex.components.tags import CondTag, Tag
-from reflex.utils import format
-from reflex.vars import Var
+from reflex.constants import Dirs
+from reflex.utils import format, imports
+from reflex.vars import BaseVar, Var, VarData
+
+_IS_TRUE_IMPORT = {
+    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
+}
 
 
 
 
 class Cond(Component):
 class Cond(Component):
@@ -88,6 +93,28 @@ class Cond(Component):
             cond_state=f"isTrue({self.cond._var_full_name})",
             cond_state=f"isTrue({self.cond._var_full_name})",
         )
         )
 
 
+    def _get_imports(self) -> imports.ImportDict:
+        return imports.merge_imports(
+            super()._get_imports(),
+            getattr(self.cond._var_data, "imports", {}),
+            _IS_TRUE_IMPORT,
+        )
+
+
+@overload
+def cond(condition: Any, c1: Component, c2: Any) -> Component:
+    ...
+
+
+@overload
+def cond(condition: Any, c1: Component) -> Component:
+    ...
+
+
+@overload
+def cond(condition: Any, c1: Any, c2: Any) -> Var:
+    ...
+
 
 
 def cond(condition: Any, c1: Any, c2: Any = None):
 def cond(condition: Any, c1: Any, c2: Any = None):
     """Create a conditional component or Prop.
     """Create a conditional component or Prop.
@@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
     Raises:
     Raises:
         ValueError: If the arguments are invalid.
         ValueError: If the arguments are invalid.
     """
     """
-    # Import here to avoid circular imports.
-    from reflex.vars import BaseVar, Var
+    var_datas: list[VarData | None] = [
+        VarData(  # type: ignore
+            imports=_IS_TRUE_IMPORT,
+        ),
+    ]
 
 
     # Convert the condition to a Var.
     # Convert the condition to a Var.
     cond_var = Var.create(condition)
     cond_var = Var.create(condition)
@@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None):
             c2, Component
             c2, Component
         ), "Both arguments must be components."
         ), "Both arguments must be components."
         return Cond.create(cond_var, c1, c2)
         return Cond.create(cond_var, c1, c2)
+    if isinstance(c1, Var):
+        var_datas.append(c1._var_data)
 
 
-    # Otherwise, create a conditionl Var.
+    # Otherwise, create a conditional Var.
     # Check that the second argument is valid.
     # Check that the second argument is valid.
     if isinstance(c2, Component):
     if isinstance(c2, Component):
         raise ValueError("Both arguments must be props.")
         raise ValueError("Both arguments must be props.")
     if c2 is None:
     if c2 is None:
         raise ValueError("For conditional vars, the second argument must be set.")
         raise ValueError("For conditional vars, the second argument must be set.")
+    if isinstance(c2, Var):
+        var_datas.append(c2._var_data)
 
 
     # Create the conditional var.
     # Create the conditional var.
-    return BaseVar(
+    return cond_var._replace(
         _var_name=format.format_cond(
         _var_name=format.format_cond(
             cond=cond_var._var_full_name,
             cond=cond_var._var_full_name,
             true_value=c1,
             true_value=c1,
@@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
             is_prop=True,
             is_prop=True,
         ),
         ),
         _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
         _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
+        _var_is_local=False,
+        _var_full_name_needs_state_prefix=False,
+        merge_var_data=VarData.merge(*var_datas),
     )
     )

+ 3 - 3
reflex/components/layout/html.py

@@ -1,8 +1,8 @@
 """A html component."""
 """A html component."""
-
-from typing import Any
+from typing import Dict
 
 
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
+from reflex.vars import Var
 
 
 
 
 class Html(Box):
 class Html(Box):
@@ -13,7 +13,7 @@ class Html(Box):
     """
     """
 
 
     # The HTML to render.
     # The HTML to render.
-    dangerouslySetInnerHTML: Any
+    dangerouslySetInnerHTML: Var[Dict[str, str]]
 
 
     @classmethod
     @classmethod
     def create(cls, *children, **props):
     def create(cls, *children, **props):

+ 5 - 2
reflex/components/layout/html.pyi

@@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
-from typing import Any
+from typing import Dict
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
+from reflex.vars import Var
 
 
 class Html(Box):
 class Html(Box):
     @overload
     @overload
@@ -16,7 +17,9 @@ class Html(Box):
     def create(  # type: ignore
     def create(  # type: ignore
         cls,
         cls,
         *children,
         *children,
-        dangerouslySetInnerHTML: Optional[Any] = None,
+        dangerouslySetInnerHTML: Optional[
+            Union[Var[Dict[str, str]], Dict[str, str]]
+        ] = None,
         element: Optional[Union[Var[str], str]] = None,
         element: Optional[Union[Var[str], str]] = None,
         src: Optional[Union[Var[str], str]] = None,
         src: Optional[Union[Var[str], str]] = None,
         alt: Optional[Union[Var[str], str]] = None,
         alt: Optional[Union[Var[str], str]] = None,

+ 10 - 10
reflex/components/libs/chakra.py

@@ -6,7 +6,7 @@ from typing import List, Literal
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 
 
 class ChakraComponent(Component):
 class ChakraComponent(Component):
@@ -34,7 +34,7 @@ class ChakraComponent(Component):
             The dependencies imports of the component.
             The dependencies imports of the component.
         """
         """
         return {
         return {
-            dep: [ImportVar(tag=None, render=False)]
+            dep: [imports.ImportVar(tag=None, render=False)]
             for dep in [
             for dep in [
                 "@chakra-ui/system@2.5.7",
                 "@chakra-ui/system@2.5.7",
                 "framer-motion@10.16.4",
                 "framer-motion@10.16.4",
@@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent):
         )
         )
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
-        imports = super()._get_imports()
-        imports.setdefault(self.__fields__["library"].default, []).append(
-            ImportVar(tag="extendTheme", is_default=False),
+        _imports = super()._get_imports()
+        _imports.setdefault(self.__fields__["library"].default, []).append(
+            imports.ImportVar(tag="extendTheme", is_default=False),
         )
         )
-        imports.setdefault("/utils/theme.js", []).append(
-            ImportVar(tag="theme", is_default=True),
+        _imports.setdefault("/utils/theme.js", []).append(
+            imports.ImportVar(tag="theme", is_default=True),
         )
         )
-        imports.setdefault(Global.__fields__["library"].default, []).append(
-            ImportVar(tag="css", is_default=False),
+        _imports.setdefault(Global.__fields__["library"].default, []).append(
+            imports.ImportVar(tag="css", is_default=False),
         )
         )
-        return imports
+        return _imports
 
 
     def _get_custom_code(self) -> str | None:
     def _get_custom_code(self) -> str | None:
         return """
         return """

+ 1 - 1
reflex/components/libs/chakra.pyi

@@ -11,7 +11,7 @@ from functools import lru_cache
 from typing import List, Literal
 from typing import List, Literal
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class ChakraComponent(Component):
 class ChakraComponent(Component):
     @overload
     @overload

+ 7 - 8
reflex/components/navigation/client_side_routing.py

@@ -10,10 +10,9 @@ routeNotFound becomes true.
 from __future__ import annotations
 from __future__ import annotations
 
 
 from reflex import constants
 from reflex import constants
-
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
 
 
 route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
 route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
 
 
@@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component:
     Returns:
     Returns:
         The conditionally rendered component.
         The conditionally rendered component.
     """
     """
-    return Cond.create(
-        cond=route_not_found,
-        comp1=component,
-        comp2=ClientSideRouting.create(),
+    return cond(
+        condition=route_not_found,
+        c1=component,
+        c2=ClientSideRouting.create(),
     )
     )
 
 
 
 

+ 3 - 3
reflex/components/navigation/client_side_routing.pyi

@@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
 from reflex import constants
 from reflex import constants
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
 
 
 route_not_found: Var
 route_not_found: Var
 
 

+ 13 - 8
reflex/components/overlay/banner.py

@@ -5,22 +5,27 @@ from typing import Optional
 
 
 from reflex.components.base.bare import Bare
 from reflex.components.base.bare import Bare
 from reflex.components.component import Component
 from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
 from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
+
+connect_error_var_data: VarData = VarData(  # type: ignore
+    imports=Imports.EVENTS,
+    hooks={Hooks.EVENTS},
+)
 
 
 connection_error: Var = Var.create_safe(
 connection_error: Var = Var.create_safe(
     value="(connectError !== null) ? connectError.message : ''",
     value="(connectError !== null) ? connectError.message : ''",
     _var_is_local=False,
     _var_is_local=False,
     _var_is_string=False,
     _var_is_string=False,
-)
+)._replace(merge_var_data=connect_error_var_data)
 has_connection_error: Var = Var.create_safe(
 has_connection_error: Var = Var.create_safe(
     value="connectError !== null",
     value="connectError !== null",
     _var_is_string=False,
     _var_is_string=False,
-)
-has_connection_error._var_type = bool
+)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
 
 
 
 
 class WebsocketTargetURL(Bare):
 class WebsocketTargetURL(Bare):
@@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare):
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         return {
         return {
-            "/utils/state.js": [ImportVar(tag="getEventURL")],
+            "/utils/state.js": [imports.ImportVar(tag="getEventURL")],
         }
         }
 
 
     @classmethod
     @classmethod
@@ -78,7 +83,7 @@ class ConnectionBanner(Component):
                 textAlign="center",
                 textAlign="center",
             )
             )
 
 
-        return Cond.create(has_connection_error, comp)
+        return cond(has_connection_error, comp)
 
 
 
 
 class ConnectionModal(Component):
 class ConnectionModal(Component):
@@ -96,7 +101,7 @@ class ConnectionModal(Component):
         """
         """
         if not comp:
         if not comp:
             comp = Text.create(*default_connection_error())
             comp = Text.create(*default_connection_error())
-        return Cond.create(
+        return cond(
             has_connection_error,
             has_connection_error,
             Modal.create(
             Modal.create(
                 header="Connection Error",
                 header="Connection Error",

+ 4 - 3
reflex/components/overlay/banner.pyi

@@ -10,15 +10,16 @@ from reflex.style import Style
 from typing import Optional
 from typing import Optional
 from reflex.components.base.bare import Bare
 from reflex.components.base.bare import Bare
 from reflex.components.component import Component
 from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
 from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
 
 
+connect_error_var_data: VarData
 connection_error: Var
 connection_error: Var
 has_connection_error: Var
 has_connection_error: Var
-has_connection_error._var_type = bool
 
 
 class WebsocketTargetURL(Bare):
 class WebsocketTargetURL(Bare):
     @overload
     @overload

+ 2 - 2
reflex/components/radix/themes/base.py

@@ -6,7 +6,7 @@ from typing import Literal
 
 
 from reflex.components import Component
 from reflex.components import Component
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralJustify = Literal["start", "center", "end", "between"]
 LiteralJustify = Literal["start", "center", "end", "between"]
@@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         return {
         return {
             **super()._get_imports(),
             **super()._get_imports(),
-            "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
+            "": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
         }
         }
 
 
 
 

+ 1 - 1
reflex/components/radix/themes/base.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from typing import Literal
 from typing import Literal
 from reflex.components import Component
 from reflex.components import Component
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralJustify = Literal["start", "center", "end", "between"]
 LiteralJustify = Literal["start", "center", "end", "between"]

+ 2 - 1
reflex/components/typography/markdown.py

@@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading
 from reflex.components.typography.text import Text
 from reflex.components.typography.text import Text
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, imports, types
 from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 # Special vars used in the component map.
 # Special vars used in the component map.
 _CHILDREN = Var.create_safe("children", _var_is_local=False)
 _CHILDREN = Var.create_safe("children", _var_is_local=False)

+ 2 - 1
reflex/components/typography/markdown.pyi

@@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading
 from reflex.components.typography.text import Text
 from reflex.components.typography.text import Text
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import console, imports, types
 from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 _CHILDREN = Var.create_safe("children", _var_is_local=False)
 _CHILDREN = Var.create_safe("children", _var_is_local=False)
 _PROPS = Var.create_safe("...props", _var_is_local=False)
 _PROPS = Var.create_safe("...props", _var_is_local=False)

+ 4 - 0
reflex/constants/__init__.py

@@ -22,6 +22,8 @@ from .compiler import (
     CompileVars,
     CompileVars,
     ComponentName,
     ComponentName,
     Ext,
     Ext,
+    Hooks,
+    Imports,
     PageNames,
     PageNames,
 )
 )
 from .config import (
 from .config import (
@@ -68,7 +70,9 @@ __ALL__ = [
     Ext,
     Ext,
     Fnm,
     Fnm,
     GitIgnore,
     GitIgnore,
+    Hooks,
     RequirementsTxt,
     RequirementsTxt,
+    Imports,
     IS_WINDOWS,
     IS_WINDOWS,
     LOCAL_STORAGE,
     LOCAL_STORAGE,
     LogLevel,
     LogLevel,

+ 2 - 0
reflex/constants/base.py

@@ -29,6 +29,8 @@ class Dirs(SimpleNamespace):
     STATE_PATH = "/".join([UTILS, "state"])
     STATE_PATH = "/".join([UTILS, "state"])
     # The name of the components file.
     # The name of the components file.
     COMPONENTS_PATH = "/".join([UTILS, "components"])
     COMPONENTS_PATH = "/".join([UTILS, "components"])
+    # The name of the contexts file.
+    CONTEXTS_PATH = "/".join([UTILS, "context"])
     # The directory where the app pages are compiled to.
     # The directory where the app pages are compiled to.
     WEB_PAGES = os.path.join(WEB, "pages")
     WEB_PAGES = os.path.join(WEB, "pages")
     # The directory where the static build is located.
     # The directory where the static build is located.

+ 25 - 0
reflex/constants/compiler.py

@@ -2,6 +2,9 @@
 from enum import Enum
 from enum import Enum
 from types import SimpleNamespace
 from types import SimpleNamespace
 
 
+from reflex.constants import Dirs
+from reflex.utils.imports import ImportVar
+
 # The prefix used to create setters for state vars.
 # The prefix used to create setters for state vars.
 SETTER_PREFIX = "set_"
 SETTER_PREFIX = "set_"
 
 
@@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace):
     HYDRATE = "hydrate"
     HYDRATE = "hydrate"
     # The name of the is_hydrated variable.
     # The name of the is_hydrated variable.
     IS_HYDRATED = "is_hydrated"
     IS_HYDRATED = "is_hydrated"
+    # The name of the function to add events to the queue.
+    ADD_EVENTS = "addEvents"
+    # The name of the var storing any connection error.
+    CONNECT_ERROR = "connectError"
+    # The name of the function for converting a dict to an event.
+    TO_EVENT = "Event"
 
 
 
 
 class PageNames(SimpleNamespace):
 class PageNames(SimpleNamespace):
@@ -77,3 +86,19 @@ class ComponentName(Enum):
             The lower-case filename with zip extension.
             The lower-case filename with zip extension.
         """
         """
         return self.value.lower() + Ext.ZIP
         return self.value.lower() + Ext.ZIP
+
+
+class Imports(SimpleNamespace):
+    """Common sets of import vars."""
+
+    EVENTS = {
+        "react": {ImportVar(tag="useContext")},
+        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
+        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
+    }
+
+
+class Hooks(SimpleNamespace):
+    """Common sets of hook declarations."""
+
+    EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"

+ 1 - 1
reflex/middleware/hydrate_middleware.py

@@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware):
                     setattr(var_state, var_name, value)
                     setattr(var_state, var_name, value)
 
 
         # Get the initial state.
         # Get the initial state.
-        delta = format.format_state({state.get_name(): state.dict()})
+        delta = format.format_state(state.dict())
         # since a full dict was captured, clean any dirtiness
         # since a full dict was captured, clean any dirtiness
         state._clean()
         state._clean()
 
 

+ 9 - 5
reflex/state.py

@@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             if include_computed
             if include_computed
             else {}
             else {}
         )
         )
-        substate_vars = {
-            k: v.dict(include_computed=include_computed, **kwargs)
-            for k, v in self.substates.items()
+        variables = {**base_vars, **computed_vars}
+        d = {
+            self.get_full_name(): {k: variables[k] for k in sorted(variables)},
         }
         }
-        variables = {**base_vars, **computed_vars, **substate_vars}
-        return {k: variables[k] for k in sorted(variables)}
+        for substate_d in [
+            v.dict(include_computed=include_computed, **kwargs)
+            for v in self.substates.values()
+        ]:
+            d.update(substate_d)
+        return d
 
 
     async def __aenter__(self) -> State:
     async def __aenter__(self) -> State:
         """Enter the async context manager protocol.
         """Enter the async context manager protocol.

+ 71 - 7
reflex/style.py

@@ -2,13 +2,38 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+from typing import Any
+
 from reflex import constants
 from reflex import constants
 from reflex.event import EventChain
 from reflex.event import EventChain
 from reflex.utils import format
 from reflex.utils import format
-from reflex.vars import BaseVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import BaseVar, Var, VarData
+
+VarData.update_forward_refs()  # Ensure all type definitions are resolved
 
 
-color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str")
-toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain)
+# Reference the global ColorModeContext
+color_mode_var_data = VarData(  # type: ignore
+    imports={
+        f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
+        "react": {ImportVar(tag="useContext")},
+    },
+    hooks={
+        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+    },
+)
+# Var resolves to the current color mode for the app ("light" or "dark")
+color_mode = BaseVar(
+    _var_name=constants.ColorMode.NAME,
+    _var_type="str",
+    _var_data=color_mode_var_data,
+)
+# Var resolves to a function invocation that toggles the color mode
+toggle_color_mode = BaseVar(
+    _var_name=constants.ColorMode.TOGGLE,
+    _var_type=EventChain,
+    _var_data=color_mode_var_data,
+)
 
 
 
 
 def convert(style_dict):
 def convert(style_dict):
@@ -20,16 +45,27 @@ def convert(style_dict):
     Returns:
     Returns:
         The formatted style dictionary.
         The formatted style dictionary.
     """
     """
+    var_data = None  # Track import/hook data from any Vars in the style dict.
     out = {}
     out = {}
     for key, value in style_dict.items():
     for key, value in style_dict.items():
         key = format.to_camel_case(key)
         key = format.to_camel_case(key)
+        new_var_data = None
         if isinstance(value, dict):
         if isinstance(value, dict):
-            out[key] = convert(value)
+            # Recursively format nested style dictionaries.
+            out[key], new_var_data = convert(value)
         elif isinstance(value, Var):
         elif isinstance(value, Var):
+            # If the value is a Var, extract the var_data and cast as str.
+            new_var_data = value._var_data
             out[key] = str(value)
             out[key] = str(value)
         else:
         else:
+            # Otherwise, convert to Var to collapse VarData encoded in f-string.
+            new_var = Var.create(value)
+            if new_var is not None:
+                new_var_data = new_var._var_data
             out[key] = value
             out[key] = value
-    return out
+        # Combine all the collected VarData instances.
+        var_data = VarData.merge(var_data, new_var_data)
+    return out, var_data
 
 
 
 
 class Style(dict):
 class Style(dict):
@@ -41,5 +77,33 @@ class Style(dict):
         Args:
         Args:
             style_dict: The style dictionary.
             style_dict: The style dictionary.
         """
         """
-        style_dict = style_dict or {}
-        super().__init__(convert(style_dict))
+        style_dict, self._var_data = convert(style_dict or {})
+        super().__init__(style_dict)
+
+    def update(self, style_dict: dict | None, **kwargs):
+        """Update the style.
+
+        Args:
+            style_dict: The style dictionary.
+            kwargs: Other key value pairs to apply to the dict update.
+        """
+        if kwargs:
+            style_dict = {**(style_dict or {}), **kwargs}
+        converted_dict = type(self)(style_dict)
+        # Combine our VarData with that of any Vars in the style_dict that was passed.
+        self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
+        super().update(converted_dict)
+
+    def __setitem__(self, key: str, value: Any):
+        """Set an item in the style.
+
+        Args:
+            key: The key to set.
+            value: The value to set.
+        """
+        # Create a Var to collapse VarData encoded in f-string.
+        _var = Var.create(value)
+        if _var is not None:
+            # Carry the imports/hooks when setting a Var as a value.
+            self._var_data = VarData.merge(self._var_data, _var._var_data)
+        super().__setitem__(key, value)

+ 19 - 6
reflex/utils/format.py

@@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str:
 
 
 
 
 def format_cond(
 def format_cond(
-    cond: str,
-    true_value: str,
-    false_value: str = '""',
+    cond: str | Var,
+    true_value: str | Var,
+    false_value: str | Var = '""',
     is_prop=False,
     is_prop=False,
 ) -> str:
 ) -> str:
     """Format a conditional expression.
     """Format a conditional expression.
@@ -248,9 +248,6 @@ def format_cond(
     Returns:
     Returns:
         The formatted conditional expression.
         The formatted conditional expression.
     """
     """
-    # Import here to avoid circular imports.
-    from reflex.vars import Var
-
     # Use Python truthiness.
     # Use Python truthiness.
     cond = f"isTrue({cond})"
     cond = f"isTrue({cond})"
 
 
@@ -266,6 +263,7 @@ def format_cond(
             _var_is_string=type(false_value) is str,
             _var_is_string=type(false_value) is str,
         )
         )
         prop2._var_is_local = True
         prop2._var_is_local = True
+        prop1, prop2 = str(prop1), str(prop2)  # avoid f-string semantics for Var
         return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
         return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
 
 
     # Format component conds.
     # Format component conds.
@@ -517,6 +515,21 @@ def format_state(value: Any) -> Any:
     raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
     raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
 
 
 
 
+def format_state_name(state_name: str) -> str:
+    """Format a state name, replacing dots with double underscore.
+
+    This allows individual substates to be accessed independently as javascript vars
+    without using dot notation.
+
+    Args:
+        state_name: The state name to format.
+
+    Returns:
+        The formatted state name.
+    """
+    return state_name.replace(".", "__")
+
+
 def format_ref(ref: str) -> str:
 def format_ref(ref: str) -> str:
     """Format a ref.
     """Format a ref.
 
 

+ 41 - 4
reflex/utils/imports.py

@@ -3,11 +3,9 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from collections import defaultdict
 from collections import defaultdict
-from typing import Dict, List
+from typing import Dict, List, Optional
 
 
-from reflex.vars import ImportVar
-
-ImportDict = Dict[str, List[ImportVar]]
+from reflex.base import Base
 
 
 
 
 def merge_imports(*imports) -> ImportDict:
 def merge_imports(*imports) -> ImportDict:
@@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict:
         for lib, fields in import_dict.items():
         for lib, fields in import_dict.items():
             all_imports[lib].extend(fields)
             all_imports[lib].extend(fields)
     return all_imports
     return all_imports
+
+
+class ImportVar(Base):
+    """An import var."""
+
+    # The name of the import tag.
+    tag: Optional[str]
+
+    # whether the import is default or named.
+    is_default: Optional[bool] = False
+
+    # The tag alias.
+    alias: Optional[str] = None
+
+    # Whether this import need to install the associated lib
+    install: Optional[bool] = True
+
+    # whether this import should be rendered or not
+    render: Optional[bool] = True
+
+    @property
+    def name(self) -> str:
+        """The name of the import.
+
+        Returns:
+            The name(tag name with alias) of tag.
+        """
+        return self.tag if not self.alias else " as ".join([self.tag, self.alias])  # type: ignore
+
+    def __hash__(self) -> int:
+        """Define a hash function for the import var.
+
+        Returns:
+            The hash of the var.
+        """
+        return hash((self.tag, self.is_default, self.alias, self.install, self.render))
+
+
+ImportDict = Dict[str, List[ImportVar]]

+ 1 - 0
reflex/utils/types.py

@@ -27,6 +27,7 @@ from reflex.utils import serializers
 GenericType = Union[Type, _GenericAlias]
 GenericType = Union[Type, _GenericAlias]
 
 
 # Valid state var types.
 # Valid state var types.
+JSONType = {str, int, float, bool}
 PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
 PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
 StateVar = Union[PrimitiveType, Base, None]
 StateVar = Union[PrimitiveType, Base, None]
 StateIterVar = Union[list, set, tuple]
 StateIterVar = Union[list, set, tuple]

+ 298 - 120
reflex/vars.py

@@ -7,6 +7,7 @@ import dis
 import inspect
 import inspect
 import json
 import json
 import random
 import random
+import re
 import string
 import string
 import sys
 import sys
 from types import CodeType, FunctionType
 from types import CodeType, FunctionType
@@ -15,9 +16,11 @@ from typing import (
     Any,
     Any,
     Callable,
     Callable,
     Dict,
     Dict,
+    Iterable,
     List,
     List,
     Literal,
     Literal,
     Optional,
     Optional,
+    Set,
     Tuple,
     Tuple,
     Type,
     Type,
     Union,
     Union,
@@ -30,7 +33,10 @@ from typing import (
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
-from reflex.utils import console, format, serializers, types
+from reflex.utils import console, format, imports, serializers, types
+
+# This module used to export ImportVar itself, so we still import it for export here
+from reflex.utils.imports import ImportDict, ImportVar
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.state import State
     from reflex.state import State
@@ -71,7 +77,7 @@ OPERATION_MAPPING = {
 REPLACED_NAMES = {
 REPLACED_NAMES = {
     "full_name": "_var_full_name",
     "full_name": "_var_full_name",
     "name": "_var_name",
     "name": "_var_name",
-    "state": "_var_state",
+    "state": "_var_data.state",
     "type_": "_var_type",
     "type_": "_var_type",
     "is_local": "_var_is_local",
     "is_local": "_var_is_local",
     "is_string": "_var_is_string",
     "is_string": "_var_is_string",
@@ -93,6 +99,131 @@ def get_unique_variable_name() -> str:
     return get_unique_variable_name()
     return get_unique_variable_name()
 
 
 
 
+class VarData(Base):
+    """Metadata associated with a Var."""
+
+    # The name of the enclosing state.
+    state: str = ""
+
+    # Imports needed to render this var
+    imports: ImportDict = {}
+
+    # Hooks that need to be present in the component to render this var
+    hooks: Set[str] = set()
+
+    @classmethod
+    def merge(cls, *others: VarData | None) -> VarData | None:
+        """Merge multiple var data objects.
+
+        Args:
+            *others: The var data objects to merge.
+
+        Returns:
+            The merged var data object.
+        """
+        state = ""
+        _imports = {}
+        hooks = set()
+        for var_data in others:
+            if var_data is None:
+                continue
+            state = state or var_data.state
+            _imports = imports.merge_imports(_imports, var_data.imports)
+            hooks.update(var_data.hooks)
+        return (
+            cls(
+                state=state,
+                imports=_imports,
+                hooks=hooks,
+            )
+            or None
+        )
+
+    def __bool__(self) -> bool:
+        """Check if the var data is non-empty.
+
+        Returns:
+            True if any field is set to a non-default value.
+        """
+        return bool(self.state or self.imports or self.hooks)
+
+    def dict(self) -> dict:
+        """Convert the var data to a dictionary.
+
+        Returns:
+            The var data dictionary.
+        """
+        return {
+            "state": self.state,
+            "imports": {
+                lib: [import_var.dict() for import_var in import_vars]
+                for lib, import_vars in self.imports.items()
+            },
+            "hooks": list(self.hooks),
+        }
+
+
+def _encode_var(value: Var) -> str:
+    """Encode the state name into a formatted var.
+
+    Args:
+        value: The value to encode the state name into.
+
+    Returns:
+        The encoded var.
+    """
+    if value._var_data:
+        return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
+    return str(value)
+
+
+def _decode_var(value: str) -> tuple[VarData | None, str]:
+    """Decode the state name from a formatted var.
+
+    Args:
+        value: The value to extract the state name from.
+
+    Returns:
+        The extracted state name and the value without the state name.
+    """
+    var_datas = []
+    if isinstance(value, str):
+        # Extract the state name from a formatted var
+        while m := re.match(r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)", value):
+            value = m.group(1) + m.group(3)
+            var_datas.append(VarData.parse_raw(m.group(2)))
+    if var_datas:
+        return VarData.merge(*var_datas), value
+    return None, value
+
+
+def _extract_var_data(value: Iterable) -> list[VarData | None]:
+    """Extract the var imports and hooks from an iterable containing a Var.
+
+    Args:
+        value: The iterable to extract the VarData from
+
+    Returns:
+        The extracted VarDatas.
+    """
+    var_datas = []
+    with contextlib.suppress(TypeError):
+        for sub in value:
+            if isinstance(sub, Var):
+                var_datas.append(sub._var_data)
+            elif not isinstance(sub, str):
+                # Recurse into dict values.
+                if hasattr(sub, "values") and callable(sub.values):
+                    var_datas.extend(_extract_var_data(sub.values()))
+                # Recurse into iterable values (or dict keys).
+                var_datas.extend(_extract_var_data(sub))
+    # Recurse when value is a dict itself.
+    values = getattr(value, "values", None)
+    if callable(values):
+        var_datas.extend(_extract_var_data(values()))
+    return var_datas
+
+
 class Var:
 class Var:
     """An abstract var."""
     """An abstract var."""
 
 
@@ -102,15 +233,18 @@ class Var:
     # The type of the var.
     # The type of the var.
     _var_type: Type
     _var_type: Type
 
 
-    # The name of the enclosing state.
-    _var_state: str
-
     # Whether this is a local javascript variable.
     # Whether this is a local javascript variable.
     _var_is_local: bool
     _var_is_local: bool
 
 
     # Whether the var is a string literal.
     # Whether the var is a string literal.
     _var_is_string: bool
     _var_is_string: bool
 
 
+    # _var_full_name should be prefixed with _var_state
+    _var_full_name_needs_state_prefix: bool
+
+    # Extra metadata associated with the Var
+    _var_data: Optional[VarData]
+
     @classmethod
     @classmethod
     def create(
     def create(
         cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
         cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
@@ -136,9 +270,14 @@ class Var:
         if isinstance(value, Var):
         if isinstance(value, Var):
             return value
             return value
 
 
+        # Try to pull the imports and hooks from contained values.
+        _var_data = None
+        if not isinstance(value, str):
+            _var_data = VarData.merge(*_extract_var_data(value))
+
         # Try to serialize the value.
         # Try to serialize the value.
         type_ = type(value)
         type_ = type(value)
-        name = serializers.serialize(value)
+        name = value if type_ in types.JSONType else serializers.serialize(value)
         if name is None:
         if name is None:
             raise TypeError(
             raise TypeError(
                 f"No JSON serializer found for var {value} of type {type_}."
                 f"No JSON serializer found for var {value} of type {type_}."
@@ -150,6 +289,7 @@ class Var:
             _var_type=type_,
             _var_type=type_,
             _var_is_local=_var_is_local,
             _var_is_local=_var_is_local,
             _var_is_string=_var_is_string,
             _var_is_string=_var_is_string,
+            _var_data=_var_data,
         )
         )
 
 
     @classmethod
     @classmethod
@@ -186,6 +326,39 @@ class Var:
         """
         """
         return _GenericAlias(cls, type_)
         return _GenericAlias(cls, type_)
 
 
+    def __post_init__(self) -> None:
+        """Post-initialize the var."""
+        # Decode any inline Var markup and apply it to the instance
+        _var_data, _var_name = _decode_var(self._var_name)
+        if _var_data:
+            self._var_name = _var_name
+            self._var_data = VarData.merge(self._var_data, _var_data)
+
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
+        """Make a copy of this Var with updated fields.
+
+        Args:
+            merge_var_data: VarData to merge into the existing VarData.
+            **kwargs: Var fields to update.
+
+        Returns:
+            A new BaseVar with the updated fields overwriting the corresponding fields in this Var.
+        """
+        field_values = dict(
+            _var_name=kwargs.pop("_var_name", self._var_name),
+            _var_type=kwargs.pop("_var_type", self._var_type),
+            _var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
+            _var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
+            _var_full_name_needs_state_prefix=kwargs.pop(
+                "_var_full_name_needs_state_prefix",
+                self._var_full_name_needs_state_prefix,
+            ),
+            _var_data=VarData.merge(
+                kwargs.get("_var_data", self._var_data), merge_var_data
+            ),
+        )
+        return BaseVar(**field_values)
+
     def _decode(self) -> Any:
     def _decode(self) -> Any:
         """Decode Var as a python value.
         """Decode Var as a python value.
 
 
@@ -195,8 +368,6 @@ class Var:
         Returns:
         Returns:
             The decoded value or the Var name.
             The decoded value or the Var name.
         """
         """
-        if self._var_state:
-            return self._var_full_name
         if self._var_is_string:
         if self._var_is_string:
             return self._var_name
             return self._var_name
         try:
         try:
@@ -216,8 +387,10 @@ class Var:
         return (
         return (
             self._var_name == other._var_name
             self._var_name == other._var_name
             and self._var_type == other._var_type
             and self._var_type == other._var_type
-            and self._var_state == other._var_state
             and self._var_is_local == other._var_is_local
             and self._var_is_local == other._var_is_local
+            and self._var_full_name_needs_state_prefix
+            == other._var_full_name_needs_state_prefix
+            and self._var_data == other._var_data
         )
         )
 
 
     def to_string(self, json: bool = True) -> Var:
     def to_string(self, json: bool = True) -> Var:
@@ -285,9 +458,11 @@ class Var:
         Returns:
         Returns:
             The formatted var.
             The formatted var.
         """
         """
+        # Encode the _var_data into the formatted output for tracking purposes.
+        str_self = _encode_var(self)
         if self._var_is_local:
         if self._var_is_local:
-            return str(self)
-        return f"${str(self)}"
+            return str_self
+        return f"${str_self}"
 
 
     def __getitem__(self, i: Any) -> Var:
     def __getitem__(self, i: Any) -> Var:
         """Index into a var.
         """Index into a var.
@@ -320,12 +495,7 @@ class Var:
 
 
         # Convert any vars to local vars.
         # Convert any vars to local vars.
         if isinstance(i, Var):
         if isinstance(i, Var):
-            i = BaseVar(
-                _var_name=i._var_name,
-                _var_type=i._var_type,
-                _var_state=i._var_state,
-                _var_is_local=True,
-            )
+            i = i._replace(_var_is_local=True)
 
 
         # Handle list/tuple/str indexing.
         # Handle list/tuple/str indexing.
         if types._issubclass(self._var_type, Union[List, Tuple, str]):
         if types._issubclass(self._var_type, Union[List, Tuple, str]):
@@ -344,11 +514,9 @@ class Var:
                 stop = i.stop or "undefined"
                 stop = i.stop or "undefined"
 
 
                 # Use the slice function.
                 # Use the slice function.
-                return BaseVar(
+                return self._replace(
                     _var_name=f"{self._var_name}.slice({start}, {stop})",
                     _var_name=f"{self._var_name}.slice({start}, {stop})",
-                    _var_type=self._var_type,
-                    _var_state=self._var_state,
-                    _var_is_local=self._var_is_local,
+                    _var_is_string=False,
                 )
                 )
 
 
             # Get the type of the indexed var.
             # Get the type of the indexed var.
@@ -359,11 +527,10 @@ class Var:
             )
             )
 
 
             # Use `at` to support negative indices.
             # Use `at` to support negative indices.
-            return BaseVar(
+            return self._replace(
                 _var_name=f"{self._var_name}.at({i})",
                 _var_name=f"{self._var_name}.at({i})",
                 _var_type=type_,
                 _var_type=type_,
-                _var_state=self._var_state,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
             )
             )
 
 
         # Dictionary / dataframe indexing.
         # Dictionary / dataframe indexing.
@@ -393,11 +560,10 @@ class Var:
         )
         )
 
 
         # Use normal indexing here.
         # Use normal indexing here.
-        return BaseVar(
+        return self._replace(
             _var_name=f"{self._var_name}[{i}]",
             _var_name=f"{self._var_name}[{i}]",
             _var_type=type_,
             _var_type=type_,
-            _var_state=self._var_state,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
         )
 
 
     def __getattr__(self, name: str) -> Var:
     def __getattr__(self, name: str) -> Var:
@@ -423,11 +589,10 @@ class Var:
             type_ = types.get_attribute_access_type(self._var_type, name)
             type_ = types.get_attribute_access_type(self._var_type, name)
 
 
             if type_ is not None:
             if type_ is not None:
-                return BaseVar(
+                return self._replace(
                     _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
                     _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
                     _var_type=type_,
                     _var_type=type_,
-                    _var_state=self._var_state,
-                    _var_is_local=self._var_is_local,
+                    _var_is_string=False,
                 )
                 )
 
 
             if name in REPLACED_NAMES:
             if name in REPLACED_NAMES:
@@ -519,10 +684,12 @@ class Var:
                     else f"{self._var_full_name}.{fn}()"
                     else f"{self._var_full_name}.{fn}()"
                 )
                 )
 
 
-        return BaseVar(
+        return self._replace(
             _var_name=operation_name,
             _var_name=operation_name,
             _var_type=type_,
             _var_type=type_,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
+            merge_var_data=other._var_data if other is not None else None,
         )
         )
 
 
     @staticmethod
     @staticmethod
@@ -602,10 +769,10 @@ class Var:
         """
         """
         if not types._issubclass(self._var_type, List):
         if not types._issubclass(self._var_type, List):
             raise TypeError(f"Cannot get length of non-list var {self}.")
             raise TypeError(f"Cannot get length of non-list var {self}.")
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.length",
+        return self._replace(
+            _var_name=f"{self._var_name}.length",
             _var_type=int,
             _var_type=int,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
         )
 
 
     def __eq__(self, other: Var) -> Var:
     def __eq__(self, other: Var) -> Var:
@@ -692,7 +859,17 @@ class Var:
             types.get_base_class(self._var_type) == list
             types.get_base_class(self._var_type) == list
             and types.get_base_class(other_type) == list
             and types.get_base_class(other_type) == list
         ):
         ):
-            return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip)
+            return self.operation(
+                ",", other, fn="spreadArraysOrObjects", flip=flip
+            )._replace(
+                merge_var_data=VarData(
+                    imports={
+                        f"/{constants.Dirs.STATE_PATH}": [
+                            ImportVar(tag="spreadArraysOrObjects")
+                        ]
+                    },
+                ),
+            )
         return self.operation("+", other, flip=flip)
         return self.operation("+", other, flip=flip)
 
 
     def __radd__(self, other: Var) -> Var:
     def __radd__(self, other: Var) -> Var:
@@ -755,10 +932,11 @@ class Var:
         ]:
         ]:
             other_name = other._var_full_name if isinstance(other, Var) else other
             other_name = other._var_full_name if isinstance(other, Var) else other
             name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
             name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
-            return BaseVar(
+            return self._replace(
                 _var_name=name,
                 _var_name=name,
                 _var_type=str,
                 _var_type=str,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                _var_full_name_needs_state_prefix=False,
             )
             )
 
 
         return self.operation("*", other)
         return self.operation("*", other)
@@ -1003,10 +1181,11 @@ class Var:
         elif not isinstance(other, Var):
         elif not isinstance(other, Var):
             other = Var.create(other)
             other = Var.create(other)
         if types._issubclass(self._var_type, Dict):
         if types._issubclass(self._var_type, Dict):
-            return BaseVar(
-                _var_name=f"{self._var_full_name}.{method}({other._var_full_name})",
+            return self._replace(
+                _var_name=f"{self._var_name}.{method}({other._var_full_name})",
                 _var_type=bool,
                 _var_type=bool,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                merge_var_data=other._var_data,
             )
             )
         else:  # str, list, tuple
         else:  # str, list, tuple
             # For strings, the left operand must be a string.
             # For strings, the left operand must be a string.
@@ -1016,10 +1195,11 @@ class Var:
                 raise TypeError(
                 raise TypeError(
                     f"'in <string>' requires string as left operand, not {other._var_type}"
                     f"'in <string>' requires string as left operand, not {other._var_type}"
                 )
                 )
-            return BaseVar(
-                _var_name=f"{self._var_full_name}.includes({other._var_full_name})",
+            return self._replace(
+                _var_name=f"{self._var_name}.includes({other._var_full_name})",
                 _var_type=bool,
                 _var_type=bool,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                merge_var_data=other._var_data,
             )
             )
 
 
     def reverse(self) -> Var:
     def reverse(self) -> Var:
@@ -1034,10 +1214,10 @@ class Var:
         if not types._issubclass(self._var_type, list):
         if not types._issubclass(self._var_type, list):
             raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
             raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
 
 
-        return BaseVar(
+        return self._replace(
             _var_name=f"[...{self._var_full_name}].reverse()",
             _var_name=f"[...{self._var_full_name}].reverse()",
-            _var_type=self._var_type,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
         )
         )
 
 
     def lower(self) -> Var:
     def lower(self) -> Var:
@@ -1054,10 +1234,10 @@ class Var:
                 f"Cannot convert non-string var {self._var_full_name} to lowercase."
                 f"Cannot convert non-string var {self._var_full_name} to lowercase."
             )
             )
 
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.toLowerCase()",
+        return self._replace(
+            _var_name=f"{self._var_name}.toLowerCase()",
+            _var_is_string=False,
             _var_type=str,
             _var_type=str,
-            _var_is_local=self._var_is_local,
         )
         )
 
 
     def upper(self) -> Var:
     def upper(self) -> Var:
@@ -1074,10 +1254,10 @@ class Var:
                 f"Cannot convert non-string var {self._var_full_name} to uppercase."
                 f"Cannot convert non-string var {self._var_full_name} to uppercase."
             )
             )
 
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.toUpperCase()",
+        return self._replace(
+            _var_name=f"{self._var_name}.toUpperCase()",
+            _var_is_string=False,
             _var_type=str,
             _var_type=str,
-            _var_is_local=self._var_is_local,
         )
         )
 
 
     def split(self, other: str | Var[str] = " ") -> Var:
     def split(self, other: str | Var[str] = " ") -> Var:
@@ -1097,10 +1277,11 @@ class Var:
 
 
         other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
         other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
 
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.split({other._var_full_name})",
+        return self._replace(
+            _var_name=f"{self._var_name}.split({other._var_full_name})",
+            _var_is_string=False,
             _var_type=list[str],
             _var_type=list[str],
-            _var_is_local=self._var_is_local,
+            merge_var_data=other._var_data,
         )
         )
 
 
     def join(self, other: str | Var[str] | None = None) -> Var:
     def join(self, other: str | Var[str] | None = None) -> Var:
@@ -1125,10 +1306,11 @@ class Var:
         else:
         else:
             other = Var.create_safe(other)
             other = Var.create_safe(other)
 
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.join({other._var_full_name})",
+        return self._replace(
+            _var_name=f"{self._var_name}.join({other._var_full_name})",
+            _var_is_string=False,
             _var_type=str,
             _var_type=str,
-            _var_is_local=self._var_is_local,
+            merge_var_data=other._var_data,
         )
         )
 
 
     def foreach(self, fn: Callable) -> Var:
     def foreach(self, fn: Callable) -> Var:
@@ -1159,10 +1341,9 @@ class Var:
         fn_signature = inspect.signature(fn)
         fn_signature = inspect.signature(fn)
         fn_args = (arg, index)
         fn_args = (arg, index)
         fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
         fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
-        return BaseVar(
+        return self._replace(
             _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
             _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
-            _var_type=self._var_type,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
         )
 
 
     @classmethod
     @classmethod
@@ -1207,6 +1388,18 @@ class Var:
             _var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
             _var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
             _var_type=list[int],
             _var_type=list[int],
             _var_is_local=False,
             _var_is_local=False,
+            _var_data=VarData.merge(
+                v1._var_data,
+                v2._var_data,
+                step._var_data,
+                VarData(
+                    imports={
+                        "/utils/helpers/range.js": [
+                            ImportVar(tag="range", is_default=True),
+                        ],
+                    },
+                ),
+            ),
         )
         )
 
 
     def to(self, type_: Type) -> Var:
     def to(self, type_: Type) -> Var:
@@ -1218,12 +1411,7 @@ class Var:
         Returns:
         Returns:
             The converted var.
             The converted var.
         """
         """
-        return BaseVar(
-            _var_name=self._var_name,
-            _var_type=type_,
-            _var_state=self._var_state,
-            _var_is_local=self._var_is_local,
-        )
+        return self._replace(_var_type=type_)
 
 
     @property
     @property
     def _var_full_name(self) -> str:
     def _var_full_name(self) -> str:
@@ -1232,24 +1420,51 @@ class Var:
         Returns:
         Returns:
             The full name of the var.
             The full name of the var.
         """
         """
+        if not self._var_full_name_needs_state_prefix:
+            return self._var_name
         return (
         return (
             self._var_name
             self._var_name
-            if self._var_state == ""
-            else ".".join([self._var_state, self._var_name])
+            if self._var_data is None or self._var_data.state == ""
+            else ".".join(
+                [format.format_state_name(self._var_data.state), self._var_name]
+            )
         )
         )
 
 
-    def _var_set_state(self, state: Type[State]) -> Any:
+    def _var_set_state(self, state: Type[State] | str) -> Any:
         """Set the state of the var.
         """Set the state of the var.
 
 
         Args:
         Args:
-            state: The state to set.
+            state: The state to set or the full name of the state.
 
 
         Returns:
         Returns:
             The var with the set state.
             The var with the set state.
         """
         """
-        self._var_state = state.get_full_name()
+        state_name = state if isinstance(state, str) else state.get_full_name()
+        new_var_data = VarData(
+            state=state_name,
+            hooks={
+                "const {0} = useContext(StateContexts.{0})".format(
+                    format.format_state_name(state_name)
+                )
+            },
+            imports={
+                f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
+                "react": [ImportVar(tag="useContext")],
+            },
+        )
+        self._var_data = VarData.merge(self._var_data, new_var_data)
+        self._var_full_name_needs_state_prefix = True
         return self
         return self
 
 
+    @property
+    def _var_state(self) -> str:
+        """Compat method for getting the state.
+
+        Returns:
+            The state name associated with the var.
+        """
+        return self._var_data.state if self._var_data else ""
+
 
 
 @dataclasses.dataclass(
 @dataclasses.dataclass(
     eq=False,
     eq=False,
@@ -1264,15 +1479,18 @@ class BaseVar(Var):
     # The type of the var.
     # The type of the var.
     _var_type: Type = dataclasses.field(default=Any)
     _var_type: Type = dataclasses.field(default=Any)
 
 
-    # The name of the enclosing state.
-    _var_state: str = dataclasses.field(default="")
-
     # Whether this is a local javascript variable.
     # Whether this is a local javascript variable.
     _var_is_local: bool = dataclasses.field(default=False)
     _var_is_local: bool = dataclasses.field(default=False)
 
 
     # Whether the var is a string literal.
     # Whether the var is a string literal.
     _var_is_string: bool = dataclasses.field(default=False)
     _var_is_string: bool = dataclasses.field(default=False)
 
 
+    # _var_full_name should be prefixed with _var_state
+    _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
+
+    # Extra metadata associated with the Var
+    _var_data: Optional[VarData] = dataclasses.field(default=None)
+
     def __hash__(self) -> int:
     def __hash__(self) -> int:
         """Define a hash function for a var.
         """Define a hash function for a var.
 
 
@@ -1334,9 +1552,11 @@ class BaseVar(Var):
             The name of the setter function.
             The name of the setter function.
         """
         """
         setter = constants.SETTER_PREFIX + self._var_name
         setter = constants.SETTER_PREFIX + self._var_name
-        if not include_state or self._var_state == "":
+        if self._var_data is None:
+            return setter
+        if not include_state or self._var_data.state == "":
             return setter
             return setter
-        return ".".join((self._var_state, setter))
+        return ".".join((self._var_data.state, setter))
 
 
     def get_setter(self) -> Callable[[State, Any], None]:
     def get_setter(self) -> Callable[[State, Any], None]:
         """Get the var's setter function.
         """Get the var's setter function.
@@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
     return cvar
     return cvar
 
 
 
 
-class ImportVar(Base):
-    """An import var."""
-
-    # The name of the import tag.
-    tag: Optional[str]
-
-    # whether the import is default or named.
-    is_default: Optional[bool] = False
-
-    # The tag alias.
-    alias: Optional[str] = None
-
-    # Whether this import need to install the associated lib
-    install: Optional[bool] = True
-
-    # whether this import should be rendered or not
-    render: Optional[bool] = True
-
-    @property
-    def name(self) -> str:
-        """The name of the import.
-
-        Returns:
-            The name(tag name with alias) of tag.
-        """
-        return self.tag if not self.alias else " as ".join([self.tag, self.alias])  # type: ignore
-
-    def __hash__(self) -> int:
-        """Define a hash function for the import var.
-
-        Returns:
-            The hash of the var.
-        """
-        return hash((self.tag, self.is_default, self.alias, self.install, self.render))
-
-
-class NoRenderImportVar(ImportVar):
-    """A import that doesn't need to be rendered."""
-
-    render: Optional[bool] = False
-
-
 class CallableVar(BaseVar):
 class CallableVar(BaseVar):
     """Decorate a Var-returning function to act as both a Var and a function.
     """Decorate a Var-returning function to act as both a Var and a function.
 
 

+ 19 - 19
reflex/vars.pyi

@@ -6,11 +6,13 @@ from reflex import constants as constants
 from reflex.base import Base as Base
 from reflex.base import Base as Base
 from reflex.state import State as State
 from reflex.state import State as State
 from reflex.utils import console as console, format as format, types as types
 from reflex.utils import console as console, format as format, types as types
+from reflex.utils.imports import ImportVar
 from types import FunctionType
 from types import FunctionType
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
     Dict,
     Dict,
+    Iterable,
     List,
     List,
     Optional,
     Optional,
     Set,
     Set,
@@ -22,13 +24,24 @@ from typing import (
 USED_VARIABLES: Incomplete
 USED_VARIABLES: Incomplete
 
 
 def get_unique_variable_name() -> str: ...
 def get_unique_variable_name() -> str: ...
+def _encode_var(value: Var) -> str: ...
+def _decode_var(value: str) -> tuple[VarData, str]: ...
+def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
+
+class VarData(Base):
+    state: str
+    imports: dict[str, set[ImportVar]]
+    hooks: set[str]
+    @classmethod
+    def merge(cls, *others: VarData | None) -> VarData | None: ...
 
 
 class Var:
 class Var:
     _var_name: str
     _var_name: str
     _var_type: Type
     _var_type: Type
-    _var_state: str = ""
     _var_is_local: bool = False
     _var_is_local: bool = False
     _var_is_string: bool = False
     _var_is_string: bool = False
+    _var_full_name_needs_state_prefix: bool = False
+    _var_data: VarData | None = None
     @classmethod
     @classmethod
     def create(
     def create(
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
@@ -38,7 +51,8 @@ class Var:
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
     ) -> Var: ...
     ) -> Var: ...
     @classmethod
     @classmethod
-    def __class_getitem__(cls, type_: str) -> _GenericAlias: ...
+    def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
     def equals(self, other: Var) -> bool: ...
     def equals(self, other: Var) -> bool: ...
     def to_string(self) -> Var: ...
     def to_string(self) -> Var: ...
     def __hash__(self) -> int: ...
     def __hash__(self) -> int: ...
@@ -95,15 +109,16 @@ class Var:
     def to(self, type_: Type) -> Var: ...
     def to(self, type_: Type) -> Var: ...
     @property
     @property
     def _var_full_name(self) -> str: ...
     def _var_full_name(self) -> str: ...
-    def _var_set_state(self, state: Type[State]) -> Any: ...
+    def _var_set_state(self, state: Type[State] | str) -> Any: ...
 
 
 @dataclass(eq=False)
 @dataclass(eq=False)
 class BaseVar(Var):
 class BaseVar(Var):
     _var_name: str
     _var_name: str
     _var_type: Any
     _var_type: Any
-    _var_state: str = ""
     _var_is_local: bool = False
     _var_is_local: bool = False
     _var_is_string: bool = False
     _var_is_string: bool = False
+    _var_full_name_needs_state_prefix: bool = False
+    _var_data: VarData | None = None
     def __hash__(self) -> int: ...
     def __hash__(self) -> int: ...
     def get_default_value(self) -> Any: ...
     def get_default_value(self) -> Any: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
@@ -123,21 +138,6 @@ class ComputedVar(Var):
 
 
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
 
-class ImportVar(Base):
-    tag: Optional[str]
-    is_default: Optional[bool] = False
-    alias: Optional[str] = None
-    install: Optional[bool] = True
-    render: Optional[bool] = True
-    @property
-    def name(self) -> str: ...
-    def __hash__(self) -> int: ...
-
-class NoRenderImportVar(ImportVar):
-    """A import that doesn't need to be rendered."""
-
-def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
-
 class CallableVar(BaseVar):
 class CallableVar(BaseVar):
     def __init__(self, fn: Callable[..., BaseVar]): ...
     def __init__(self, fn: Callable[..., BaseVar]): ...
     def __call__(self, *args, **kwargs) -> BaseVar: ...
     def __call__(self, *args, **kwargs) -> BaseVar: ...

+ 1 - 1
tests/compiler/test_compiler.py

@@ -5,7 +5,7 @@ import pytest
 
 
 from reflex.compiler import compiler, utils
 from reflex.compiler import compiler, utils
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(

+ 1 - 1
tests/components/layout/test_cond.py

@@ -110,7 +110,7 @@ def test_cond_no_else():
 
 
     # Props do not support the use of cond without else
     # Props do not support the use of cond without else
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
-        cond(True, "hello")
+        cond(True, "hello")  # type: ignore
 
 
 
 
 def test_mobile_only():
 def test_mobile_only():

+ 162 - 2
tests/components/test_component.py

@@ -4,14 +4,16 @@ import pytest
 
 
 import reflex as rx
 import reflex as rx
 from reflex.base import Base
 from reflex.base import Base
+from reflex.components.base.bare import Bare
 from reflex.components.component import Component, CustomComponent, custom_component
 from reflex.components.component import Component, CustomComponent, custom_component
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
 from reflex.constants import EventTriggers
 from reflex.constants import EventTriggers
-from reflex.event import EventHandler
+from reflex.event import EventChain, EventHandler
 from reflex.state import State
 from reflex.state import State
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var, VarData
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -600,3 +602,161 @@ def test_format_component(component, rendered):
         rendered: The expected rendered component.
         rendered: The expected rendered component.
     """
     """
     assert str(component) == rendered
     assert str(component) == rendered
+
+
+TEST_VAR = Var.create_safe("test")._replace(
+    merge_var_data=VarData(
+        hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+    )
+)
+FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
+STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False)
+EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
+ARG_VAR = Var.create("arg")
+
+
+class EventState(rx.State):
+    """State for testing event handlers with _get_vars."""
+
+    v: int = 42
+
+    def handler(self):
+        """A handler that does nothing."""
+
+    def handler2(self, arg):
+        """A handler that takes an arg.
+
+        Args:
+            arg: An arg.
+        """
+
+
+@pytest.mark.parametrize(
+    ("component", "exp_vars"),
+    (
+        pytest.param(
+            Bare.create(TEST_VAR),
+            [TEST_VAR],
+            id="direct-bare",
+        ),
+        pytest.param(
+            Bare.create(f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-bare",
+        ),
+        pytest.param(
+            rx.text(as_=TEST_VAR),
+            [TEST_VAR],
+            id="direct-prop",
+        ),
+        pytest.param(
+            rx.text(as_=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-prop",
+        ),
+        pytest.param(
+            rx.fragment(id=TEST_VAR),
+            [TEST_VAR],
+            id="direct-id",
+        ),
+        pytest.param(
+            rx.fragment(id=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-id",
+        ),
+        pytest.param(
+            rx.fragment(key=TEST_VAR),
+            [TEST_VAR],
+            id="direct-key",
+        ),
+        pytest.param(
+            rx.fragment(key=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-key",
+        ),
+        pytest.param(
+            rx.fragment(class_name=TEST_VAR),
+            [TEST_VAR],
+            id="direct-class_name",
+        ),
+        pytest.param(
+            rx.fragment(class_name=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-class_name",
+        ),
+        pytest.param(
+            rx.fragment(special_props={TEST_VAR}),
+            [TEST_VAR],
+            id="direct-special_props",
+        ),
+        pytest.param(
+            rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}),
+            [FORMATTED_TEST_VAR],
+            id="fstring-special_props",
+        ),
+        pytest.param(
+            # custom_attrs cannot accept a Var directly as a value
+            rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}),
+            [TEST_VAR],
+            id="fstring-custom_attrs-nofmt",
+        ),
+        pytest.param(
+            rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}),
+            [FORMATTED_TEST_VAR],
+            id="fstring-custom_attrs",
+        ),
+        pytest.param(
+            rx.fragment(background_color=TEST_VAR),
+            [STYLE_VAR],
+            id="direct-background_color",
+        ),
+        pytest.param(
+            rx.fragment(background_color=f"foo{TEST_VAR}bar"),
+            [STYLE_VAR],
+            id="fstring-background_color",
+        ),
+        pytest.param(
+            rx.fragment(style={"background_color": TEST_VAR}),  # type: ignore
+            [STYLE_VAR],
+            id="direct-style-background_color",
+        ),
+        pytest.param(
+            rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}),  # type: ignore
+            [STYLE_VAR],
+            id="fstring-style-background_color",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EVENT_CHAIN_VAR),  # type: ignore
+            [EVENT_CHAIN_VAR],
+            id="direct-event-chain",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler),
+            [],
+            id="direct-event-handler",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler2(TEST_VAR)),  # type: ignore
+            [ARG_VAR, TEST_VAR],
+            id="direct-event-handler-arg",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler2(EventState.v)),  # type: ignore
+            [ARG_VAR, EventState.v],
+            id="direct-event-handler-arg2",
+        ),
+        pytest.param(
+            rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)),  # type: ignore
+            [ARG_VAR, TEST_VAR],
+            id="direct-event-handler-lambda",
+        ),
+    ),
+)
+def test_get_vars(component, exp_vars):
+    comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name)
+    assert len(comp_vars) == len(exp_vars)
+    for comp_var, exp_var in zip(
+        comp_vars,
+        sorted(exp_vars, key=lambda v: v._var_name),
+    ):
+        assert comp_var.equals(exp_var)

+ 3 - 3
tests/middleware/test_hydrate_middleware.py

@@ -104,7 +104,7 @@ async def test_preprocess(
         app=app, event=request.getfixturevalue(event_fixture), state=state
         app=app, event=request.getfixturevalue(event_fixture), state=state
     )
     )
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
-    assert update.delta == {state.get_name(): state.dict()}
+    assert update.delta == state.dict()
     events = update.events
     events = update.events
     assert len(events) == 2
     assert len(events) == 2
 
 
@@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
 
 
     update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
     update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
-    assert update.delta == {"test_state": state.dict()}
+    assert update.delta == state.dict()
     assert len(update.events) == 3
     assert len(update.events) == 3
 
 
     # Apply the events.
     # Apply the events.
@@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
         state=state,
         state=state,
     )
     )
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
-    assert update.delta == {"test_state": state.dict()}
+    assert update.delta == state.dict()
     assert len(update.events) == 1
     assert len(update.events) == 1
     assert isinstance(update, StateUpdate)
     assert isinstance(update, StateUpdate)
 
 

+ 1 - 3
tests/test_app.py

@@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
         )
         )
 
 
     current_state = await app.state_manager.get_state(token)
     current_state = await app.state_manager.get_state(token)
-    state_dict = current_state.dict()
-    for substate in state.get_full_name().split(".")[1:]:
-        state_dict = state_dict[substate]
+    state_dict = current_state.dict()[state.get_full_name()]
     assert state_dict["img_list"] == [
     assert state_dict["img_list"] == [
         "image1.jpg",
         "image1.jpg",
         "image2.jpg",
         "image2.jpg",

+ 24 - 15
tests/test_state.py

@@ -324,11 +324,17 @@ def test_dict(test_state):
     Args:
     Args:
         test_state: A state.
         test_state: A state.
     """
     """
-    substates = {"child_state", "child_state2"}
-    assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
-    assert (
-        set(test_state.dict(include_computed=False).keys())
-        == set(test_state.base_vars) | substates
+    substates = {
+        "test_state",
+        "test_state.child_state",
+        "test_state.child_state.grandchild_state",
+        "test_state.child_state2",
+    }
+    test_state_dict = test_state.dict()
+    assert set(test_state_dict) == substates
+    assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars)
+    assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set(
+        test_state.base_vars
     )
     )
 
 
 
 
@@ -1081,9 +1087,9 @@ def test_computed_var_cached():
             return self.v
             return self.v
 
 
     cs = ComputedState()
     cs = ComputedState()
-    assert cs.dict()["v"] == 0
+    assert cs.dict()[cs.get_full_name()]["v"] == 0
     assert comp_v_calls == 1
     assert comp_v_calls == 1
-    assert cs.dict()["comp_v"] == 0
+    assert cs.dict()[cs.get_full_name()]["comp_v"] == 0
     assert comp_v_calls == 1
     assert comp_v_calls == 1
     assert cs.comp_v == 0
     assert cs.comp_v == 0
     assert comp_v_calls == 1
     assert comp_v_calls == 1
@@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached():
     assert ps.dirty_vars == set()
     assert ps.dirty_vars == set()
     assert cs.dirty_vars == set()
     assert cs.dirty_vars == set()
 
 
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 2},
+    dict1 = ps.dict()
+    assert dict1[ps.get_full_name()] == {
         "no_cache_v": 1,
         "no_cache_v": 1,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 4},
+    assert dict1[cs.get_full_name()] == {"dep_v": 2}
+    dict2 = ps.dict()
+    assert dict2[ps.get_full_name()] == {
         "no_cache_v": 3,
         "no_cache_v": 3,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 6},
+    assert dict2[cs.get_full_name()] == {"dep_v": 4}
+    dict3 = ps.dict()
+    assert dict3[ps.get_full_name()] == {
         "no_cache_v": 5,
         "no_cache_v": 5,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
         "router": formatted_router,
     }
     }
+    assert dict3[cs.get_full_name()] == {"dep_v": 6}
     assert counter == 6
     assert counter == 6
 
 
 
 
@@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables():
         items: List[Foo] = [Foo()]
         items: List[Foo] = [Foo()]
 
 
     dict_val = MutableContainsBase().dict()
     dict_val = MutableContainsBase().dict()
-    assert isinstance(dict_val["items"][0], dict)
+    assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict)
     val = json_dumps(dict_val)
     val = json_dumps(dict_val)
     f_items = '[{"tags": ["123", "456"]}]'
     f_items = '[{"tags": ["123", "456"]}]'
     f_formatted_router = str(formatted_router).replace("'", '"')
     f_formatted_router = str(formatted_router).replace("'", '"')
     assert (
     assert (
         val
         val
-        == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
+        == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
     )
     )
 
 
 
 

+ 2 - 1
tests/test_style.py

@@ -22,7 +22,8 @@ def test_convert(style_dict, expected):
         style_dict: The style to check.
         style_dict: The style to check.
         expected: The expected formatted style.
         expected: The expected formatted style.
     """
     """
-    assert style.convert(style_dict) == expected
+    converted_dict, _var_data = style.convert(style_dict)
+    assert converted_dict == expected
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(

+ 50 - 16
tests/test_var.py

@@ -7,18 +7,20 @@ from pandas import DataFrame
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.state import State
 from reflex.state import State
+from reflex.utils.imports import ImportVar
 from reflex.vars import (
 from reflex.vars import (
     BaseVar,
     BaseVar,
     ComputedVar,
     ComputedVar,
-    ImportVar,
     Var,
     Var,
 )
 )
 
 
 test_vars = [
 test_vars = [
     BaseVar(_var_name="prop1", _var_type=int),
     BaseVar(_var_name="prop1", _var_type=int),
     BaseVar(_var_name="key", _var_type=str),
     BaseVar(_var_name="key", _var_type=str),
-    BaseVar(_var_name="value", _var_type=str, _var_state="state"),
-    BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True),
+    BaseVar(_var_name="value", _var_type=str)._var_set_state("state"),
+    BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state(
+        "state"
+    ),
     BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
     BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
 ]
 ]
 
 
@@ -263,7 +265,7 @@ def test_basic_operations(TestObj):
     assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
     assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
     assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
     assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
     assert (
     assert (
-        str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar)
+        str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar)
         == "{state.foo.bar}"
         == "{state.foo.bar}"
     )
     )
     assert str(abs(v(1))) == "{Math.abs(1)}"
     assert str(abs(v(1))) == "{Math.abs(1)}"
@@ -274,7 +276,7 @@ def test_basic_operations(TestObj):
     assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
     assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
     assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
     assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
     assert (
     assert (
-        str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse())
+        str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse())
         == "{[...state.foo].reverse()}"
         == "{[...state.foo].reverse()}"
     )
     )
     assert (
     assert (
@@ -288,11 +290,14 @@ def test_basic_operations(TestObj):
     [
     [
         (v([1, 2, 3]), "[1, 2, 3]"),
         (v([1, 2, 3]), "[1, 2, 3]"),
         (v(["1", "2", "3"]), '["1", "2", "3"]'),
         (v(["1", "2", "3"]), '["1", "2", "3"]'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=list), "foo"),
         (BaseVar(_var_name="foo", _var_type=list), "foo"),
         (v((1, 2, 3)), "[1, 2, 3]"),
         (v((1, 2, 3)), "[1, 2, 3]"),
         (v(("1", "2", "3")), '["1", "2", "3"]'),
         (v(("1", "2", "3")), '["1", "2", "3"]'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"),
+        (
+            BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"),
+            "state.foo",
+        ),
         (BaseVar(_var_name="foo", _var_type=tuple), "foo"),
         (BaseVar(_var_name="foo", _var_type=tuple), "foo"),
     ],
     ],
 )
 )
@@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
     assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected):
     "var, expected",
     "var, expected",
     [
     [
         (v("123"), json.dumps("123")),
         (v("123"), json.dumps("123")),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=str), "foo"),
         (BaseVar(_var_name="foo", _var_type=str), "foo"),
     ],
     ],
 )
 )
 def test_str_contains(var, expected):
 def test_str_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -328,7 +333,7 @@ def test_str_contains(var, expected):
     "var, expected",
     "var, expected",
     [
     [
         (v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
         (v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=dict), "foo"),
         (BaseVar(_var_name="foo", _var_type=dict), "foo"),
     ],
     ],
 )
 )
@@ -337,7 +342,7 @@ def test_dict_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
     assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
     assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
     assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
     assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
     assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert (
     assert (
         str(var.contains(other_state_var))
         str(var.contains(other_state_var))
@@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index):
     "fixture,full_name",
     "fixture,full_name",
     [
     [
         ("ParentState", "parent_state.var_without_annotation"),
         ("ParentState", "parent_state.var_without_annotation"),
-        ("ChildState", "parent_state.child_state.var_without_annotation"),
+        ("ChildState", "parent_state__child_state.var_without_annotation"),
         (
         (
             "GrandChildState",
             "GrandChildState",
-            "parent_state.child_state.grand_child_state.var_without_annotation",
+            "parent_state__child_state__grand_child_state.var_without_annotation",
         ),
         ),
         ("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
         ("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
     ],
     ],
@@ -630,8 +635,8 @@ def test_import_var(import_var, expected):
     [
     [
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
         (
-            f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}",
-            "testing f-string with ${state.myvar}",
+            f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
+            'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
         ),
         ),
         (
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
@@ -643,6 +648,35 @@ def test_fstrings(out, expected):
     assert out == expected
     assert out == expected
 
 
 
 
+@pytest.mark.parametrize(
+    ("value", "expect_state"),
+    [
+        ([1], ""),
+        ({"a": 1}, ""),
+        ([Var.create_safe(1)._var_set_state("foo")], "foo"),
+        ({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"),
+    ],
+)
+def test_extract_state_from_container(value, expect_state):
+    """Test that _var_state is extracted from containers containing BaseVar.
+
+    Args:
+        value: The value to create a var from.
+        expect_state: The expected state.
+    """
+    assert Var.create_safe(value)._var_state == expect_state
+
+
+def test_fstring_roundtrip():
+    """Test that f-string roundtrip carries state."""
+    var = BaseVar.create_safe("var")._var_set_state("state")
+    rt_var = Var.create_safe(f"{var}")
+    assert var._var_state == rt_var._var_state
+    assert var._var_full_name_needs_state_prefix
+    assert not rt_var._var_full_name_needs_state_prefix
+    assert rt_var._var_name == var._var_full_name
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "var",
     "var",
     [
     [

+ 37 - 28
tests/utils/test_format.py

@@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
-from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState
+from tests.test_state import (
+    ChildState,
+    ChildState2,
+    DateTimeState,
+    GrandchildState,
+    TestState,
+)
 
 
 
 
 def mock_event(arg):
 def mock_event(arg):
@@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
             BaseVar(
             BaseVar(
                 _var_name="_",
                 _var_name="_",
                 _var_type=Any,
                 _var_type=Any,
-                _var_state="",
                 _var_is_local=True,
                 _var_is_local=True,
                 _var_is_string=False,
                 _var_is_string=False,
             ),
             ),
@@ -515,40 +520,44 @@ formatted_router = {
         (
         (
             TestState().dict(),  # type: ignore
             TestState().dict(),  # type: ignore
             {
             {
-                "array": [1, 2, 3.14],
-                "child_state": {
+                TestState.get_full_name(): {
+                    "array": [1, 2, 3.14],
+                    "complex": {
+                        1: {"prop1": 42, "prop2": "hello"},
+                        2: {"prop1": 42, "prop2": "hello"},
+                    },
+                    "dt": "1989-11-09 18:53:00+01:00",
+                    "fig": [],
+                    "is_hydrated": False,
+                    "key": "",
+                    "map_key": "a",
+                    "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
+                    "num1": 0,
+                    "num2": 3.14,
+                    "obj": {"prop1": 42, "prop2": "hello"},
+                    "sum": 3.14,
+                    "upper": "",
+                    "router": formatted_router,
+                },
+                ChildState.get_full_name(): {
                     "count": 23,
                     "count": 23,
-                    "grandchild_state": {"value2": ""},
                     "value": "",
                     "value": "",
                 },
                 },
-                "child_state2": {"value": ""},
-                "complex": {
-                    1: {"prop1": 42, "prop2": "hello"},
-                    2: {"prop1": 42, "prop2": "hello"},
-                },
-                "dt": "1989-11-09 18:53:00+01:00",
-                "fig": [],
-                "is_hydrated": False,
-                "key": "",
-                "map_key": "a",
-                "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
-                "num1": 0,
-                "num2": 3.14,
-                "obj": {"prop1": 42, "prop2": "hello"},
-                "sum": 3.14,
-                "upper": "",
-                "router": formatted_router,
+                ChildState2.get_full_name(): {"value": ""},
+                GrandchildState.get_full_name(): {"value2": ""},
             },
             },
         ),
         ),
         (
         (
             DateTimeState().dict(),
             DateTimeState().dict(),
             {
             {
-                "d": "1989-11-09",
-                "dt": "1989-11-09 18:53:00+01:00",
-                "is_hydrated": False,
-                "t": "18:53:00+01:00",
-                "td": "11 days, 0:11:00",
-                "router": formatted_router,
+                DateTimeState.get_full_name(): {
+                    "d": "1989-11-09",
+                    "dt": "1989-11-09 18:53:00+01:00",
+                    "is_hydrated": False,
+                    "t": "18:53:00+01:00",
+                    "td": "11 days, 0:11:00",
+                    "router": formatted_router,
+                },
             },
             },
         ),
         ),
     ],
     ],