Quellcode durchsuchen

[REF-889] useContext per substate (#2149)

Masen Furer vor 1 Jahr
Ursprung
Commit
1603144c7d
65 geänderte Dateien mit 1257 neuen und 455 gelöschten Zeilen
  1. 21 0
      integration/test_var_operations.py
  2. 7 5
      reflex/.templates/jinja/web/pages/_app.js.jinja2
  3. 0 26
      reflex/.templates/jinja/web/pages/index.js.jinja2
  4. 38 9
      reflex/.templates/jinja/web/utils/context.js.jinja2
  5. 21 38
      reflex/.templates/web/utils/state.js
  6. 1 1
      reflex/app.py
  7. 2 32
      reflex/compiler/compiler.py
  8. 2 1
      reflex/compiler/templates.py
  9. 4 3
      reflex/compiler/utils.py
  10. 16 2
      reflex/components/base/bare.py
  11. 192 16
      reflex/components/component.py
  12. 2 1
      reflex/components/datadisplay/code.py
  13. 2 1
      reflex/components/datadisplay/code.pyi
  14. 2 1
      reflex/components/datadisplay/dataeditor.py
  15. 2 1
      reflex/components/datadisplay/dataeditor.pyi
  16. 6 6
      reflex/components/datadisplay/datatable.py
  17. 1 1
      reflex/components/datadisplay/datatable.pyi
  18. 2 2
      reflex/components/datadisplay/moment.py
  19. 1 1
      reflex/components/datadisplay/moment.pyi
  20. 2 2
      reflex/components/forms/colormodeswitch.py
  21. 2 2
      reflex/components/forms/colormodeswitch.pyi
  22. 13 1
      reflex/components/forms/debounce.py
  23. 2 1
      reflex/components/forms/debounce.pyi
  24. 2 1
      reflex/components/forms/editor.py
  25. 2 1
      reflex/components/forms/editor.pyi
  26. 8 2
      reflex/components/forms/form.py
  27. 1 1
      reflex/components/forms/form.pyi
  28. 2 2
      reflex/components/forms/input.py
  29. 1 1
      reflex/components/forms/input.pyi
  30. 3 1
      reflex/components/forms/pininput.py
  31. 27 10
      reflex/components/forms/upload.py
  32. 2 1
      reflex/components/forms/upload.pyi
  33. 44 7
      reflex/components/layout/cond.py
  34. 3 3
      reflex/components/layout/html.py
  35. 5 2
      reflex/components/layout/html.pyi
  36. 10 10
      reflex/components/libs/chakra.py
  37. 1 1
      reflex/components/libs/chakra.pyi
  38. 7 8
      reflex/components/navigation/client_side_routing.py
  39. 3 3
      reflex/components/navigation/client_side_routing.pyi
  40. 13 8
      reflex/components/overlay/banner.py
  41. 4 3
      reflex/components/overlay/banner.pyi
  42. 2 2
      reflex/components/radix/themes/base.py
  43. 1 1
      reflex/components/radix/themes/base.pyi
  44. 2 1
      reflex/components/typography/markdown.py
  45. 2 1
      reflex/components/typography/markdown.pyi
  46. 4 0
      reflex/constants/__init__.py
  47. 2 0
      reflex/constants/base.py
  48. 25 0
      reflex/constants/compiler.py
  49. 1 1
      reflex/middleware/hydrate_middleware.py
  50. 9 5
      reflex/state.py
  51. 71 7
      reflex/style.py
  52. 19 6
      reflex/utils/format.py
  53. 41 4
      reflex/utils/imports.py
  54. 1 0
      reflex/utils/types.py
  55. 298 120
      reflex/vars.py
  56. 19 19
      reflex/vars.pyi
  57. 1 1
      tests/compiler/test_compiler.py
  58. 1 1
      tests/components/layout/test_cond.py
  59. 162 2
      tests/components/test_component.py
  60. 3 3
      tests/middleware/test_hydrate_middleware.py
  61. 1 3
      tests/test_app.py
  62. 24 15
      tests/test_state.py
  63. 2 1
      tests/test_style.py
  64. 50 16
      tests/test_var.py
  65. 37 28
      tests/utils/test_format.py

+ 21 - 0
integration/test_var_operations.py

@@ -28,6 +28,7 @@ def VarOperations():
         str_var4: str = "a long string"
         dict1: dict = {1: 2}
         dict2: dict = {3: 4}
+        html_str: str = "<div>hello</div>"
 
     app = rx.App(state=VarOperationState)
 
@@ -522,6 +523,19 @@ def VarOperations():
             rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
             rx.text(VarOperationState.list3.join(""), id="list_join"),
             rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
+            # Index from an op var
+            rx.text(
+                VarOperationState.list3[VarOperationState.int_var1 % 3],
+                id="list_index_mod",
+            ),
+            rx.html(
+                VarOperationState.html_str,
+                id="html_str",
+            ),
+            rx.highlight(
+                "second",
+                query=[VarOperationState.str_var2],
+            ),
             rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
             rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
             rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
@@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("dict_eq_dict", "false"),
         ("dict_neq_dict", "true"),
         ("dict_contains", "true"),
+        # index from an op var
+        ("list_index_mod", "second"),
+        # html component with var
+        ("html_str", "hello"),
     ]
 
     for tag, expected in tests:
         assert driver.find_element(By.ID, tag).text == expected
+
+    # Highlight component with var query (does not plumb ID)
+    assert driver.find_element(By.TAG_NAME, "mark").text == "second"

+ 7 - 5
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -1,7 +1,7 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 
 {% block declaration %}
-import { EventLoopProvider } from "/utils/context.js";
+import { EventLoopProvider, StateProvider } from "/utils/context.js";
 import { ThemeProvider } from 'next-themes'
 
 {% for custom_code in custom_codes %}
@@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
   return (
     <ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
       <AppWrap>
-        <EventLoopProvider>
-          <Component {...pageProps} />
-        </EventLoopProvider>
+        <StateProvider>
+          <EventLoopProvider>
+            <Component {...pageProps} />
+          </EventLoopProvider>
+        </StateProvider>
       </AppWrap>
     </ThemeProvider>
   );
 }
 
-{% endblock %}
+{% endblock %}

+ 0 - 26
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -8,32 +8,6 @@
 
 {% block export %}
 export default function Component() {
-{% if state_name %}
-  const {{state_name}} = useContext(StateContext)
-{% endif %}
-  const {{const.router}} = useRouter()
-  const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
-  const focusRef = useRef();
-  
-  // Main event loop.
-  const [addEvents, connectError] = useContext(EventLoopContext)
-
-  // Set focus to the specified element.
-  useEffect(() => {
-    if (focusRef.current) {
-      focusRef.current.focus();
-    }
-  })
-
-  // Route after the initial page hydration.
-  useEffect(() => {
-    const change_complete = () => addEvents(initialEvents())
-    {{const.router}}.events.on('routeChangeComplete', change_complete)
-    return () => {
-      {{const.router}}.events.off('routeChangeComplete', change_complete)
-    }
-  }, [{{const.router}}])
-
   {% for hook in hooks %}
   {{ hook }}
   {% endfor %}

+ 38 - 9
reflex/.templates/jinja/web/utils/context.js.jinja2

@@ -1,5 +1,5 @@
-import { createContext, useState } from "react"
-import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
+import { createContext, useContext, useMemo, useReducer, useState } from "react"
+import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
 
 {% if initial_state %}
 export const initialState = {{ initial_state|json_dumps }}
@@ -8,7 +8,12 @@ export const initialState = {}
 {% endif %}
 
 export const ColorModeContext = createContext(null);
-export const StateContext = createContext(null);
+export const DispatchContext = createContext(null);
+export const StateContexts = {
+  {% for state_name in initial_state %}
+  {{state_name|var_name}}: createContext(null),
+  {% endfor %}
+}
 export const EventLoopContext = createContext(null);
 {% if client_storage %}
 export const clientStorage = {{ client_storage|json_dumps }}
@@ -27,16 +32,40 @@ export const initialEvents = () => []
 export const isDevMode = {{ is_dev_mode|json_dumps }}
 
 export function EventLoopProvider({ children }) {
-  const [state, addEvents, connectError] = useEventLoop(
-    initialState,
+  const dispatch = useContext(DispatchContext)
+  const [addEvents, connectError] = useEventLoop(
+    dispatch,
     initialEvents,
     clientStorage,
   )
   return (
     <EventLoopContext.Provider value={[addEvents, connectError]}>
-      <StateContext.Provider value={state}>
-        {children}
-      </StateContext.Provider>
+      {children}
     </EventLoopContext.Provider>
   )
-}
+}
+
+export function StateProvider({ children }) {
+  {% for state_name in initial_state %}
+  const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
+  {% endfor %}
+  const dispatchers = useMemo(() => {
+    return {
+      {% for state_name in initial_state %}
+      "{{state_name}}": dispatch_{{state_name|var_name}},
+      {% endfor %}
+    }
+  }, [])
+
+  return (
+    {% for state_name in initial_state %}
+    <StateContexts.{{state_name|var_name}}.Provider value={ {{state_name|var_name}} }>
+    {% endfor %}
+      <DispatchContext.Provider value={dispatchers}>
+        {children}
+      </DispatchContext.Provider>
+    {% for state_name in initial_state|reverse %}
+    </StateContexts.{{state_name|var_name}}.Provider>
+    {% endfor %}
+  )
+}

+ 21 - 38
reflex/.templates/web/utils/state.js

@@ -6,7 +6,7 @@ import env from "env.json";
 import Cookies from "universal-cookie";
 import { useEffect, useReducer, useRef, useState } from "react";
 import Router, { useRouter } from "next/router";
-import { initialEvents } from "utils/context.js"
+import { initialEvents, initialState } from "utils/context.js"
 
 // Endpoint URLs.
 const EVENTURL = env.EVENT
@@ -100,37 +100,10 @@ export const getEventURL = () => {
  * @param delta The delta to apply.
  */
 export const applyDelta = (state, delta) => {
-  const new_state = { ...state }
-  for (const substate in delta) {
-    let s = new_state;
-    const path = substate.split(".").slice(1);
-    while (path.length > 0) {
-      s = s[path.shift()];
-    }
-    for (const key in delta[substate]) {
-      s[key] = delta[substate][key];
-    }
-  }
-  return new_state
+  return { ...state, ...delta }
 };
 
 
-/**
- * Get all local storage items in a key-value object.
- * @returns object of items in local storage.
- */
-export const getAllLocalStorageItems = () => {
-  var localStorageItems = {};
-
-  for (var i = 0, len = localStorage.length; i < len; i++) {
-    var key = localStorage.key(i);
-    localStorageItems[key] = localStorage.getItem(key);
-  }
-
-  return localStorageItems;
-}
-
-
 /**
  * Handle frontend event or send the event to the backend via Websocket.
  * @param event The event to send.
@@ -346,7 +319,9 @@ export const connect = async (
   // On each received message, queue the updates and events.
   socket.current.on("event", message => {
     const update = JSON5.parse(message)
-    dispatch(update.delta)
+    for (const substate in update.delta) {
+      dispatch[substate](update.delta[substate])
+    }
     applyClientStorageDelta(client_storage, update.delta)
     event_processing = !update.final
     if (update.events) {
@@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {
 
 /**
  * Establish websocket event loop for a NextJS page.
- * @param initial_state The initial app state.
- * @param initial_events Function that returns the initial app events.
+ * @param dispatch The reducer dispatch function to update state.
+ * @param initial_events The initial app events.
  * @param client_storage The client storage object from context.js
  *
- * @returns [state, addEvents, connectError] -
- *   state is a reactive dict,
+ * @returns [addEvents, connectError] -
  *   addEvents is used to queue an event, and
  *   connectError is a reactive js error from the websocket connection (or null if connected).
  */
 export const useEventLoop = (
-  initial_state = {},
+  dispatch,
   initial_events = () => [],
   client_storage = {},
 ) => {
   const socket = useRef(null)
   const router = useRouter()
-  const [state, dispatch] = useReducer(applyDelta, initial_state)
   const [connectError, setConnectError] = useState(null)
 
   // Function to add new events to the event queue.
@@ -570,7 +543,7 @@ export const useEventLoop = (
       return;
     }
     // only use websockets if state is present
-    if (Object.keys(state).length > 0) {
+    if (Object.keys(initialState).length > 0) {
       // Initialize the websocket connection.
       if (!socket.current) {
         connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
@@ -583,7 +556,17 @@ export const useEventLoop = (
       })()
     }
   })
-  return [state, addEvents, connectError]
+
+  // Route after the initial page hydration.
+  useEffect(() => {
+    const change_complete = () => addEvents(initial_events())
+    router.events.on('routeChangeComplete', change_complete)
+    return () => {
+      router.events.off('routeChangeComplete', change_complete)
+    }
+  }, [router])
+
+  return [addEvents, connectError]
 }
 
 /***

+ 1 - 1
reflex/app.py

@@ -63,7 +63,7 @@ from reflex.state import (
     StateUpdate,
 )
 from reflex.utils import console, format, prerequisites, types
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
 
 # Define custom types.
 ComponentCallable = Callable[[], Component]

+ 2 - 32
reflex/compiler/compiler.py

@@ -10,40 +10,10 @@ from reflex.compiler import templates, utils
 from reflex.components.component import Component, ComponentStyle, CustomComponent
 from reflex.config import get_config
 from reflex.state import State
-from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 
 # Imports to be included in every Reflex app.
-DEFAULT_IMPORTS: imports.ImportDict = {
-    "react": [
-        ImportVar(tag="Fragment"),
-        ImportVar(tag="useEffect"),
-        ImportVar(tag="useRef"),
-        ImportVar(tag="useState"),
-        ImportVar(tag="useContext"),
-    ],
-    "next/router": [ImportVar(tag="useRouter")],
-    f"/{constants.Dirs.STATE_PATH}": [
-        ImportVar(tag="uploadFiles"),
-        ImportVar(tag="Event"),
-        ImportVar(tag="isTrue"),
-        ImportVar(tag="spreadArraysOrObjects"),
-        ImportVar(tag="preventDefault"),
-        ImportVar(tag="refs"),
-        ImportVar(tag="getRefValue"),
-        ImportVar(tag="getRefValues"),
-        ImportVar(tag="getAllLocalStorageItems"),
-        ImportVar(tag="useEventLoop"),
-    ],
-    "/utils/context.js": [
-        ImportVar(tag="EventLoopContext"),
-        ImportVar(tag="initialEvents"),
-        ImportVar(tag="StateContext"),
-        ImportVar(tag="ColorModeContext"),
-    ],
-    "/utils/helpers/range.js": [
-        ImportVar(tag="range", is_default=True),
-    ],
+DEFAULT_IMPORTS: ImportDict = {
     "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
 }
 

+ 2 - 1
reflex/compiler/templates.py

@@ -3,7 +3,7 @@
 from jinja2 import Environment, FileSystemLoader, Template
 
 from reflex import constants
-from reflex.utils.format import json_dumps
+from reflex.utils.format import format_state_name, json_dumps
 
 
 class ReflexJinjaEnvironment(Environment):
@@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment):
         )
         self.filters["json_dumps"] = json_dumps
         self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
+        self.filters["var_name"] = format_state_name
         self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
         self.globals["const"] = {
             "socket": constants.CompileVars.SOCKET,

+ 4 - 3
reflex/compiler/utils.py

@@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
 from reflex.state import Cookie, LocalStorage, State
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
-from reflex.vars import ImportVar
 
 # To re-export this function.
 merge_imports = imports.merge_imports
 
 
-def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
+def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
     """Compile an import statement.
 
     Args:
@@ -343,7 +342,9 @@ def get_context_path() -> str:
     Returns:
         The path of the context module.
     """
-    return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
+    return os.path.join(
+        constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
+    )
 
 
 def get_components_path() -> str:

+ 16 - 2
reflex/components/base/bare.py

@@ -1,7 +1,7 @@
 """A bare component."""
 from __future__ import annotations
 
-from typing import Any
+from typing import Any, Iterator
 
 from reflex.components.component import Component
 from reflex.components.tags import Tag
@@ -24,7 +24,21 @@ class Bare(Component):
         Returns:
             The component.
         """
-        return cls(contents=str(contents))  # type: ignore
+        if isinstance(contents, Var) and contents._var_data:
+            contents = contents.to(str)
+        else:
+            contents = str(contents)
+        return cls(contents=contents)  # type: ignore
 
     def _render(self) -> Tag:
         return Tagless(contents=str(self.contents))
+
+    def _get_vars(self) -> Iterator[Var]:
+        """Walk all Vars used in this component.
+
+        Yields:
+            The contents if it is a Var, otherwise nothing.
+        """
+        if isinstance(self.contents, Var):
+            # Fast path for Bare text components.
+            yield self.contents

+ 192 - 16
reflex/components/component.py

@@ -5,11 +5,11 @@ from __future__ import annotations
 import typing
 from abc import ABC
 from functools import lru_cache, wraps
-from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union
 
 from reflex.base import Base
 from reflex.components.tags import Tag
-from reflex.constants import Dirs, EventTriggers
+from reflex.constants import Dirs, EventTriggers, Hooks, Imports
 from reflex.event import (
     EventChain,
     EventHandler,
@@ -20,8 +20,9 @@ from reflex.event import (
 )
 from reflex.style import Style
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
-from reflex.vars import BaseVar, ImportVar, Var
+from reflex.vars import BaseVar, Var
 
 
 class Component(Base, ABC):
@@ -388,7 +389,11 @@ class Component(Base, ABC):
             props = props.copy()
 
         props.update(
-            self.event_triggers,
+            **{
+                trigger: handler
+                for trigger, handler in self.event_triggers.items()
+                if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT}
+            },
             key=self.key,
             id=self.id,
             class_name=self.class_name,
@@ -488,7 +493,7 @@ class Component(Base, ABC):
         """
         if type(self) in style:
             # Extract the style for this component.
-            component_style = Style(style[type(self)])
+            component_style = style[type(self)]
 
             # Only add style props that are not overridden.
             component_style = {
@@ -564,6 +569,78 @@ class Component(Base, ABC):
             if self._valid_children:
                 validate_valid_child(name)
 
+    @staticmethod
+    def _get_vars_from_event_triggers(
+        event_triggers: dict[str, EventChain | Var],
+    ) -> Iterator[tuple[str, list[Var]]]:
+        """Get the Vars associated with each event trigger.
+
+        Args:
+            event_triggers: The event triggers from the component instance.
+
+        Yields:
+            tuple of (event_name, event_vars)
+        """
+        for event_trigger, event in event_triggers.items():
+            if isinstance(event, Var):
+                yield event_trigger, [event]
+            elif isinstance(event, EventChain):
+                event_args = []
+                for spec in event.events:
+                    for args in spec.args:
+                        event_args.extend(args)
+                yield event_trigger, event_args
+
+    def _get_vars(self) -> list[Var]:
+        """Walk all Vars used in this component.
+
+        Returns:
+            Each var referenced by the component (props, styles, event handlers).
+        """
+        vars = getattr(self, "__vars", None)
+        if vars is not None:
+            return vars
+        vars = self.__vars = []
+        # Get Vars associated with event trigger arguments.
+        for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers):
+            vars.extend(event_vars)
+
+        # Get Vars associated with component props.
+        for prop in self.get_props():
+            prop_var = getattr(self, prop)
+            if isinstance(prop_var, Var):
+                vars.append(prop_var)
+
+        # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded.
+        if self.style:
+            vars.append(
+                BaseVar(
+                    _var_name="style",
+                    _var_type=str,
+                    _var_data=self.style._var_data,
+                )
+            )
+
+        # Special props are always Var instances.
+        vars.extend(self.special_props)
+
+        # Get Vars associated with common Component props.
+        for comp_prop in (
+            self.class_name,
+            self.id,
+            self.key,
+            self.autofocus,
+            *self.custom_attrs.values(),
+        ):
+            if isinstance(comp_prop, Var):
+                vars.append(comp_prop)
+            elif isinstance(comp_prop, str):
+                # Collapse VarData encoded in f-strings.
+                var = Var.create_safe(comp_prop)
+                if var._var_data is not None:
+                    vars.append(var)
+        return vars
+
     def _get_custom_code(self) -> str | None:
         """Get custom code for the component.
 
@@ -644,6 +721,33 @@ class Component(Base, ABC):
             dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
         }
 
+    def _get_hooks_imports(self) -> imports.ImportDict:
+        """Get the imports required by certain hooks.
+
+        Returns:
+            The imports required for all selected hooks.
+        """
+        _imports = {}
+
+        if self._get_ref_hook():
+            # Handle hooks needed for attaching react refs to DOM nodes.
+            _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
+            _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+
+        if self._get_mount_lifecycle_hook():
+            # Handle hooks for `on_mount` / `on_unmount`.
+            _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+
+        if self._get_special_hooks():
+            # Handle additional internal hooks (autofocus, etc).
+            _imports.setdefault("react", set()).update(
+                {
+                    ImportVar(tag="useRef"),
+                    ImportVar(tag="useEffect"),
+                },
+            )
+        return _imports
+
     def _get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
 
@@ -651,13 +755,26 @@ class Component(Base, ABC):
             The imports needed by the component.
         """
         _imports = {}
+
+        # Import this component's tag from the main library.
         if self.library is not None and self.tag is not None:
             _imports[self.library] = {self.import_var}
 
+        # Get static imports required for event processing.
+        event_imports = Imports.EVENTS if self.event_triggers else {}
+
+        # Collect imports from Vars used directly by this component.
+        var_imports = [
+            var._var_data.imports for var in self._get_vars() if var._var_data
+        ]
+
         return imports.merge_imports(
             *self._get_props_imports(),
             self._get_dependencies_imports(),
+            self._get_hooks_imports(),
             _imports,
+            event_imports,
+            *var_imports,
         )
 
     def get_imports(self) -> imports.ImportDict:
@@ -678,13 +795,13 @@ class Component(Base, ABC):
         """
         # pop on_mount and on_unmount from event_triggers since these are handled by
         # hooks, not as actually props in the component
-        on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None)
-        on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None)
-        if on_mount:
+        on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None)
+        on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None)
+        if on_mount is not None:
             on_mount = format.format_event_chain(on_mount)
-        if on_unmount:
+        if on_unmount is not None:
             on_unmount = format.format_event_chain(on_unmount)
-        if on_mount or on_unmount:
+        if on_mount is not None or on_unmount is not None:
             return f"""
                 useEffect(() => {{
                     {on_mount or ""}
@@ -703,6 +820,47 @@ class Component(Base, ABC):
         if ref is not None:
             return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
 
+    def _get_vars_hooks(self) -> set[str]:
+        """Get the hooks required by vars referenced in this component.
+
+        Returns:
+            The hooks for the vars.
+        """
+        vars_hooks = set()
+        for var in self._get_vars():
+            if var._var_data:
+                vars_hooks.update(var._var_data.hooks)
+        return vars_hooks
+
+    def _get_events_hooks(self) -> set[str]:
+        """Get the hooks required by events referenced in this component.
+
+        Returns:
+            The hooks for the events.
+        """
+        if self.event_triggers:
+            return {Hooks.EVENTS}
+        return set()
+
+    def _get_special_hooks(self) -> set[str]:
+        """Get the hooks required by special actions referenced in this component.
+
+        Returns:
+            The hooks for special actions.
+        """
+        if self.autofocus:
+            return {
+                """
+                // Set focus to the specified element.
+                const focusRef = useRef(null)
+                useEffect(() => {
+                  if (focusRef.current) {
+                    focusRef.current.focus();
+                  }
+                })""",
+            }
+        return set()
+
     def _get_hooks_internal(self) -> Set[str]:
         """Get the React hooks for this component managed by the framework.
 
@@ -712,10 +870,15 @@ class Component(Base, ABC):
         Returns:
             Set of internally managed hooks.
         """
-        return set(
-            hook
-            for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
-            if hook
+        return (
+            set(
+                hook
+                for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
+                if hook
+            )
+            | self._get_vars_hooks()
+            | self._get_events_hooks()
+            | self._get_special_hooks()
         )
 
     def _get_hooks(self) -> str | None:
@@ -1018,11 +1181,24 @@ class NoSSRComponent(Component):
     """A dynamic component that is not rendered on the server."""
 
     def _get_imports(self) -> imports.ImportDict:
-        dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}}
+        """Get the imports for the component.
+
+        Returns:
+            The imports for dynamically importing the component at module load time.
+        """
+        # Next.js dynamic import mechanism.
+        dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]}
+
+        # The normal imports for this component.
+        _imports = super()._get_imports()
+
+        # Do NOT import the main library/tag statically.
+        if self.library is not None:
+            _imports[self.library] = [imports.ImportVar(tag=None, render=False)]
 
         return imports.merge_imports(
             dynamic_import,
-            {self.library: {ImportVar(tag=None, render=False)}},
+            _imports,
             self._get_dependencies_imports(),
         )
 

+ 2 - 1
reflex/components/datadisplay/code.py

@@ -12,7 +12,8 @@ from reflex.components.media import Icon
 from reflex.event import set_clipboard
 from reflex.style import Style
 from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 LiteralCodeBlockTheme = Literal[
     "a11y-dark",

+ 2 - 1
reflex/components/datadisplay/code.pyi

@@ -16,7 +16,8 @@ from reflex.components.media import Icon
 from reflex.event import set_clipboard
 from reflex.style import Style
 from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 LiteralCodeBlockTheme = Literal[
     "a11y-dark",

+ 2 - 1
reflex/components/datadisplay/dataeditor.py

@@ -8,8 +8,9 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
 
 
 # TODO: Fix the serialization issue for custom types.

+ 2 - 1
reflex/components/datadisplay/dataeditor.pyi

@@ -13,8 +13,9 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
 from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
 from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
 
 class GridColumnIcons(Enum):
     Array = "array"

+ 6 - 6
reflex/components/datadisplay/datatable.py

@@ -8,7 +8,7 @@ from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
 from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
 
 
 class Gridjs(Component):
@@ -105,7 +105,7 @@ class DataTable(Gridjs):
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
             super()._get_imports(),
-            {"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
+            {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
         )
 
     def _render(self) -> Tag:
@@ -113,13 +113,13 @@ class DataTable(Gridjs):
             self.columns = BaseVar(
                 _var_name=f"{self.data._var_name}.columns",
                 _var_type=List[Any],
-                _var_state=self.data._var_state,
-            )
+                _var_full_name_needs_state_prefix=True,
+            )._replace(merge_var_data=self.data._var_data)
             self.data = BaseVar(
                 _var_name=f"{self.data._var_name}.data",
                 _var_type=List[List[Any]],
-                _var_state=self.data._var_state,
-            )
+                _var_full_name_needs_state_prefix=True,
+            )._replace(merge_var_data=self.data._var_data)
         if types.is_dataframe(type(self.data)):
             # If given a pandas df break up the data and columns
             data = serialize(self.data)

+ 1 - 1
reflex/components/datadisplay/datatable.pyi

@@ -12,7 +12,7 @@ from reflex.components.component import Component
 from reflex.components.tags import Tag
 from reflex.utils import imports, types
 from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
 
 class Gridjs(Component):
     @overload

+ 2 - 2
reflex/components/datadisplay/moment.py

@@ -4,7 +4,7 @@ from typing import Any, Dict, List
 
 from reflex.components.component import Component, NoSSRComponent
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class Moment(NoSSRComponent):
@@ -78,7 +78,7 @@ class Moment(NoSSRComponent):
         if self.tz is not None:
             merged_imports = imports.merge_imports(
                 merged_imports,
-                {"moment-timezone": {ImportVar(tag="")}},
+                {"moment-timezone": {imports.ImportVar(tag="")}},
             )
         return merged_imports
 

+ 1 - 1
reflex/components/datadisplay/moment.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from typing import Any, Dict, List
 from reflex.components.component import Component, NoSSRComponent
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 class Moment(NoSSRComponent):
     def get_event_triggers(self) -> Dict[str, Any]: ...

+ 2 - 2
reflex/components/forms/colormodeswitch.py

@@ -22,7 +22,7 @@ from reflex.components.component import Component
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.media.icon import Icon
 from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
 
 from .button import Button
 from .switch import Switch
@@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
 DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
 
 
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component:
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
     """Create a component or Prop based on color_mode.
 
     Args:

+ 2 - 2
reflex/components/forms/colormodeswitch.pyi

@@ -12,7 +12,7 @@ from reflex.components.component import Component
 from reflex.components.layout.cond import Cond, cond
 from reflex.components.media.icon import Icon
 from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
 from .button import Button
 from .switch import Switch
 
@@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str
 DEFAULT_LIGHT_ICON: Icon
 DEFAULT_DARK_ICON: Icon
 
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ...
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ...
 
 class ColorModeIcon(Cond):
     @overload

+ 13 - 1
reflex/components/forms/debounce.py

@@ -1,10 +1,11 @@
 """Wrapper around react-debounce-input."""
 from __future__ import annotations
 
-from typing import Any
+from typing import Any, Set
 
 from reflex.components import Component
 from reflex.components.tags import Tag
+from reflex.utils import imports
 from reflex.vars import Var
 
 
@@ -77,6 +78,17 @@ class DebounceInput(Component):
         object.__setattr__(child, "render", lambda: "")
         return tag
 
+    def _get_imports(self) -> imports.ImportDict:
+        return imports.merge_imports(
+            super()._get_imports(), *[c._get_imports() for c in self.children]
+        )
+
+    def _get_hooks_internal(self) -> Set[str]:
+        hooks = super()._get_hooks_internal()
+        for child in self.children:
+            hooks.update(child._get_hooks_internal())
+        return hooks
+
 
 def props_not_none(c: Component) -> dict[str, Any]:
     """Get all properties of the component that are not None.

+ 2 - 1
reflex/components/forms/debounce.pyi

@@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-from typing import Any
+from typing import Any, Set
 from reflex.components import Component
 from reflex.components.tags import Tag
+from reflex.utils import imports
 from reflex.vars import Var
 
 class DebounceInput(Component):

+ 2 - 1
reflex/components/forms/editor.py

@@ -8,7 +8,8 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 
 class EditorButtonList(list, enum.Enum):

+ 2 - 1
reflex/components/forms/editor.pyi

@@ -13,7 +13,8 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 class EditorButtonList(list, enum.Enum):
     BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]

+ 8 - 2
reflex/components/forms/form.py

@@ -8,7 +8,7 @@ from jinja2 import Environment
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
 from reflex.utils import imports
 from reflex.utils.format import format_event_chain, to_camel_case
@@ -65,7 +65,13 @@ class Form(ChakraComponent):
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
             super()._get_imports(),
-            {"react": {imports.ImportVar(tag="useCallback")}},
+            {
+                "react": {imports.ImportVar(tag="useCallback")},
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="getRefValue"),
+                    imports.ImportVar(tag="getRefValues"),
+                },
+            },
         )
 
     def _get_hooks(self) -> str | None:

+ 1 - 1
reflex/components/forms/form.pyi

@@ -12,7 +12,7 @@ from jinja2 import Environment
 from reflex.components.component import Component
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
 from reflex.utils import imports
 from reflex.utils.format import format_event_chain, to_camel_case

+ 2 - 2
reflex/components/forms/input.py

@@ -11,7 +11,7 @@ from reflex.components.libs.chakra import (
 )
 from reflex.constants import EventTriggers
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class Input(ChakraComponent):
@@ -61,7 +61,7 @@ class Input(ChakraComponent):
     def _get_imports(self) -> imports.ImportDict:
         return imports.merge_imports(
             super()._get_imports(),
-            {"/utils/state": {ImportVar(tag="set_val")}},
+            {"/utils/state": {imports.ImportVar(tag="set_val")}},
         )
 
     def get_event_triggers(self) -> Dict[str, Any]:

+ 1 - 1
reflex/components/forms/input.pyi

@@ -17,7 +17,7 @@ from reflex.components.libs.chakra import (
 )
 from reflex.constants import EventTriggers
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 class Input(ChakraComponent):
     def get_event_triggers(self) -> Dict[str, Any]: ...

+ 3 - 1
reflex/components/forms/pininput.py

@@ -68,9 +68,11 @@ class PinInput(ChakraComponent):
         Returns:
             The merged import dict.
         """
+        range_var = Var.range(0)
         return merge_imports(
             super()._get_imports(),
             PinInputField().get_imports(),  # type: ignore
+            range_var._var_data.imports if range_var._var_data is not None else {},
         )
 
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
@@ -117,7 +119,7 @@ class PinInput(ChakraComponent):
             )
             refs_declaration._var_is_local = True
             if ref:
-                return f"const {ref} = {refs_declaration}"
+                return f"const {ref} = {str(refs_declaration)}"
             return super()._get_ref_hook()
 
     def _render(self) -> Tag:

+ 27 - 10
reflex/components/forms/upload.py

@@ -7,9 +7,10 @@ from reflex import constants
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
+from reflex.constants import Dirs
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 DEFAULT_UPLOAD_ID: str = "default"
 
@@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
     return BaseVar(
         _var_name=f"e => upload_files.{id_}[1]((files) => e)",
         _var_type=EventChain,
+        _var_data=VarData(  # type: ignore
+            imports={
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="upload_files"),
+                },
+            },
+        ),
     )
 
 
@@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
     return BaseVar(
         _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
         _var_type=List[str],
+        _var_data=VarData(  # type: ignore
+            imports={
+                f"/{Dirs.STATE_PATH}": {
+                    imports.ImportVar(tag="upload_files"),
+                },
+            },
+        ),
     )
 
 
@@ -166,14 +181,16 @@ class Upload(Component):
 
     def _get_hooks(self) -> str | None:
         return (
-            (super()._get_hooks() or "")
-            + f"""
-        upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
-        """
-        )
+            super()._get_hooks() or ""
+        ) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);"
 
     def _get_imports(self) -> imports.ImportDict:
-        return {
-            **super()._get_imports(),
-            f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
-        }
+        return imports.merge_imports(
+            super()._get_imports(),
+            {
+                "react": {imports.ImportVar(tag="useState")},
+                f"/{constants.Dirs.STATE_PATH}": [
+                    imports.ImportVar(tag="upload_files")
+                ],
+            },
+        )

+ 2 - 1
reflex/components/forms/upload.pyi

@@ -12,9 +12,10 @@ from reflex import constants
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
+from reflex.constants import Dirs
 from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
 from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 DEFAULT_UPLOAD_ID: str
 

+ 44 - 7
reflex/components/layout/cond.py

@@ -1,13 +1,18 @@
 """Create a list of components from an iterable."""
 from __future__ import annotations
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, overload
 
 from reflex.components.component import Component
 from reflex.components.layout.fragment import Fragment
 from reflex.components.tags import CondTag, Tag
-from reflex.utils import format
-from reflex.vars import Var
+from reflex.constants import Dirs
+from reflex.utils import format, imports
+from reflex.vars import BaseVar, Var, VarData
+
+_IS_TRUE_IMPORT = {
+    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
+}
 
 
 class Cond(Component):
@@ -88,6 +93,28 @@ class Cond(Component):
             cond_state=f"isTrue({self.cond._var_full_name})",
         )
 
+    def _get_imports(self) -> imports.ImportDict:
+        return imports.merge_imports(
+            super()._get_imports(),
+            getattr(self.cond._var_data, "imports", {}),
+            _IS_TRUE_IMPORT,
+        )
+
+
+@overload
+def cond(condition: Any, c1: Component, c2: Any) -> Component:
+    ...
+
+
+@overload
+def cond(condition: Any, c1: Component) -> Component:
+    ...
+
+
+@overload
+def cond(condition: Any, c1: Any, c2: Any) -> Var:
+    ...
+
 
 def cond(condition: Any, c1: Any, c2: Any = None):
     """Create a conditional component or Prop.
@@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
     Raises:
         ValueError: If the arguments are invalid.
     """
-    # Import here to avoid circular imports.
-    from reflex.vars import BaseVar, Var
+    var_datas: list[VarData | None] = [
+        VarData(  # type: ignore
+            imports=_IS_TRUE_IMPORT,
+        ),
+    ]
 
     # Convert the condition to a Var.
     cond_var = Var.create(condition)
@@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None):
             c2, Component
         ), "Both arguments must be components."
         return Cond.create(cond_var, c1, c2)
+    if isinstance(c1, Var):
+        var_datas.append(c1._var_data)
 
-    # Otherwise, create a conditionl Var.
+    # Otherwise, create a conditional Var.
     # Check that the second argument is valid.
     if isinstance(c2, Component):
         raise ValueError("Both arguments must be props.")
     if c2 is None:
         raise ValueError("For conditional vars, the second argument must be set.")
+    if isinstance(c2, Var):
+        var_datas.append(c2._var_data)
 
     # Create the conditional var.
-    return BaseVar(
+    return cond_var._replace(
         _var_name=format.format_cond(
             cond=cond_var._var_full_name,
             true_value=c1,
@@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
             is_prop=True,
         ),
         _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
+        _var_is_local=False,
+        _var_full_name_needs_state_prefix=False,
+        merge_var_data=VarData.merge(*var_datas),
     )

+ 3 - 3
reflex/components/layout/html.py

@@ -1,8 +1,8 @@
 """A html component."""
-
-from typing import Any
+from typing import Dict
 
 from reflex.components.layout.box import Box
+from reflex.vars import Var
 
 
 class Html(Box):
@@ -13,7 +13,7 @@ class Html(Box):
     """
 
     # The HTML to render.
-    dangerouslySetInnerHTML: Any
+    dangerouslySetInnerHTML: Var[Dict[str, str]]
 
     @classmethod
     def create(cls, *children, **props):

+ 5 - 2
reflex/components/layout/html.pyi

@@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
-from typing import Any
+from typing import Dict
 from reflex.components.layout.box import Box
+from reflex.vars import Var
 
 class Html(Box):
     @overload
@@ -16,7 +17,9 @@ class Html(Box):
     def create(  # type: ignore
         cls,
         *children,
-        dangerouslySetInnerHTML: Optional[Any] = None,
+        dangerouslySetInnerHTML: Optional[
+            Union[Var[Dict[str, str]], Dict[str, str]]
+        ] = None,
         element: Optional[Union[Var[str], str]] = None,
         src: Optional[Union[Var[str], str]] = None,
         alt: Optional[Union[Var[str], str]] = None,

+ 10 - 10
reflex/components/libs/chakra.py

@@ -6,7 +6,7 @@ from typing import List, Literal
 
 from reflex.components.component import Component
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 
 class ChakraComponent(Component):
@@ -34,7 +34,7 @@ class ChakraComponent(Component):
             The dependencies imports of the component.
         """
         return {
-            dep: [ImportVar(tag=None, render=False)]
+            dep: [imports.ImportVar(tag=None, render=False)]
             for dep in [
                 "@chakra-ui/system@2.5.7",
                 "framer-motion@10.16.4",
@@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent):
         )
 
     def _get_imports(self) -> imports.ImportDict:
-        imports = super()._get_imports()
-        imports.setdefault(self.__fields__["library"].default, []).append(
-            ImportVar(tag="extendTheme", is_default=False),
+        _imports = super()._get_imports()
+        _imports.setdefault(self.__fields__["library"].default, []).append(
+            imports.ImportVar(tag="extendTheme", is_default=False),
         )
-        imports.setdefault("/utils/theme.js", []).append(
-            ImportVar(tag="theme", is_default=True),
+        _imports.setdefault("/utils/theme.js", []).append(
+            imports.ImportVar(tag="theme", is_default=True),
         )
-        imports.setdefault(Global.__fields__["library"].default, []).append(
-            ImportVar(tag="css", is_default=False),
+        _imports.setdefault(Global.__fields__["library"].default, []).append(
+            imports.ImportVar(tag="css", is_default=False),
         )
-        return imports
+        return _imports
 
     def _get_custom_code(self) -> str | None:
         return """

+ 1 - 1
reflex/components/libs/chakra.pyi

@@ -11,7 +11,7 @@ from functools import lru_cache
 from typing import List, Literal
 from reflex.components.component import Component
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 class ChakraComponent(Component):
     @overload

+ 7 - 8
reflex/components/navigation/client_side_routing.py

@@ -10,10 +10,9 @@ routeNotFound becomes true.
 from __future__ import annotations
 
 from reflex import constants
-
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
 
 route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
 
@@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component:
     Returns:
         The conditionally rendered component.
     """
-    return Cond.create(
-        cond=route_not_found,
-        comp1=component,
-        comp2=ClientSideRouting.create(),
+    return cond(
+        condition=route_not_found,
+        c1=component,
+        c2=ClientSideRouting.create(),
     )
 
 

+ 3 - 3
reflex/components/navigation/client_side_routing.pyi

@@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex import constants
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
 
 route_not_found: Var
 

+ 13 - 8
reflex/components/overlay/banner.py

@@ -5,22 +5,27 @@ from typing import Optional
 
 from reflex.components.base.bare import Bare
 from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
+
+connect_error_var_data: VarData = VarData(  # type: ignore
+    imports=Imports.EVENTS,
+    hooks={Hooks.EVENTS},
+)
 
 connection_error: Var = Var.create_safe(
     value="(connectError !== null) ? connectError.message : ''",
     _var_is_local=False,
     _var_is_string=False,
-)
+)._replace(merge_var_data=connect_error_var_data)
 has_connection_error: Var = Var.create_safe(
     value="connectError !== null",
     _var_is_string=False,
-)
-has_connection_error._var_type = bool
+)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
 
 
 class WebsocketTargetURL(Bare):
@@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare):
 
     def _get_imports(self) -> imports.ImportDict:
         return {
-            "/utils/state.js": [ImportVar(tag="getEventURL")],
+            "/utils/state.js": [imports.ImportVar(tag="getEventURL")],
         }
 
     @classmethod
@@ -78,7 +83,7 @@ class ConnectionBanner(Component):
                 textAlign="center",
             )
 
-        return Cond.create(has_connection_error, comp)
+        return cond(has_connection_error, comp)
 
 
 class ConnectionModal(Component):
@@ -96,7 +101,7 @@ class ConnectionModal(Component):
         """
         if not comp:
             comp = Text.create(*default_connection_error())
-        return Cond.create(
+        return cond(
             has_connection_error,
             Modal.create(
                 header="Connection Error",

+ 4 - 3
reflex/components/overlay/banner.pyi

@@ -10,15 +10,16 @@ from reflex.style import Style
 from typing import Optional
 from reflex.components.base.bare import Bare
 from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
 
+connect_error_var_data: VarData
 connection_error: Var
 has_connection_error: Var
-has_connection_error._var_type = bool
 
 class WebsocketTargetURL(Bare):
     @overload

+ 2 - 2
reflex/components/radix/themes/base.py

@@ -6,7 +6,7 @@ from typing import Literal
 
 from reflex.components import Component
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralJustify = Literal["start", "center", "end", "between"]
@@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
     def _get_imports(self) -> imports.ImportDict:
         return {
             **super()._get_imports(),
-            "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
+            "": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
         }
 
 

+ 1 - 1
reflex/components/radix/themes/base.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from typing import Literal
 from reflex.components import Component
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
 LiteralJustify = Literal["start", "center", "end", "between"]

+ 2 - 1
reflex/components/typography/markdown.py

@@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading
 from reflex.components.typography.text import Text
 from reflex.style import Style
 from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 # Special vars used in the component map.
 _CHILDREN = Var.create_safe("children", _var_is_local=False)

+ 2 - 1
reflex/components/typography/markdown.pyi

@@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading
 from reflex.components.typography.text import Text
 from reflex.style import Style
 from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
 
 _CHILDREN = Var.create_safe("children", _var_is_local=False)
 _PROPS = Var.create_safe("...props", _var_is_local=False)

+ 4 - 0
reflex/constants/__init__.py

@@ -22,6 +22,8 @@ from .compiler import (
     CompileVars,
     ComponentName,
     Ext,
+    Hooks,
+    Imports,
     PageNames,
 )
 from .config import (
@@ -68,7 +70,9 @@ __ALL__ = [
     Ext,
     Fnm,
     GitIgnore,
+    Hooks,
     RequirementsTxt,
+    Imports,
     IS_WINDOWS,
     LOCAL_STORAGE,
     LogLevel,

+ 2 - 0
reflex/constants/base.py

@@ -29,6 +29,8 @@ class Dirs(SimpleNamespace):
     STATE_PATH = "/".join([UTILS, "state"])
     # The name of the components file.
     COMPONENTS_PATH = "/".join([UTILS, "components"])
+    # The name of the contexts file.
+    CONTEXTS_PATH = "/".join([UTILS, "context"])
     # The directory where the app pages are compiled to.
     WEB_PAGES = os.path.join(WEB, "pages")
     # The directory where the static build is located.

+ 25 - 0
reflex/constants/compiler.py

@@ -2,6 +2,9 @@
 from enum import Enum
 from types import SimpleNamespace
 
+from reflex.constants import Dirs
+from reflex.utils.imports import ImportVar
+
 # The prefix used to create setters for state vars.
 SETTER_PREFIX = "set_"
 
@@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace):
     HYDRATE = "hydrate"
     # The name of the is_hydrated variable.
     IS_HYDRATED = "is_hydrated"
+    # The name of the function to add events to the queue.
+    ADD_EVENTS = "addEvents"
+    # The name of the var storing any connection error.
+    CONNECT_ERROR = "connectError"
+    # The name of the function for converting a dict to an event.
+    TO_EVENT = "Event"
 
 
 class PageNames(SimpleNamespace):
@@ -77,3 +86,19 @@ class ComponentName(Enum):
             The lower-case filename with zip extension.
         """
         return self.value.lower() + Ext.ZIP
+
+
+class Imports(SimpleNamespace):
+    """Common sets of import vars."""
+
+    EVENTS = {
+        "react": {ImportVar(tag="useContext")},
+        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
+        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
+    }
+
+
+class Hooks(SimpleNamespace):
+    """Common sets of hook declarations."""
+
+    EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"

+ 1 - 1
reflex/middleware/hydrate_middleware.py

@@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware):
                     setattr(var_state, var_name, value)
 
         # Get the initial state.
-        delta = format.format_state({state.get_name(): state.dict()})
+        delta = format.format_state(state.dict())
         # since a full dict was captured, clean any dirtiness
         state._clean()
 

+ 9 - 5
reflex/state.py

@@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             if include_computed
             else {}
         )
-        substate_vars = {
-            k: v.dict(include_computed=include_computed, **kwargs)
-            for k, v in self.substates.items()
+        variables = {**base_vars, **computed_vars}
+        d = {
+            self.get_full_name(): {k: variables[k] for k in sorted(variables)},
         }
-        variables = {**base_vars, **computed_vars, **substate_vars}
-        return {k: variables[k] for k in sorted(variables)}
+        for substate_d in [
+            v.dict(include_computed=include_computed, **kwargs)
+            for v in self.substates.values()
+        ]:
+            d.update(substate_d)
+        return d
 
     async def __aenter__(self) -> State:
         """Enter the async context manager protocol.

+ 71 - 7
reflex/style.py

@@ -2,13 +2,38 @@
 
 from __future__ import annotations
 
+from typing import Any
+
 from reflex import constants
 from reflex.event import EventChain
 from reflex.utils import format
-from reflex.vars import BaseVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import BaseVar, Var, VarData
+
+VarData.update_forward_refs()  # Ensure all type definitions are resolved
 
-color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str")
-toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain)
+# Reference the global ColorModeContext
+color_mode_var_data = VarData(  # type: ignore
+    imports={
+        f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
+        "react": {ImportVar(tag="useContext")},
+    },
+    hooks={
+        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+    },
+)
+# Var resolves to the current color mode for the app ("light" or "dark")
+color_mode = BaseVar(
+    _var_name=constants.ColorMode.NAME,
+    _var_type="str",
+    _var_data=color_mode_var_data,
+)
+# Var resolves to a function invocation that toggles the color mode
+toggle_color_mode = BaseVar(
+    _var_name=constants.ColorMode.TOGGLE,
+    _var_type=EventChain,
+    _var_data=color_mode_var_data,
+)
 
 
 def convert(style_dict):
@@ -20,16 +45,27 @@ def convert(style_dict):
     Returns:
         The formatted style dictionary.
     """
+    var_data = None  # Track import/hook data from any Vars in the style dict.
     out = {}
     for key, value in style_dict.items():
         key = format.to_camel_case(key)
+        new_var_data = None
         if isinstance(value, dict):
-            out[key] = convert(value)
+            # Recursively format nested style dictionaries.
+            out[key], new_var_data = convert(value)
         elif isinstance(value, Var):
+            # If the value is a Var, extract the var_data and cast as str.
+            new_var_data = value._var_data
             out[key] = str(value)
         else:
+            # Otherwise, convert to Var to collapse VarData encoded in f-string.
+            new_var = Var.create(value)
+            if new_var is not None:
+                new_var_data = new_var._var_data
             out[key] = value
-    return out
+        # Combine all the collected VarData instances.
+        var_data = VarData.merge(var_data, new_var_data)
+    return out, var_data
 
 
 class Style(dict):
@@ -41,5 +77,33 @@ class Style(dict):
         Args:
             style_dict: The style dictionary.
         """
-        style_dict = style_dict or {}
-        super().__init__(convert(style_dict))
+        style_dict, self._var_data = convert(style_dict or {})
+        super().__init__(style_dict)
+
+    def update(self, style_dict: dict | None, **kwargs):
+        """Update the style.
+
+        Args:
+            style_dict: The style dictionary.
+            kwargs: Other key value pairs to apply to the dict update.
+        """
+        if kwargs:
+            style_dict = {**(style_dict or {}), **kwargs}
+        converted_dict = type(self)(style_dict)
+        # Combine our VarData with that of any Vars in the style_dict that was passed.
+        self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
+        super().update(converted_dict)
+
+    def __setitem__(self, key: str, value: Any):
+        """Set an item in the style.
+
+        Args:
+            key: The key to set.
+            value: The value to set.
+        """
+        # Create a Var to collapse VarData encoded in f-string.
+        _var = Var.create(value)
+        if _var is not None:
+            # Carry the imports/hooks when setting a Var as a value.
+            self._var_data = VarData.merge(self._var_data, _var._var_data)
+        super().__setitem__(key, value)

+ 19 - 6
reflex/utils/format.py

@@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str:
 
 
 def format_cond(
-    cond: str,
-    true_value: str,
-    false_value: str = '""',
+    cond: str | Var,
+    true_value: str | Var,
+    false_value: str | Var = '""',
     is_prop=False,
 ) -> str:
     """Format a conditional expression.
@@ -248,9 +248,6 @@ def format_cond(
     Returns:
         The formatted conditional expression.
     """
-    # Import here to avoid circular imports.
-    from reflex.vars import Var
-
     # Use Python truthiness.
     cond = f"isTrue({cond})"
 
@@ -266,6 +263,7 @@ def format_cond(
             _var_is_string=type(false_value) is str,
         )
         prop2._var_is_local = True
+        prop1, prop2 = str(prop1), str(prop2)  # avoid f-string semantics for Var
         return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
 
     # Format component conds.
@@ -517,6 +515,21 @@ def format_state(value: Any) -> Any:
     raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
 
 
+def format_state_name(state_name: str) -> str:
+    """Format a state name, replacing dots with double underscore.
+
+    This allows individual substates to be accessed independently as javascript vars
+    without using dot notation.
+
+    Args:
+        state_name: The state name to format.
+
+    Returns:
+        The formatted state name.
+    """
+    return state_name.replace(".", "__")
+
+
 def format_ref(ref: str) -> str:
     """Format a ref.
 

+ 41 - 4
reflex/utils/imports.py

@@ -3,11 +3,9 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from typing import Dict, List
+from typing import Dict, List, Optional
 
-from reflex.vars import ImportVar
-
-ImportDict = Dict[str, List[ImportVar]]
+from reflex.base import Base
 
 
 def merge_imports(*imports) -> ImportDict:
@@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict:
         for lib, fields in import_dict.items():
             all_imports[lib].extend(fields)
     return all_imports
+
+
+class ImportVar(Base):
+    """An import var."""
+
+    # The name of the import tag.
+    tag: Optional[str]
+
+    # whether the import is default or named.
+    is_default: Optional[bool] = False
+
+    # The tag alias.
+    alias: Optional[str] = None
+
+    # Whether this import need to install the associated lib
+    install: Optional[bool] = True
+
+    # whether this import should be rendered or not
+    render: Optional[bool] = True
+
+    @property
+    def name(self) -> str:
+        """The name of the import.
+
+        Returns:
+            The name(tag name with alias) of tag.
+        """
+        return self.tag if not self.alias else " as ".join([self.tag, self.alias])  # type: ignore
+
+    def __hash__(self) -> int:
+        """Define a hash function for the import var.
+
+        Returns:
+            The hash of the var.
+        """
+        return hash((self.tag, self.is_default, self.alias, self.install, self.render))
+
+
+ImportDict = Dict[str, List[ImportVar]]

+ 1 - 0
reflex/utils/types.py

@@ -27,6 +27,7 @@ from reflex.utils import serializers
 GenericType = Union[Type, _GenericAlias]
 
 # Valid state var types.
+JSONType = {str, int, float, bool}
 PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
 StateVar = Union[PrimitiveType, Base, None]
 StateIterVar = Union[list, set, tuple]

+ 298 - 120
reflex/vars.py

@@ -7,6 +7,7 @@ import dis
 import inspect
 import json
 import random
+import re
 import string
 import sys
 from types import CodeType, FunctionType
@@ -15,9 +16,11 @@ from typing import (
     Any,
     Callable,
     Dict,
+    Iterable,
     List,
     Literal,
     Optional,
+    Set,
     Tuple,
     Type,
     Union,
@@ -30,7 +33,10 @@ from typing import (
 
 from reflex import constants
 from reflex.base import Base
-from reflex.utils import console, format, serializers, types
+from reflex.utils import console, format, imports, serializers, types
+
+# This module used to export ImportVar itself, so we still import it for export here
+from reflex.utils.imports import ImportDict, ImportVar
 
 if TYPE_CHECKING:
     from reflex.state import State
@@ -71,7 +77,7 @@ OPERATION_MAPPING = {
 REPLACED_NAMES = {
     "full_name": "_var_full_name",
     "name": "_var_name",
-    "state": "_var_state",
+    "state": "_var_data.state",
     "type_": "_var_type",
     "is_local": "_var_is_local",
     "is_string": "_var_is_string",
@@ -93,6 +99,131 @@ def get_unique_variable_name() -> str:
     return get_unique_variable_name()
 
 
+class VarData(Base):
+    """Metadata associated with a Var."""
+
+    # The name of the enclosing state.
+    state: str = ""
+
+    # Imports needed to render this var
+    imports: ImportDict = {}
+
+    # Hooks that need to be present in the component to render this var
+    hooks: Set[str] = set()
+
+    @classmethod
+    def merge(cls, *others: VarData | None) -> VarData | None:
+        """Merge multiple var data objects.
+
+        Args:
+            *others: The var data objects to merge.
+
+        Returns:
+            The merged var data object.
+        """
+        state = ""
+        _imports = {}
+        hooks = set()
+        for var_data in others:
+            if var_data is None:
+                continue
+            state = state or var_data.state
+            _imports = imports.merge_imports(_imports, var_data.imports)
+            hooks.update(var_data.hooks)
+        return (
+            cls(
+                state=state,
+                imports=_imports,
+                hooks=hooks,
+            )
+            or None
+        )
+
+    def __bool__(self) -> bool:
+        """Check if the var data is non-empty.
+
+        Returns:
+            True if any field is set to a non-default value.
+        """
+        return bool(self.state or self.imports or self.hooks)
+
+    def dict(self) -> dict:
+        """Convert the var data to a dictionary.
+
+        Returns:
+            The var data dictionary.
+        """
+        return {
+            "state": self.state,
+            "imports": {
+                lib: [import_var.dict() for import_var in import_vars]
+                for lib, import_vars in self.imports.items()
+            },
+            "hooks": list(self.hooks),
+        }
+
+
+def _encode_var(value: Var) -> str:
+    """Encode the state name into a formatted var.
+
+    Args:
+        value: The value to encode the state name into.
+
+    Returns:
+        The encoded var.
+    """
+    if value._var_data:
+        return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
+    return str(value)
+
+
+def _decode_var(value: str) -> tuple[VarData | None, str]:
+    """Decode the state name from a formatted var.
+
+    Args:
+        value: The value to extract the state name from.
+
+    Returns:
+        The extracted state name and the value without the state name.
+    """
+    var_datas = []
+    if isinstance(value, str):
+        # Extract the state name from a formatted var
+        while m := re.match(r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)", value):
+            value = m.group(1) + m.group(3)
+            var_datas.append(VarData.parse_raw(m.group(2)))
+    if var_datas:
+        return VarData.merge(*var_datas), value
+    return None, value
+
+
+def _extract_var_data(value: Iterable) -> list[VarData | None]:
+    """Extract the var imports and hooks from an iterable containing a Var.
+
+    Args:
+        value: The iterable to extract the VarData from
+
+    Returns:
+        The extracted VarDatas.
+    """
+    var_datas = []
+    with contextlib.suppress(TypeError):
+        for sub in value:
+            if isinstance(sub, Var):
+                var_datas.append(sub._var_data)
+            elif not isinstance(sub, str):
+                # Recurse into dict values.
+                if hasattr(sub, "values") and callable(sub.values):
+                    var_datas.extend(_extract_var_data(sub.values()))
+                # Recurse into iterable values (or dict keys).
+                var_datas.extend(_extract_var_data(sub))
+    # Recurse when value is a dict itself.
+    values = getattr(value, "values", None)
+    if callable(values):
+        var_datas.extend(_extract_var_data(values()))
+    return var_datas
+
+
 class Var:
     """An abstract var."""
 
@@ -102,15 +233,18 @@ class Var:
     # The type of the var.
     _var_type: Type
 
-    # The name of the enclosing state.
-    _var_state: str
-
     # Whether this is a local javascript variable.
     _var_is_local: bool
 
     # Whether the var is a string literal.
     _var_is_string: bool
 
+    # _var_full_name should be prefixed with _var_state
+    _var_full_name_needs_state_prefix: bool
+
+    # Extra metadata associated with the Var
+    _var_data: Optional[VarData]
+
     @classmethod
     def create(
         cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
@@ -136,9 +270,14 @@ class Var:
         if isinstance(value, Var):
             return value
 
+        # Try to pull the imports and hooks from contained values.
+        _var_data = None
+        if not isinstance(value, str):
+            _var_data = VarData.merge(*_extract_var_data(value))
+
         # Try to serialize the value.
         type_ = type(value)
-        name = serializers.serialize(value)
+        name = value if type_ in types.JSONType else serializers.serialize(value)
         if name is None:
             raise TypeError(
                 f"No JSON serializer found for var {value} of type {type_}."
@@ -150,6 +289,7 @@ class Var:
             _var_type=type_,
             _var_is_local=_var_is_local,
             _var_is_string=_var_is_string,
+            _var_data=_var_data,
         )
 
     @classmethod
@@ -186,6 +326,39 @@ class Var:
         """
         return _GenericAlias(cls, type_)
 
+    def __post_init__(self) -> None:
+        """Post-initialize the var."""
+        # Decode any inline Var markup and apply it to the instance
+        _var_data, _var_name = _decode_var(self._var_name)
+        if _var_data:
+            self._var_name = _var_name
+            self._var_data = VarData.merge(self._var_data, _var_data)
+
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
+        """Make a copy of this Var with updated fields.
+
+        Args:
+            merge_var_data: VarData to merge into the existing VarData.
+            **kwargs: Var fields to update.
+
+        Returns:
+            A new BaseVar with the updated fields overwriting the corresponding fields in this Var.
+        """
+        field_values = dict(
+            _var_name=kwargs.pop("_var_name", self._var_name),
+            _var_type=kwargs.pop("_var_type", self._var_type),
+            _var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
+            _var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
+            _var_full_name_needs_state_prefix=kwargs.pop(
+                "_var_full_name_needs_state_prefix",
+                self._var_full_name_needs_state_prefix,
+            ),
+            _var_data=VarData.merge(
+                kwargs.get("_var_data", self._var_data), merge_var_data
+            ),
+        )
+        return BaseVar(**field_values)
+
     def _decode(self) -> Any:
         """Decode Var as a python value.
 
@@ -195,8 +368,6 @@ class Var:
         Returns:
             The decoded value or the Var name.
         """
-        if self._var_state:
-            return self._var_full_name
         if self._var_is_string:
             return self._var_name
         try:
@@ -216,8 +387,10 @@ class Var:
         return (
             self._var_name == other._var_name
             and self._var_type == other._var_type
-            and self._var_state == other._var_state
             and self._var_is_local == other._var_is_local
+            and self._var_full_name_needs_state_prefix
+            == other._var_full_name_needs_state_prefix
+            and self._var_data == other._var_data
         )
 
     def to_string(self, json: bool = True) -> Var:
@@ -285,9 +458,11 @@ class Var:
         Returns:
             The formatted var.
         """
+        # Encode the _var_data into the formatted output for tracking purposes.
+        str_self = _encode_var(self)
         if self._var_is_local:
-            return str(self)
-        return f"${str(self)}"
+            return str_self
+        return f"${str_self}"
 
     def __getitem__(self, i: Any) -> Var:
         """Index into a var.
@@ -320,12 +495,7 @@ class Var:
 
         # Convert any vars to local vars.
         if isinstance(i, Var):
-            i = BaseVar(
-                _var_name=i._var_name,
-                _var_type=i._var_type,
-                _var_state=i._var_state,
-                _var_is_local=True,
-            )
+            i = i._replace(_var_is_local=True)
 
         # Handle list/tuple/str indexing.
         if types._issubclass(self._var_type, Union[List, Tuple, str]):
@@ -344,11 +514,9 @@ class Var:
                 stop = i.stop or "undefined"
 
                 # Use the slice function.
-                return BaseVar(
+                return self._replace(
                     _var_name=f"{self._var_name}.slice({start}, {stop})",
-                    _var_type=self._var_type,
-                    _var_state=self._var_state,
-                    _var_is_local=self._var_is_local,
+                    _var_is_string=False,
                 )
 
             # Get the type of the indexed var.
@@ -359,11 +527,10 @@ class Var:
             )
 
             # Use `at` to support negative indices.
-            return BaseVar(
+            return self._replace(
                 _var_name=f"{self._var_name}.at({i})",
                 _var_type=type_,
-                _var_state=self._var_state,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
             )
 
         # Dictionary / dataframe indexing.
@@ -393,11 +560,10 @@ class Var:
         )
 
         # Use normal indexing here.
-        return BaseVar(
+        return self._replace(
             _var_name=f"{self._var_name}[{i}]",
             _var_type=type_,
-            _var_state=self._var_state,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
 
     def __getattr__(self, name: str) -> Var:
@@ -423,11 +589,10 @@ class Var:
             type_ = types.get_attribute_access_type(self._var_type, name)
 
             if type_ is not None:
-                return BaseVar(
+                return self._replace(
                     _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
                     _var_type=type_,
-                    _var_state=self._var_state,
-                    _var_is_local=self._var_is_local,
+                    _var_is_string=False,
                 )
 
             if name in REPLACED_NAMES:
@@ -519,10 +684,12 @@ class Var:
                     else f"{self._var_full_name}.{fn}()"
                 )
 
-        return BaseVar(
+        return self._replace(
             _var_name=operation_name,
             _var_type=type_,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
+            merge_var_data=other._var_data if other is not None else None,
         )
 
     @staticmethod
@@ -602,10 +769,10 @@ class Var:
         """
         if not types._issubclass(self._var_type, List):
             raise TypeError(f"Cannot get length of non-list var {self}.")
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.length",
+        return self._replace(
+            _var_name=f"{self._var_name}.length",
             _var_type=int,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
 
     def __eq__(self, other: Var) -> Var:
@@ -692,7 +859,17 @@ class Var:
             types.get_base_class(self._var_type) == list
             and types.get_base_class(other_type) == list
         ):
-            return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip)
+            return self.operation(
+                ",", other, fn="spreadArraysOrObjects", flip=flip
+            )._replace(
+                merge_var_data=VarData(
+                    imports={
+                        f"/{constants.Dirs.STATE_PATH}": [
+                            ImportVar(tag="spreadArraysOrObjects")
+                        ]
+                    },
+                ),
+            )
         return self.operation("+", other, flip=flip)
 
     def __radd__(self, other: Var) -> Var:
@@ -755,10 +932,11 @@ class Var:
         ]:
             other_name = other._var_full_name if isinstance(other, Var) else other
             name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
-            return BaseVar(
+            return self._replace(
                 _var_name=name,
                 _var_type=str,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                _var_full_name_needs_state_prefix=False,
             )
 
         return self.operation("*", other)
@@ -1003,10 +1181,11 @@ class Var:
         elif not isinstance(other, Var):
             other = Var.create(other)
         if types._issubclass(self._var_type, Dict):
-            return BaseVar(
-                _var_name=f"{self._var_full_name}.{method}({other._var_full_name})",
+            return self._replace(
+                _var_name=f"{self._var_name}.{method}({other._var_full_name})",
                 _var_type=bool,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                merge_var_data=other._var_data,
             )
         else:  # str, list, tuple
             # For strings, the left operand must be a string.
@@ -1016,10 +1195,11 @@ class Var:
                 raise TypeError(
                     f"'in <string>' requires string as left operand, not {other._var_type}"
                 )
-            return BaseVar(
-                _var_name=f"{self._var_full_name}.includes({other._var_full_name})",
+            return self._replace(
+                _var_name=f"{self._var_name}.includes({other._var_full_name})",
                 _var_type=bool,
-                _var_is_local=self._var_is_local,
+                _var_is_string=False,
+                merge_var_data=other._var_data,
             )
 
     def reverse(self) -> Var:
@@ -1034,10 +1214,10 @@ class Var:
         if not types._issubclass(self._var_type, list):
             raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
 
-        return BaseVar(
+        return self._replace(
             _var_name=f"[...{self._var_full_name}].reverse()",
-            _var_type=self._var_type,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
+            _var_full_name_needs_state_prefix=False,
         )
 
     def lower(self) -> Var:
@@ -1054,10 +1234,10 @@ class Var:
                 f"Cannot convert non-string var {self._var_full_name} to lowercase."
             )
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.toLowerCase()",
+        return self._replace(
+            _var_name=f"{self._var_name}.toLowerCase()",
+            _var_is_string=False,
             _var_type=str,
-            _var_is_local=self._var_is_local,
         )
 
     def upper(self) -> Var:
@@ -1074,10 +1254,10 @@ class Var:
                 f"Cannot convert non-string var {self._var_full_name} to uppercase."
             )
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.toUpperCase()",
+        return self._replace(
+            _var_name=f"{self._var_name}.toUpperCase()",
+            _var_is_string=False,
             _var_type=str,
-            _var_is_local=self._var_is_local,
         )
 
     def split(self, other: str | Var[str] = " ") -> Var:
@@ -1097,10 +1277,11 @@ class Var:
 
         other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.split({other._var_full_name})",
+        return self._replace(
+            _var_name=f"{self._var_name}.split({other._var_full_name})",
+            _var_is_string=False,
             _var_type=list[str],
-            _var_is_local=self._var_is_local,
+            merge_var_data=other._var_data,
         )
 
     def join(self, other: str | Var[str] | None = None) -> Var:
@@ -1125,10 +1306,11 @@ class Var:
         else:
             other = Var.create_safe(other)
 
-        return BaseVar(
-            _var_name=f"{self._var_full_name}.join({other._var_full_name})",
+        return self._replace(
+            _var_name=f"{self._var_name}.join({other._var_full_name})",
+            _var_is_string=False,
             _var_type=str,
-            _var_is_local=self._var_is_local,
+            merge_var_data=other._var_data,
         )
 
     def foreach(self, fn: Callable) -> Var:
@@ -1159,10 +1341,9 @@ class Var:
         fn_signature = inspect.signature(fn)
         fn_args = (arg, index)
         fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
-        return BaseVar(
+        return self._replace(
             _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
-            _var_type=self._var_type,
-            _var_is_local=self._var_is_local,
+            _var_is_string=False,
         )
 
     @classmethod
@@ -1207,6 +1388,18 @@ class Var:
             _var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
             _var_type=list[int],
             _var_is_local=False,
+            _var_data=VarData.merge(
+                v1._var_data,
+                v2._var_data,
+                step._var_data,
+                VarData(
+                    imports={
+                        "/utils/helpers/range.js": [
+                            ImportVar(tag="range", is_default=True),
+                        ],
+                    },
+                ),
+            ),
         )
 
     def to(self, type_: Type) -> Var:
@@ -1218,12 +1411,7 @@ class Var:
         Returns:
             The converted var.
         """
-        return BaseVar(
-            _var_name=self._var_name,
-            _var_type=type_,
-            _var_state=self._var_state,
-            _var_is_local=self._var_is_local,
-        )
+        return self._replace(_var_type=type_)
 
     @property
     def _var_full_name(self) -> str:
@@ -1232,24 +1420,51 @@ class Var:
         Returns:
             The full name of the var.
         """
+        if not self._var_full_name_needs_state_prefix:
+            return self._var_name
         return (
             self._var_name
-            if self._var_state == ""
-            else ".".join([self._var_state, self._var_name])
+            if self._var_data is None or self._var_data.state == ""
+            else ".".join(
+                [format.format_state_name(self._var_data.state), self._var_name]
+            )
         )
 
-    def _var_set_state(self, state: Type[State]) -> Any:
+    def _var_set_state(self, state: Type[State] | str) -> Any:
         """Set the state of the var.
 
         Args:
-            state: The state to set.
+            state: The state to set or the full name of the state.
 
         Returns:
             The var with the set state.
         """
-        self._var_state = state.get_full_name()
+        state_name = state if isinstance(state, str) else state.get_full_name()
+        new_var_data = VarData(
+            state=state_name,
+            hooks={
+                "const {0} = useContext(StateContexts.{0})".format(
+                    format.format_state_name(state_name)
+                )
+            },
+            imports={
+                f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
+                "react": [ImportVar(tag="useContext")],
+            },
+        )
+        self._var_data = VarData.merge(self._var_data, new_var_data)
+        self._var_full_name_needs_state_prefix = True
         return self
 
+    @property
+    def _var_state(self) -> str:
+        """Compat method for getting the state.
+
+        Returns:
+            The state name associated with the var.
+        """
+        return self._var_data.state if self._var_data else ""
+
 
 @dataclasses.dataclass(
     eq=False,
@@ -1264,15 +1479,18 @@ class BaseVar(Var):
     # The type of the var.
     _var_type: Type = dataclasses.field(default=Any)
 
-    # The name of the enclosing state.
-    _var_state: str = dataclasses.field(default="")
-
     # Whether this is a local javascript variable.
     _var_is_local: bool = dataclasses.field(default=False)
 
     # Whether the var is a string literal.
     _var_is_string: bool = dataclasses.field(default=False)
 
+    # _var_full_name should be prefixed with _var_state
+    _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
+
+    # Extra metadata associated with the Var
+    _var_data: Optional[VarData] = dataclasses.field(default=None)
+
     def __hash__(self) -> int:
         """Define a hash function for a var.
 
@@ -1334,9 +1552,11 @@ class BaseVar(Var):
             The name of the setter function.
         """
         setter = constants.SETTER_PREFIX + self._var_name
-        if not include_state or self._var_state == "":
+        if self._var_data is None:
+            return setter
+        if not include_state or self._var_data.state == "":
             return setter
-        return ".".join((self._var_state, setter))
+        return ".".join((self._var_data.state, setter))
 
     def get_setter(self) -> Callable[[State, Any], None]:
         """Get the var's setter function.
@@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
     return cvar
 
 
-class ImportVar(Base):
-    """An import var."""
-
-    # The name of the import tag.
-    tag: Optional[str]
-
-    # whether the import is default or named.
-    is_default: Optional[bool] = False
-
-    # The tag alias.
-    alias: Optional[str] = None
-
-    # Whether this import need to install the associated lib
-    install: Optional[bool] = True
-
-    # whether this import should be rendered or not
-    render: Optional[bool] = True
-
-    @property
-    def name(self) -> str:
-        """The name of the import.
-
-        Returns:
-            The name(tag name with alias) of tag.
-        """
-        return self.tag if not self.alias else " as ".join([self.tag, self.alias])  # type: ignore
-
-    def __hash__(self) -> int:
-        """Define a hash function for the import var.
-
-        Returns:
-            The hash of the var.
-        """
-        return hash((self.tag, self.is_default, self.alias, self.install, self.render))
-
-
-class NoRenderImportVar(ImportVar):
-    """A import that doesn't need to be rendered."""
-
-    render: Optional[bool] = False
-
-
 class CallableVar(BaseVar):
     """Decorate a Var-returning function to act as both a Var and a function.
 

+ 19 - 19
reflex/vars.pyi

@@ -6,11 +6,13 @@ from reflex import constants as constants
 from reflex.base import Base as Base
 from reflex.state import State as State
 from reflex.utils import console as console, format as format, types as types
+from reflex.utils.imports import ImportVar
 from types import FunctionType
 from typing import (
     Any,
     Callable,
     Dict,
+    Iterable,
     List,
     Optional,
     Set,
@@ -22,13 +24,24 @@ from typing import (
 USED_VARIABLES: Incomplete
 
 def get_unique_variable_name() -> str: ...
+def _encode_var(value: Var) -> str: ...
+def _decode_var(value: str) -> tuple[VarData, str]: ...
+def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
+
+class VarData(Base):
+    state: str
+    imports: dict[str, set[ImportVar]]
+    hooks: set[str]
+    @classmethod
+    def merge(cls, *others: VarData | None) -> VarData | None: ...
 
 class Var:
     _var_name: str
     _var_type: Type
-    _var_state: str = ""
     _var_is_local: bool = False
     _var_is_string: bool = False
+    _var_full_name_needs_state_prefix: bool = False
+    _var_data: VarData | None = None
     @classmethod
     def create(
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
@@ -38,7 +51,8 @@ class Var:
         cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
     ) -> Var: ...
     @classmethod
-    def __class_getitem__(cls, type_: str) -> _GenericAlias: ...
+    def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
     def equals(self, other: Var) -> bool: ...
     def to_string(self) -> Var: ...
     def __hash__(self) -> int: ...
@@ -95,15 +109,16 @@ class Var:
     def to(self, type_: Type) -> Var: ...
     @property
     def _var_full_name(self) -> str: ...
-    def _var_set_state(self, state: Type[State]) -> Any: ...
+    def _var_set_state(self, state: Type[State] | str) -> Any: ...
 
 @dataclass(eq=False)
 class BaseVar(Var):
     _var_name: str
     _var_type: Any
-    _var_state: str = ""
     _var_is_local: bool = False
     _var_is_string: bool = False
+    _var_full_name_needs_state_prefix: bool = False
+    _var_data: VarData | None = None
     def __hash__(self) -> int: ...
     def get_default_value(self) -> Any: ...
     def get_setter_name(self, include_state: bool = ...) -> str: ...
@@ -123,21 +138,6 @@ class ComputedVar(Var):
 
 def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
 
-class ImportVar(Base):
-    tag: Optional[str]
-    is_default: Optional[bool] = False
-    alias: Optional[str] = None
-    install: Optional[bool] = True
-    render: Optional[bool] = True
-    @property
-    def name(self) -> str: ...
-    def __hash__(self) -> int: ...
-
-class NoRenderImportVar(ImportVar):
-    """A import that doesn't need to be rendered."""
-
-def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
-
 class CallableVar(BaseVar):
     def __init__(self, fn: Callable[..., BaseVar]): ...
     def __call__(self, *args, **kwargs) -> BaseVar: ...

+ 1 - 1
tests/compiler/test_compiler.py

@@ -5,7 +5,7 @@ import pytest
 
 from reflex.compiler import compiler, utils
 from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
 
 
 @pytest.mark.parametrize(

+ 1 - 1
tests/components/layout/test_cond.py

@@ -110,7 +110,7 @@ def test_cond_no_else():
 
     # Props do not support the use of cond without else
     with pytest.raises(ValueError):
-        cond(True, "hello")
+        cond(True, "hello")  # type: ignore
 
 
 def test_mobile_only():

+ 162 - 2
tests/components/test_component.py

@@ -4,14 +4,16 @@ import pytest
 
 import reflex as rx
 from reflex.base import Base
+from reflex.components.base.bare import Bare
 from reflex.components.component import Component, CustomComponent, custom_component
 from reflex.components.layout.box import Box
 from reflex.constants import EventTriggers
-from reflex.event import EventHandler
+from reflex.event import EventChain, EventHandler
 from reflex.state import State
 from reflex.style import Style
 from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var, VarData
 
 
 @pytest.fixture
@@ -600,3 +602,161 @@ def test_format_component(component, rendered):
         rendered: The expected rendered component.
     """
     assert str(component) == rendered
+
+
+TEST_VAR = Var.create_safe("test")._replace(
+    merge_var_data=VarData(
+        hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+    )
+)
+FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
+STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False)
+EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
+ARG_VAR = Var.create("arg")
+
+
+class EventState(rx.State):
+    """State for testing event handlers with _get_vars."""
+
+    v: int = 42
+
+    def handler(self):
+        """A handler that does nothing."""
+
+    def handler2(self, arg):
+        """A handler that takes an arg.
+
+        Args:
+            arg: An arg.
+        """
+
+
+@pytest.mark.parametrize(
+    ("component", "exp_vars"),
+    (
+        pytest.param(
+            Bare.create(TEST_VAR),
+            [TEST_VAR],
+            id="direct-bare",
+        ),
+        pytest.param(
+            Bare.create(f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-bare",
+        ),
+        pytest.param(
+            rx.text(as_=TEST_VAR),
+            [TEST_VAR],
+            id="direct-prop",
+        ),
+        pytest.param(
+            rx.text(as_=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-prop",
+        ),
+        pytest.param(
+            rx.fragment(id=TEST_VAR),
+            [TEST_VAR],
+            id="direct-id",
+        ),
+        pytest.param(
+            rx.fragment(id=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-id",
+        ),
+        pytest.param(
+            rx.fragment(key=TEST_VAR),
+            [TEST_VAR],
+            id="direct-key",
+        ),
+        pytest.param(
+            rx.fragment(key=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-key",
+        ),
+        pytest.param(
+            rx.fragment(class_name=TEST_VAR),
+            [TEST_VAR],
+            id="direct-class_name",
+        ),
+        pytest.param(
+            rx.fragment(class_name=f"foo{TEST_VAR}bar"),
+            [FORMATTED_TEST_VAR],
+            id="fstring-class_name",
+        ),
+        pytest.param(
+            rx.fragment(special_props={TEST_VAR}),
+            [TEST_VAR],
+            id="direct-special_props",
+        ),
+        pytest.param(
+            rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}),
+            [FORMATTED_TEST_VAR],
+            id="fstring-special_props",
+        ),
+        pytest.param(
+            # custom_attrs cannot accept a Var directly as a value
+            rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}),
+            [TEST_VAR],
+            id="fstring-custom_attrs-nofmt",
+        ),
+        pytest.param(
+            rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}),
+            [FORMATTED_TEST_VAR],
+            id="fstring-custom_attrs",
+        ),
+        pytest.param(
+            rx.fragment(background_color=TEST_VAR),
+            [STYLE_VAR],
+            id="direct-background_color",
+        ),
+        pytest.param(
+            rx.fragment(background_color=f"foo{TEST_VAR}bar"),
+            [STYLE_VAR],
+            id="fstring-background_color",
+        ),
+        pytest.param(
+            rx.fragment(style={"background_color": TEST_VAR}),  # type: ignore
+            [STYLE_VAR],
+            id="direct-style-background_color",
+        ),
+        pytest.param(
+            rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}),  # type: ignore
+            [STYLE_VAR],
+            id="fstring-style-background_color",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EVENT_CHAIN_VAR),  # type: ignore
+            [EVENT_CHAIN_VAR],
+            id="direct-event-chain",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler),
+            [],
+            id="direct-event-handler",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler2(TEST_VAR)),  # type: ignore
+            [ARG_VAR, TEST_VAR],
+            id="direct-event-handler-arg",
+        ),
+        pytest.param(
+            rx.fragment(on_click=EventState.handler2(EventState.v)),  # type: ignore
+            [ARG_VAR, EventState.v],
+            id="direct-event-handler-arg2",
+        ),
+        pytest.param(
+            rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)),  # type: ignore
+            [ARG_VAR, TEST_VAR],
+            id="direct-event-handler-lambda",
+        ),
+    ),
+)
+def test_get_vars(component, exp_vars):
+    comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name)
+    assert len(comp_vars) == len(exp_vars)
+    for comp_var, exp_var in zip(
+        comp_vars,
+        sorted(exp_vars, key=lambda v: v._var_name),
+    ):
+        assert comp_var.equals(exp_var)

+ 3 - 3
tests/middleware/test_hydrate_middleware.py

@@ -104,7 +104,7 @@ async def test_preprocess(
         app=app, event=request.getfixturevalue(event_fixture), state=state
     )
     assert isinstance(update, StateUpdate)
-    assert update.delta == {state.get_name(): state.dict()}
+    assert update.delta == state.dict()
     events = update.events
     assert len(events) == 2
 
@@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
 
     update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
     assert isinstance(update, StateUpdate)
-    assert update.delta == {"test_state": state.dict()}
+    assert update.delta == state.dict()
     assert len(update.events) == 3
 
     # Apply the events.
@@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
         state=state,
     )
     assert isinstance(update, StateUpdate)
-    assert update.delta == {"test_state": state.dict()}
+    assert update.delta == state.dict()
     assert len(update.events) == 1
     assert isinstance(update, StateUpdate)
 

+ 1 - 3
tests/test_app.py

@@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
         )
 
     current_state = await app.state_manager.get_state(token)
-    state_dict = current_state.dict()
-    for substate in state.get_full_name().split(".")[1:]:
-        state_dict = state_dict[substate]
+    state_dict = current_state.dict()[state.get_full_name()]
     assert state_dict["img_list"] == [
         "image1.jpg",
         "image2.jpg",

+ 24 - 15
tests/test_state.py

@@ -324,11 +324,17 @@ def test_dict(test_state):
     Args:
         test_state: A state.
     """
-    substates = {"child_state", "child_state2"}
-    assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
-    assert (
-        set(test_state.dict(include_computed=False).keys())
-        == set(test_state.base_vars) | substates
+    substates = {
+        "test_state",
+        "test_state.child_state",
+        "test_state.child_state.grandchild_state",
+        "test_state.child_state2",
+    }
+    test_state_dict = test_state.dict()
+    assert set(test_state_dict) == substates
+    assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars)
+    assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set(
+        test_state.base_vars
     )
 
 
@@ -1081,9 +1087,9 @@ def test_computed_var_cached():
             return self.v
 
     cs = ComputedState()
-    assert cs.dict()["v"] == 0
+    assert cs.dict()[cs.get_full_name()]["v"] == 0
     assert comp_v_calls == 1
-    assert cs.dict()["comp_v"] == 0
+    assert cs.dict()[cs.get_full_name()]["comp_v"] == 0
     assert comp_v_calls == 1
     assert cs.comp_v == 0
     assert comp_v_calls == 1
@@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached():
     assert ps.dirty_vars == set()
     assert cs.dirty_vars == set()
 
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 2},
+    dict1 = ps.dict()
+    assert dict1[ps.get_full_name()] == {
         "no_cache_v": 1,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 4},
+    assert dict1[cs.get_full_name()] == {"dep_v": 2}
+    dict2 = ps.dict()
+    assert dict2[ps.get_full_name()] == {
         "no_cache_v": 3,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
-    assert ps.dict() == {
-        cs.get_name(): {"dep_v": 6},
+    assert dict2[cs.get_full_name()] == {"dep_v": 4}
+    dict3 = ps.dict()
+    assert dict3[ps.get_full_name()] == {
         "no_cache_v": 5,
         CompileVars.IS_HYDRATED: False,
         "router": formatted_router,
     }
+    assert dict3[cs.get_full_name()] == {"dep_v": 6}
     assert counter == 6
 
 
@@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables():
         items: List[Foo] = [Foo()]
 
     dict_val = MutableContainsBase().dict()
-    assert isinstance(dict_val["items"][0], dict)
+    assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict)
     val = json_dumps(dict_val)
     f_items = '[{"tags": ["123", "456"]}]'
     f_formatted_router = str(formatted_router).replace("'", '"')
     assert (
         val
-        == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
+        == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
     )
 
 

+ 2 - 1
tests/test_style.py

@@ -22,7 +22,8 @@ def test_convert(style_dict, expected):
         style_dict: The style to check.
         expected: The expected formatted style.
     """
-    assert style.convert(style_dict) == expected
+    converted_dict, _var_data = style.convert(style_dict)
+    assert converted_dict == expected
 
 
 @pytest.mark.parametrize(

+ 50 - 16
tests/test_var.py

@@ -7,18 +7,20 @@ from pandas import DataFrame
 
 from reflex.base import Base
 from reflex.state import State
+from reflex.utils.imports import ImportVar
 from reflex.vars import (
     BaseVar,
     ComputedVar,
-    ImportVar,
     Var,
 )
 
 test_vars = [
     BaseVar(_var_name="prop1", _var_type=int),
     BaseVar(_var_name="key", _var_type=str),
-    BaseVar(_var_name="value", _var_type=str, _var_state="state"),
-    BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True),
+    BaseVar(_var_name="value", _var_type=str)._var_set_state("state"),
+    BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state(
+        "state"
+    ),
     BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
 ]
 
@@ -263,7 +265,7 @@ def test_basic_operations(TestObj):
     assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
     assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
     assert (
-        str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar)
+        str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar)
         == "{state.foo.bar}"
     )
     assert str(abs(v(1))) == "{Math.abs(1)}"
@@ -274,7 +276,7 @@ def test_basic_operations(TestObj):
     assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
     assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
     assert (
-        str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse())
+        str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse())
         == "{[...state.foo].reverse()}"
     )
     assert (
@@ -288,11 +290,14 @@ def test_basic_operations(TestObj):
     [
         (v([1, 2, 3]), "[1, 2, 3]"),
         (v(["1", "2", "3"]), '["1", "2", "3"]'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=list), "foo"),
         (v((1, 2, 3)), "[1, 2, 3]"),
         (v(("1", "2", "3")), '["1", "2", "3"]'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"),
+        (
+            BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"),
+            "state.foo",
+        ),
         (BaseVar(_var_name="foo", _var_type=tuple), "foo"),
     ],
 )
@@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected):
     "var, expected",
     [
         (v("123"), json.dumps("123")),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=str), "foo"),
     ],
 )
 def test_str_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
     assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
     assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -328,7 +333,7 @@ def test_str_contains(var, expected):
     "var, expected",
     [
         (v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
-        (BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"),
+        (BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"),
         (BaseVar(_var_name="foo", _var_type=dict), "foo"),
     ],
 )
@@ -337,7 +342,7 @@ def test_dict_contains(var, expected):
     assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
     assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
     assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
-    other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+    other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
     other_var = BaseVar(_var_name="other", _var_type=str)
     assert (
         str(var.contains(other_state_var))
@@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index):
     "fixture,full_name",
     [
         ("ParentState", "parent_state.var_without_annotation"),
-        ("ChildState", "parent_state.child_state.var_without_annotation"),
+        ("ChildState", "parent_state__child_state.var_without_annotation"),
         (
             "GrandChildState",
-            "parent_state.child_state.grand_child_state.var_without_annotation",
+            "parent_state__child_state__grand_child_state.var_without_annotation",
         ),
         ("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
     ],
@@ -630,8 +635,8 @@ def test_import_var(import_var, expected):
     [
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
-            f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}",
-            "testing f-string with ${state.myvar}",
+            f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
+            'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
         ),
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
@@ -643,6 +648,35 @@ def test_fstrings(out, expected):
     assert out == expected
 
 
+@pytest.mark.parametrize(
+    ("value", "expect_state"),
+    [
+        ([1], ""),
+        ({"a": 1}, ""),
+        ([Var.create_safe(1)._var_set_state("foo")], "foo"),
+        ({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"),
+    ],
+)
+def test_extract_state_from_container(value, expect_state):
+    """Test that _var_state is extracted from containers containing BaseVar.
+
+    Args:
+        value: The value to create a var from.
+        expect_state: The expected state.
+    """
+    assert Var.create_safe(value)._var_state == expect_state
+
+
+def test_fstring_roundtrip():
+    """Test that f-string roundtrip carries state."""
+    var = BaseVar.create_safe("var")._var_set_state("state")
+    rt_var = Var.create_safe(f"{var}")
+    assert var._var_state == rt_var._var_state
+    assert var._var_full_name_needs_state_prefix
+    assert not rt_var._var_full_name_needs_state_prefix
+    assert rt_var._var_name == var._var_full_name
+
+
 @pytest.mark.parametrize(
     "var",
     [

+ 37 - 28
tests/utils/test_format.py

@@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import BaseVar, Var
-from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState
+from tests.test_state import (
+    ChildState,
+    ChildState2,
+    DateTimeState,
+    GrandchildState,
+    TestState,
+)
 
 
 def mock_event(arg):
@@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
             BaseVar(
                 _var_name="_",
                 _var_type=Any,
-                _var_state="",
                 _var_is_local=True,
                 _var_is_string=False,
             ),
@@ -515,40 +520,44 @@ formatted_router = {
         (
             TestState().dict(),  # type: ignore
             {
-                "array": [1, 2, 3.14],
-                "child_state": {
+                TestState.get_full_name(): {
+                    "array": [1, 2, 3.14],
+                    "complex": {
+                        1: {"prop1": 42, "prop2": "hello"},
+                        2: {"prop1": 42, "prop2": "hello"},
+                    },
+                    "dt": "1989-11-09 18:53:00+01:00",
+                    "fig": [],
+                    "is_hydrated": False,
+                    "key": "",
+                    "map_key": "a",
+                    "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
+                    "num1": 0,
+                    "num2": 3.14,
+                    "obj": {"prop1": 42, "prop2": "hello"},
+                    "sum": 3.14,
+                    "upper": "",
+                    "router": formatted_router,
+                },
+                ChildState.get_full_name(): {
                     "count": 23,
-                    "grandchild_state": {"value2": ""},
                     "value": "",
                 },
-                "child_state2": {"value": ""},
-                "complex": {
-                    1: {"prop1": 42, "prop2": "hello"},
-                    2: {"prop1": 42, "prop2": "hello"},
-                },
-                "dt": "1989-11-09 18:53:00+01:00",
-                "fig": [],
-                "is_hydrated": False,
-                "key": "",
-                "map_key": "a",
-                "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
-                "num1": 0,
-                "num2": 3.14,
-                "obj": {"prop1": 42, "prop2": "hello"},
-                "sum": 3.14,
-                "upper": "",
-                "router": formatted_router,
+                ChildState2.get_full_name(): {"value": ""},
+                GrandchildState.get_full_name(): {"value2": ""},
             },
         ),
         (
             DateTimeState().dict(),
             {
-                "d": "1989-11-09",
-                "dt": "1989-11-09 18:53:00+01:00",
-                "is_hydrated": False,
-                "t": "18:53:00+01:00",
-                "td": "11 days, 0:11:00",
-                "router": formatted_router,
+                DateTimeState.get_full_name(): {
+                    "d": "1989-11-09",
+                    "dt": "1989-11-09 18:53:00+01:00",
+                    "is_hydrated": False,
+                    "t": "18:53:00+01:00",
+                    "td": "11 days, 0:11:00",
+                    "router": formatted_router,
+                },
             },
         ),
     ],