Browse Source

use add_imports everywhere (#3448)

Thomas Brandého 11 months ago
parent
commit
462b023019
40 changed files with 469 additions and 304 deletions
  1. 7 6
      reflex/compiler/utils.py
  2. 14 13
      reflex/components/chakra/base.py
  3. 2 1
      reflex/components/chakra/base.pyi
  4. 8 6
      reflex/components/chakra/forms/input.py
  5. 2 1
      reflex/components/chakra/forms/input.pyi
  6. 8 3
      reflex/components/chakra/navigation/link.py
  7. 2 1
      reflex/components/chakra/navigation/link.pyi
  8. 20 30
      reflex/components/component.py
  9. 27 24
      reflex/components/core/banner.py
  10. 5 2
      reflex/components/core/banner.pyi
  11. 13 8
      reflex/components/core/cond.py
  12. 9 6
      reflex/components/core/match.py
  13. 5 7
      reflex/components/core/upload.py
  14. 1 1
      reflex/components/core/upload.pyi
  15. 26 23
      reflex/components/datadisplay/code.py
  16. 3 2
      reflex/components/datadisplay/code.pyi
  17. 59 43
      reflex/components/datadisplay/dataeditor.py
  18. 50 4
      reflex/components/datadisplay/dataeditor.pyi
  19. 11 12
      reflex/components/el/elements/forms.py
  20. 2 1
      reflex/components/el/elements/forms.pyi
  21. 9 6
      reflex/components/gridjs/datatable.py
  22. 3 1
      reflex/components/gridjs/datatable.pyi
  23. 33 40
      reflex/components/markdown/markdown.py
  24. 3 3
      reflex/components/markdown/markdown.pyi
  25. 9 8
      reflex/components/moment/moment.py
  26. 2 1
      reflex/components/moment/moment.pyi
  27. 2 3
      reflex/components/radix/primitives/accordion.py
  28. 1 2
      reflex/components/radix/primitives/accordion.pyi
  29. 3 3
      reflex/components/radix/themes/base.py
  30. 2 2
      reflex/components/radix/themes/base.pyi
  31. 8 3
      reflex/components/radix/themes/typography/link.py
  32. 2 1
      reflex/components/radix/themes/typography/link.pyi
  33. 10 7
      reflex/components/suneditor/editor.py
  34. 2 1
      reflex/components/suneditor/editor.pyi
  35. 34 5
      reflex/utils/imports.py
  36. 20 2
      reflex/vars.py
  37. 2 2
      reflex/vars.pyi
  38. 2 3
      tests/compiler/test_compiler.py
  39. 12 16
      tests/components/test_component.py
  40. 36 1
      tests/utils/test_imports.py

+ 7 - 6
reflex/compiler/utils.py

@@ -28,13 +28,14 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
 from reflex.state import BaseState, Cookie, LocalStorage
 from reflex.style import Style
 from reflex.utils import console, format, imports, path_ops
+from reflex.utils.imports import ImportVar, ParsedImportDict
 from reflex.vars import Var
 
 # To re-export this function.
 merge_imports = imports.merge_imports
 
 
-def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
+def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
     """Compile an import statement.
 
     Args:
@@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list
     return default, list(rest)
 
 
-def validate_imports(import_dict: imports.ImportDict):
+def validate_imports(import_dict: ParsedImportDict):
     """Verify that the same Tag is not used in multiple import.
 
     Args:
@@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict):
                 used_tags[import_name] = lib
 
 
-def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
+def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
     """Compile an import dict.
 
     Args:
@@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
     Returns:
         The list of import dict.
     """
-    collapsed_import_dict = imports.collapse_imports(import_dict)
+    collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
     validate_imports(collapsed_import_dict)
     import_dicts = []
     for lib, fields in collapsed_import_dict.items():
@@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
 
 def compile_custom_component(
     component: CustomComponent,
-) -> tuple[dict, imports.ImportDict]:
+) -> tuple[dict, ParsedImportDict]:
     """Compile a custom component.
 
     Args:
@@ -244,7 +245,7 @@ def compile_custom_component(
     render = component.get_component(component)
 
     # Get the imports.
-    imports = {
+    imports: ParsedImportDict = {
         lib: fields
         for lib, fields in render._get_all_imports().items()
         if lib != component.library

+ 14 - 13
reflex/components/chakra/base.py

@@ -5,14 +5,14 @@ from functools import lru_cache
 from typing import List, Literal
 
 from reflex.components.component import Component
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 
 class ChakraComponent(Component):
     """A component that wraps a Chakra component."""
 
-    library = "@chakra-ui/react@2.6.1"
+    library: str = "@chakra-ui/react@2.6.1"  # type: ignore
     lib_dependencies: List[str] = [
         "@chakra-ui/system@2.5.7",
         "framer-motion@10.16.4",
@@ -35,14 +35,14 @@ class ChakraComponent(Component):
 
     @classmethod
     @lru_cache(maxsize=None)
-    def _get_dependencies_imports(cls) -> imports.ImportDict:
+    def _get_dependencies_imports(cls) -> ImportDict:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
             The dependencies imports of the component.
         """
         return {
-            dep: [imports.ImportVar(tag=None, render=False)]
+            dep: [ImportVar(tag=None, render=False)]
             for dep in [
                 "@chakra-ui/system@2.5.7",
                 "framer-motion@10.16.4",
@@ -70,15 +70,16 @@ class ChakraProvider(ChakraComponent):
             ),
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        _imports = super()._get_imports()
-        _imports.setdefault(self.__fields__["library"].default, []).append(
-            imports.ImportVar(tag="extendTheme", is_default=False),
-        )
-        _imports.setdefault("/utils/theme.js", []).append(
-            imports.ImportVar(tag="theme", is_default=True),
-        )
-        return _imports
+    def add_imports(self) -> ImportDict:
+        """Add imports for the ChakraProvider component.
+
+        Returns:
+            The import dict for the component.
+        """
+        return {
+            self.library: ImportVar(tag="extendTheme", is_default=False),
+            "/utils/theme.js": ImportVar(tag="theme", is_default=True),
+        }
 
     @staticmethod
     @lru_cache(maxsize=None)

+ 2 - 1
reflex/components/chakra/base.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from functools import lru_cache
 from typing import List, Literal
 from reflex.components.component import Component
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 class ChakraComponent(Component):
@@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent):
             A new ChakraProvider component.
         """
         ...
+    def add_imports(self) -> ImportDict: ...
 
 chakra_provider = ChakraProvider.create()
 

+ 8 - 6
reflex/components/chakra/forms/input.py

@@ -11,7 +11,7 @@ from reflex.components.component import Component
 from reflex.components.core.debounce import DebounceInput
 from reflex.components.literals import LiteralInputType
 from reflex.constants import EventTriggers, MemoizationMode
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 
@@ -59,11 +59,13 @@ class Input(ChakraComponent):
     # The name of the form field
     name: Var[str]
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            {"/utils/state": {imports.ImportVar(tag="set_val")}},
-        )
+    def add_imports(self) -> ImportDict:
+        """Add imports for the Input component.
+
+        Returns:
+            The import dict.
+        """
+        return {"/utils/state": "set_val"}
 
     def get_event_triggers(self) -> Dict[str, Any]:
         """Get the event triggers that pass the component's value to the handler.

+ 2 - 1
reflex/components/chakra/forms/input.pyi

@@ -17,10 +17,11 @@ from reflex.components.component import Component
 from reflex.components.core.debounce import DebounceInput
 from reflex.components.literals import LiteralInputType
 from reflex.constants import EventTriggers, MemoizationMode
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 class Input(ChakraComponent):
+    def add_imports(self) -> ImportDict: ...
     def get_event_triggers(self) -> Dict[str, Any]: ...
     @overload
     @classmethod

+ 8 - 3
reflex/components/chakra/navigation/link.py

@@ -4,7 +4,7 @@
 from reflex.components.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.next.link import NextLink
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import BaseVar, Var
 
 next_link = NextLink.create()
@@ -32,8 +32,13 @@ class Link(ChakraComponent):
     # If true, the link will open in new tab.
     is_external: Var[bool]
 
-    def _get_imports(self) -> imports.ImportDict:
-        return {**super()._get_imports(), **next_link._get_imports()}
+    def add_imports(self) -> ImportDict:
+        """Add imports for the link component.
+
+        Returns:
+            The import dict.
+        """
+        return next_link._get_imports()  # type: ignore
 
     @classmethod
     def create(cls, *children, **props) -> Component:

+ 2 - 1
reflex/components/chakra/navigation/link.pyi

@@ -10,12 +10,13 @@ from reflex.style import Style
 from reflex.components.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.next.link import NextLink
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import BaseVar, Var
 
 next_link = NextLink.create()
 
 class Link(ChakraComponent):
+    def add_imports(self) -> ImportDict: ...
     @overload
     @classmethod
     def create(  # type: ignore

+ 20 - 30
reflex/components/component.py

@@ -44,7 +44,7 @@ from reflex.event import (
 )
 from reflex.style import Style, format_as_emotion
 from reflex.utils import console, format, imports, types
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.utils.serializers import serializer
 from reflex.vars import BaseVar, Var, VarData
 
@@ -95,7 +95,7 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> ParsedImportDict:
         """Get all the libraries and fields that are used by the component.
 
         Returns:
@@ -213,7 +213,7 @@ class Component(BaseComponent, ABC):
     # State class associated with this component instance
     State: Optional[Type[reflex.state.State]] = None
 
-    def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
+    def add_imports(self) -> ImportDict | list[ImportDict]:
         """Add imports for the component.
 
         This method should be implemented by subclasses to add new imports for the component.
@@ -1224,7 +1224,7 @@ class Component(BaseComponent, ABC):
         # Return the dynamic imports
         return dynamic_imports
 
-    def _get_props_imports(self) -> List[str]:
+    def _get_props_imports(self) -> List[ParsedImportDict]:
         """Get the imports needed for components props.
 
         Returns:
@@ -1250,7 +1250,7 @@ class Component(BaseComponent, ABC):
             or format.format_library_name(dep or "") in self.transpile_packages
         )
 
-    def _get_dependencies_imports(self) -> imports.ImportDict:
+    def _get_dependencies_imports(self) -> ParsedImportDict:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
@@ -1267,7 +1267,7 @@ class Component(BaseComponent, ABC):
             for dep in self.lib_dependencies
         }
 
-    def _get_hooks_imports(self) -> imports.ImportDict:
+    def _get_hooks_imports(self) -> ParsedImportDict:
         """Get the imports required by certain hooks.
 
         Returns:
@@ -1308,7 +1308,7 @@ class Component(BaseComponent, ABC):
 
         return imports.merge_imports(_imports, *other_imports)
 
-    def _get_imports(self) -> imports.ImportDict:
+    def _get_imports(self) -> ParsedImportDict:
         """Get all the libraries and fields that are used by the component.
 
         Returns:
@@ -1328,25 +1328,15 @@ class Component(BaseComponent, ABC):
             var._var_data.imports for var in self._get_vars() if var._var_data
         ]
 
-        # If any subclass implements add_imports, merge the imports.
-        def _make_list(
-            value: str | ImportVar | list[str | ImportVar],
-        ) -> list[str | ImportVar]:
-            if isinstance(value, (str, ImportVar)):
-                return [value]
-            return value
-
-        _added_import_dicts = []
+        added_import_dicts: list[ParsedImportDict] = []
         for clz in self._iter_parent_classes_with_method("add_imports"):
-            _added_import_dicts.append(
-                {
-                    package: [
-                        ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
-                        for tag in _make_list(maybe_tags)
-                    ]
-                    for package, maybe_tags in clz.add_imports(self).items()
-                }
-            )
+            list_of_import_dict = clz.add_imports(self)
+
+            if not isinstance(list_of_import_dict, list):
+                list_of_import_dict = [list_of_import_dict]
+
+            for import_dict in list_of_import_dict:
+                added_import_dicts.append(parse_imports(import_dict))
 
         return imports.merge_imports(
             *self._get_props_imports(),
@@ -1355,10 +1345,10 @@ class Component(BaseComponent, ABC):
             _imports,
             event_imports,
             *var_imports,
-            *_added_import_dicts,
+            *added_import_dicts,
         )
 
-    def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
+    def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
         """Get all the libraries and fields that are used by the component and its children.
 
         Args:
@@ -1453,7 +1443,7 @@ class Component(BaseComponent, ABC):
             **self._get_special_hooks(),
         }
 
-    def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
+    def _get_added_hooks(self) -> dict[str, ImportDict]:
         """Get the hooks added via `add_hooks` method.
 
         Returns:
@@ -1842,7 +1832,7 @@ memo = custom_component
 class NoSSRComponent(Component):
     """A dynamic component that is not rendered on the server."""
 
-    def _get_imports(self) -> imports.ImportDict:
+    def _get_imports(self) -> ParsedImportDict:
         """Get the imports for the component.
 
         Returns:
@@ -2185,7 +2175,7 @@ class StatefulComponent(BaseComponent):
         """
         return {}
 
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> ParsedImportDict:
         """Get all the libraries and fields that are used by the component.
 
         Returns:

+ 27 - 24
reflex/components/core/banner.py

@@ -19,7 +19,7 @@ from reflex.components.radix.themes.typography.text import Text
 from reflex.components.sonner.toast import Toaster, ToastProps
 from reflex.constants import Dirs, Hooks, Imports
 from reflex.constants.compiler import CompileVars
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.serializers import serialize
 from reflex.vars import Var, VarData
 
@@ -65,10 +65,15 @@ has_too_many_connection_errors: Var = Var.create_safe(
 class WebsocketTargetURL(Bare):
     """A component that renders the websocket target URL."""
 
-    def _get_imports(self) -> imports.ImportDict:
+    def add_imports(self) -> ImportDict:
+        """Add imports for the websocket target URL component.
+
+        Returns:
+            The import dict.
+        """
         return {
-            f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
-            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
+            f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
+            "/env.json": [ImportVar(tag="env", is_default=True)],
         }
 
     @classmethod
@@ -98,7 +103,7 @@ def default_connection_error() -> list[str | Var | Component]:
 class ConnectionToaster(Toaster):
     """A connection toaster component."""
 
-    def add_hooks(self) -> list[str]:
+    def add_hooks(self) -> list[str | Var]:
         """Add the hooks for the connection toaster.
 
         Returns:
@@ -116,7 +121,7 @@ class ConnectionToaster(Toaster):
             duration=120000,
             id=toast_id,
         )
-        hook = Var.create(
+        hook = Var.create_safe(
             f"""
 const toast_props = {serialize(props)};
 const [userDismissed, setUserDismissed] = useState(false);
@@ -135,22 +140,17 @@ useEffect(() => {{
 }}, [{connect_errors}]);""",
             _var_is_string=False,
         )
-
-        hook._var_data = VarData.merge(  # type: ignore
+        imports: ImportDict = {
+            "react": ["useEffect", "useState"],
+            **target_url._get_imports(),  # type: ignore
+        }
+        hook._var_data = VarData.merge(
             connect_errors._var_data,
-            VarData(
-                imports={
-                    "react": [
-                        imports.ImportVar(tag="useEffect"),
-                        imports.ImportVar(tag="useState"),
-                    ],
-                    **target_url._get_imports(),
-                }
-            ),
+            VarData(imports=imports),
         )
         return [
             Hooks.EVENTS,
-            hook,  # type: ignore
+            hook,
         ]
 
 
@@ -216,10 +216,11 @@ class WifiOffPulse(Icon):
     """A wifi_off icon with an animated opacity pulse."""
 
     @classmethod
-    def create(cls, **props) -> Component:
+    def create(cls, *children, **props) -> Icon:
         """Create a wifi_off icon with an animated opacity pulse.
 
         Args:
+            *children: The children of the component.
             **props: The properties of the component.
 
         Returns:
@@ -237,11 +238,13 @@ class WifiOffPulse(Icon):
             **props,
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            {"@emotion/react": [imports.ImportVar(tag="keyframes")]},
-        )
+    def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
+        """Add imports for the WifiOffPulse component.
+
+        Returns:
+            The import dict.
+        """
+        return {"@emotion/react": [ImportVar(tag="keyframes")]}
 
     def _get_custom_code(self) -> str | None:
         return """

+ 5 - 2
reflex/components/core/banner.pyi

@@ -23,7 +23,7 @@ from reflex.components.radix.themes.typography.text import Text
 from reflex.components.sonner.toast import Toaster, ToastProps
 from reflex.constants import Dirs, Hooks, Imports
 from reflex.constants.compiler import CompileVars
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.serializers import serialize
 from reflex.vars import Var, VarData
 
@@ -35,6 +35,7 @@ has_connection_errors: Var
 has_too_many_connection_errors: Var
 
 class WebsocketTargetURL(Bare):
+    def add_imports(self) -> ImportDict: ...
     @overload
     @classmethod
     def create(  # type: ignore
@@ -104,7 +105,7 @@ class WebsocketTargetURL(Bare):
 def default_connection_error() -> list[str | Var | Component]: ...
 
 class ConnectionToaster(Toaster):
-    def add_hooks(self) -> list[str]: ...
+    def add_hooks(self) -> list[str | Var]: ...
     @overload
     @classmethod
     def create(  # type: ignore
@@ -430,6 +431,7 @@ class WifiOffPulse(Icon):
         """Create a wifi_off icon with an animated opacity pulse.
 
         Args:
+            *children: The children of the component.
             size: The size of the icon in pixels.
             style: The style of the component.
             key: A unique key for the component.
@@ -443,6 +445,7 @@ class WifiOffPulse(Icon):
             The icon component with default props applied.
         """
         ...
+    def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: ...
 
 class ConnectionPulser(Div):
     @overload

+ 13 - 8
reflex/components/core/cond.py

@@ -10,11 +10,12 @@ from reflex.components.tags import CondTag, Tag
 from reflex.constants import Dirs
 from reflex.constants.colors import Color
 from reflex.style import LIGHT_COLOR_MODE, color_mode
-from reflex.utils import format, imports
+from reflex.utils import format
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var, VarData
 
-_IS_TRUE_IMPORT = {
-    f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
+_IS_TRUE_IMPORT: ImportDict = {
+    f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
 }
 
 
@@ -96,12 +97,16 @@ class Cond(MemoizationLeaf):
             cond_state=f"isTrue({self.cond._var_full_name})",
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            getattr(self.cond._var_data, "imports", {}),
-            _IS_TRUE_IMPORT,
+    def add_imports(self) -> ImportDict:
+        """Add imports for the Cond component.
+
+        Returns:
+            The import dict for the component.
+        """
+        cond_imports: dict[str, str | ImportVar | list[str | ImportVar]] = getattr(
+            self.cond._var_data, "imports", {}
         )
+        return {**cond_imports, **_IS_TRUE_IMPORT}
 
 
 @overload

+ 9 - 6
reflex/components/core/match.py

@@ -8,8 +8,9 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea
 from reflex.components.core.colors import Color
 from reflex.components.tags import MatchTag, Tag
 from reflex.style import Style
-from reflex.utils import format, imports, types
+from reflex.utils import format, types
 from reflex.utils.exceptions import MatchTypeError
+from reflex.utils.imports import ImportDict
 from reflex.vars import BaseVar, Var, VarData
 
 
@@ -268,11 +269,13 @@ class Match(MemoizationLeaf):
         tag.name = "match"
         return dict(tag)
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            getattr(self.cond._var_data, "imports", {}),
-        )
+    def add_imports(self) -> ImportDict:
+        """Add imports for the Match component.
+
+        Returns:
+            The import dict.
+        """
+        return getattr(self.cond._var_data, "imports", {})
 
 
 match = Match.create

+ 5 - 7
reflex/components/core/upload.py

@@ -19,17 +19,15 @@ from reflex.event import (
     call_script,
     parse_args_spec,
 )
-from reflex.utils import imports
+from reflex.utils.imports import ImportVar
 from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 DEFAULT_UPLOAD_ID: str = "default"
 
 upload_files_context_var_data: VarData = VarData(
     imports={
-        "react": [imports.ImportVar(tag="useContext")],
-        f"/{Dirs.CONTEXTS_PATH}": [
-            imports.ImportVar(tag="UploadFilesContext"),
-        ],
+        "react": "useContext",
+        f"/{Dirs.CONTEXTS_PATH}": "UploadFilesContext",
     },
     hooks={
         "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@@ -133,8 +131,8 @@ uploaded_files_url_prefix: Var = Var.create_safe(
     _var_is_string=False,
     _var_data=VarData(
         imports={
-            f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
-            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
+            f"/{Dirs.STATE_PATH}": "getBackendURL",
+            "/env.json": ImportVar(tag="env", is_default=True),
         }
     ),
 )

+ 1 - 1
reflex/components/core/upload.pyi

@@ -23,7 +23,7 @@ from reflex.event import (
     call_script,
     parse_args_spec,
 )
-from reflex.utils import imports
+from reflex.utils.imports import ImportVar
 from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 DEFAULT_UPLOAD_ID: str

+ 26 - 23
reflex/components/datadisplay/code.py

@@ -12,8 +12,8 @@ from reflex.components.core.cond import color_mode_cond
 from reflex.constants.colors import Color
 from reflex.event import set_clipboard
 from reflex.style import Style
-from reflex.utils import format, imports
-from reflex.utils.imports import ImportVar
+from reflex.utils import format
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 LiteralCodeBlockTheme = Literal[
@@ -381,42 +381,45 @@ class CodeBlock(Component):
     # Props passed down to the code tag.
     code_tag_props: Var[Dict[str, str]]
 
-    def _get_imports(self) -> imports.ImportDict:
-        merged_imports = super()._get_imports()
-        # Get all themes from a cond literal
+    def add_imports(self) -> ImportDict:
+        """Add imports for the CodeBlock component.
+
+        Returns:
+            The import dict.
+        """
+        imports_: ImportDict = {}
         themes = re.findall(r"`(.*?)`", self.theme._var_name)
         if not themes:
             themes = [self.theme._var_name]
-        merged_imports = imports.merge_imports(
-            merged_imports,
+
+        imports_.update(
             {
-                f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": {
+                f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": [
                     ImportVar(
                         tag=format.to_camel_case(self.convert_theme_name(theme)),
                         is_default=True,
                         install=False,
                     )
-                }
+                ]
                 for theme in themes
-            },
+            }
         )
+
         if (
             self.language is not None
             and self.language._var_name in LiteralCodeLanguage.__args__  # type: ignore
         ):
-            merged_imports = imports.merge_imports(
-                merged_imports,
-                {
-                    f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}": {
-                        ImportVar(
-                            tag=format.to_camel_case(self.language._var_name),
-                            is_default=True,
-                            install=False,
-                        )
-                    }
-                },
-            )
-        return merged_imports
+            imports_[
+                f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}"
+            ] = [
+                ImportVar(
+                    tag=format.to_camel_case(self.language._var_name),
+                    is_default=True,
+                    install=False,
+                )
+            ]
+
+        return imports_
 
     def _get_custom_code(self) -> Optional[str]:
         if (

+ 3 - 2
reflex/components/datadisplay/code.pyi

@@ -17,8 +17,8 @@ from reflex.components.core.cond import color_mode_cond
 from reflex.constants.colors import Color
 from reflex.event import set_clipboard
 from reflex.style import Style
-from reflex.utils import format, imports
-from reflex.utils.imports import ImportVar
+from reflex.utils import format
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 LiteralCodeBlockTheme = Literal[
@@ -351,6 +351,7 @@ LiteralCodeLanguage = Literal[
 ]
 
 class CodeBlock(Component):
+    def add_imports(self) -> ImportDict: ...
     @overload
     @classmethod
     def create(  # type: ignore

+ 59 - 43
reflex/components/datadisplay/dataeditor.py

@@ -2,13 +2,14 @@
 from __future__ import annotations
 
 from enum import Enum
-from typing import Any, Callable, Dict, List, Literal, Optional, Union
+from typing import Any, Dict, List, Literal, Optional, Union
 
 from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
-from reflex.utils import console, format, imports, types
-from reflex.utils.imports import ImportVar
+from reflex.event import EventHandler
+from reflex.utils import console, format, types
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.serializers import serializer
 from reflex.vars import Var, get_unique_variable_name
 
@@ -205,51 +206,66 @@ class DataEditor(NoSSRComponent):
     # global theme
     theme: Var[Union[DataEditorTheme, Dict]]
 
-    def _get_imports(self):
-        return imports.merge_imports(
-            super()._get_imports(),
-            {
-                "": {
-                    ImportVar(
-                        tag=f"{format.format_library_name(self.library)}/dist/index.css"
-                    )
-                },
-                self.library: {ImportVar(tag="GridCellKind")},
-                "/utils/helpers/dataeditor.js": {
-                    ImportVar(
-                        tag=f"formatDataEditorCells", is_default=False, install=False
-                    ),
-                },
-            },
-        )
+    # Triggered when a cell is activated.
+    on_cell_activated: EventHandler[lambda pos: [pos]]
 
-    def get_event_triggers(self) -> Dict[str, Callable]:
-        """The event triggers of the component.
+    # Triggered when a cell is clicked.
+    on_cell_clicked: EventHandler[lambda pos: [pos]]
 
-        Returns:
-            The dict describing the event triggers.
-        """
+    # Triggered when a cell is right-clicked.
+    on_cell_context_menu: EventHandler[lambda pos: [pos]]
+
+    # Triggered when a cell is edited.
+    on_cell_edited: EventHandler[lambda pos, data: [pos, data]]
+
+    # Triggered when a group header is clicked.
+    on_group_header_clicked: EventHandler[lambda pos, data: [pos, data]]
+
+    # Triggered when a group header is right-clicked.
+    on_group_header_context_menu: EventHandler[lambda grp_idx, data: [grp_idx, data]]
+
+    # Triggered when a group header is renamed.
+    on_group_header_renamed: EventHandler[lambda idx, val: [idx, val]]
+
+    # Triggered when a header is clicked.
+    on_header_clicked: EventHandler[lambda pos: [pos]]
+
+    # Triggered when a header is right-clicked.
+    on_header_context_menu: EventHandler[lambda pos: [pos]]
 
-        def edit_sig(pos, data: dict[str, Any]):
-            return [pos, data]
+    # Triggered when a header menu is clicked.
+    on_header_menu_click: EventHandler[lambda col, pos: [col, pos]]
 
+    # Triggered when an item is hovered.
+    on_item_hovered: EventHandler[lambda pos: [pos]]
+
+    # Triggered when a selection is deleted.
+    on_delete: EventHandler[lambda selection: [selection]]
+
+    # Triggered when editing is finished.
+    on_finished_editing: EventHandler[lambda new_value, movement: [new_value, movement]]
+
+    # Triggered when a row is appended.
+    on_row_appended: EventHandler[lambda: []]
+
+    # Triggered when the selection is cleared.
+    on_selection_cleared: EventHandler[lambda: []]
+
+    # Triggered when a column is resized.
+    on_column_resize: EventHandler[lambda col, width: [col, width]]
+
+    def add_imports(self) -> ImportDict:
+        """Add imports for the component.
+
+        Returns:
+            The import dict.
+        """
         return {
-            "on_cell_activated": lambda pos: [pos],
-            "on_cell_clicked": lambda pos: [pos],
-            "on_cell_context_menu": lambda pos: [pos],
-            "on_cell_edited": edit_sig,
-            "on_group_header_clicked": edit_sig,
-            "on_group_header_context_menu": lambda grp_idx, data: [grp_idx, data],
-            "on_group_header_renamed": lambda idx, val: [idx, val],
-            "on_header_clicked": lambda pos: [pos],
-            "on_header_context_menu": lambda pos: [pos],
-            "on_header_menu_click": lambda col, pos: [col, pos],
-            "on_item_hovered": lambda pos: [pos],
-            "on_delete": lambda selection: [selection],
-            "on_finished_editing": lambda new_value, movement: [new_value, movement],
-            "on_row_appended": lambda: [],
-            "on_selection_cleared": lambda: [],
-            "on_column_resize": lambda col, width: [col, width],
+            "": f"{format.format_library_name(self.library)}/dist/index.css",
+            self.library: "GridCellKind",
+            "/utils/helpers/dataeditor.js": ImportVar(
+                tag="formatDataEditorCells", is_default=False, install=False
+            ),
         }
 
     def add_hooks(self) -> list[str]:

+ 50 - 4
reflex/components/datadisplay/dataeditor.pyi

@@ -8,12 +8,13 @@ from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from enum import Enum
-from typing import Any, Callable, Dict, List, Literal, Optional, Union
+from typing import Any, Dict, List, Literal, Optional, Union
 from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.components.literals import LiteralRowMarker
-from reflex.utils import console, format, imports, types
-from reflex.utils.imports import ImportVar
+from reflex.event import EventHandler
+from reflex.utils import console, format, types
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.serializers import serializer
 from reflex.vars import Var, get_unique_variable_name
 
@@ -80,7 +81,7 @@ class DataEditorTheme(Base):
     text_medium: Optional[str]
 
 class DataEditor(NoSSRComponent):
-    def get_event_triggers(self) -> Dict[str, Callable]: ...
+    def add_imports(self) -> ImportDict: ...
     def add_hooks(self) -> list[str]: ...
     @overload
     @classmethod
@@ -136,6 +137,9 @@ class DataEditor(NoSSRComponent):
         class_name: Optional[Any] = None,
         autofocus: Optional[bool] = None,
         custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
+        on_blur: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_cell_activated: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
@@ -148,15 +152,27 @@ class DataEditor(NoSSRComponent):
         on_cell_edited: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_click: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_column_resize: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_context_menu: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_delete: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_double_click: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_finished_editing: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_focus: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_group_header_clicked: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
@@ -178,12 +194,42 @@ class DataEditor(NoSSRComponent):
         on_item_hovered: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_mount: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_down: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_enter: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_leave: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_move: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_out: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_over: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
+        on_mouse_up: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_row_appended: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_scroll: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         on_selection_cleared: Optional[
             Union[EventHandler, EventSpec, list, function, BaseVar]
         ] = None,
+        on_unmount: Optional[
+            Union[EventHandler, EventSpec, list, function, BaseVar]
+        ] = None,
         **props
     ) -> "DataEditor":
         """Create the DataEditor component.

+ 11 - 12
reflex/components/el/elements/forms.py

@@ -11,8 +11,8 @@ from reflex.components.el.element import Element
 from reflex.components.tags.tag import Tag
 from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
-from reflex.utils import imports
 from reflex.utils.format import format_event_chain
+from reflex.utils.imports import ImportDict
 from reflex.vars import BaseVar, Var
 
 from .base import BaseHTML
@@ -169,17 +169,16 @@ class Form(BaseHTML):
         ).hexdigest()
         return form
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            {
-                "react": {imports.ImportVar(tag="useCallback")},
-                f"/{Dirs.STATE_PATH}": {
-                    imports.ImportVar(tag="getRefValue"),
-                    imports.ImportVar(tag="getRefValues"),
-                },
-            },
-        )
+    def add_imports(self) -> ImportDict:
+        """Add imports needed by the form component.
+
+        Returns:
+            The imports for the form component.
+        """
+        return {
+            "react": "useCallback",
+            f"/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"],
+        }
 
     def add_hooks(self) -> list[str]:
         """Add hooks for the form.

+ 2 - 1
reflex/components/el/elements/forms.pyi

@@ -14,8 +14,8 @@ from reflex.components.el.element import Element
 from reflex.components.tags.tag import Tag
 from reflex.constants import Dirs, EventTriggers
 from reflex.event import EventChain
-from reflex.utils import imports
 from reflex.utils.format import format_event_chain
+from reflex.utils.imports import ImportDict
 from reflex.vars import BaseVar, Var
 from .base import BaseHTML
 
@@ -581,6 +581,7 @@ class Form(BaseHTML):
             The form component.
         """
         ...
+    def add_imports(self) -> ImportDict: ...
     def add_hooks(self) -> list[str]: ...
 
 class Input(BaseHTML):

+ 9 - 6
reflex/components/gridjs/datatable.py

@@ -6,7 +6,8 @@ from typing import Any, Dict, List, Union
 
 from reflex.components.component import Component
 from reflex.components.tags import Tag
-from reflex.utils import imports, types
+from reflex.utils import types
+from reflex.utils.imports import ImportDict
 from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, ComputedVar, Var
 
@@ -102,11 +103,13 @@ class DataTable(Gridjs):
             **props,
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
-        )
+    def add_imports(self) -> ImportDict:
+        """Add the imports for the datatable component.
+
+        Returns:
+            The import dict for the component.
+        """
+        return {"": "gridjs/dist/theme/mermaid.css"}
 
     def _render(self) -> Tag:
         if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type):

+ 3 - 1
reflex/components/gridjs/datatable.pyi

@@ -10,7 +10,8 @@ from reflex.style import Style
 from typing import Any, Dict, List, Union
 from reflex.components.component import Component
 from reflex.components.tags import Tag
-from reflex.utils import imports, types
+from reflex.utils import types
+from reflex.utils.imports import ImportDict
 from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, ComputedVar, Var
 
@@ -180,3 +181,4 @@ class DataTable(Gridjs):
             ValueError: If a pandas dataframe is passed in and columns are also provided.
         """
         ...
+    def add_imports(self) -> ImportDict: ...

+ 33 - 40
reflex/components/markdown/markdown.py

@@ -7,7 +7,6 @@ from functools import lru_cache
 from hashlib import md5
 from typing import Any, Callable, Dict, Union
 
-from reflex.compiler import utils
 from reflex.components.component import Component, CustomComponent
 from reflex.components.radix.themes.layout.list import (
     ListItem,
@@ -18,8 +17,8 @@ from reflex.components.radix.themes.typography.heading import Heading
 from reflex.components.radix.themes.typography.link import Link
 from reflex.components.radix.themes.typography.text import Text
 from reflex.components.tags.tag import Tag
-from reflex.utils import imports, types
-from reflex.utils.imports import ImportVar
+from reflex.utils import types
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 # Special vars used in the component map.
@@ -145,47 +144,41 @@ class Markdown(Component):
 
         return custom_components
 
-    def _get_imports(self) -> imports.ImportDict:
-        # Import here to avoid circular imports.
+    def add_imports(self) -> ImportDict | list[ImportDict]:
+        """Add imports for the markdown component.
+
+        Returns:
+            The imports for the markdown component.
+        """
         from reflex.components.datadisplay.code import CodeBlock
         from reflex.components.radix.themes.typography.code import Code
 
-        imports = super()._get_imports()
-
-        # Special markdown imports.
-        imports.update(
+        return [
             {
-                "": [ImportVar(tag="katex/dist/katex.min.css")],
-                "remark-math@5.1.1": [
-                    ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
-                ],
-                "remark-gfm@3.0.1": [
-                    ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
-                ],
-                "remark-unwrap-images@4.0.0": [
-                    ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True)
-                ],
-                "rehype-katex@6.0.3": [
-                    ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
-                ],
-                "rehype-raw@6.1.1": [
-                    ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
-                ],
-            }
-        )
-
-        # Get the imports for each component.
-        for component in self.component_map.values():
-            imports = utils.merge_imports(
-                imports, component(_MOCK_ARG)._get_all_imports()
-            )
-
-        # Get the imports for the code components.
-        imports = utils.merge_imports(
-            imports, CodeBlock.create(theme="light")._get_imports()
-        )
-        imports = utils.merge_imports(imports, Code.create()._get_imports())
-        return imports
+                "": "katex/dist/katex.min.css",
+                "remark-math@5.1.1": ImportVar(
+                    tag=_REMARK_MATH._var_name, is_default=True
+                ),
+                "remark-gfm@3.0.1": ImportVar(
+                    tag=_REMARK_GFM._var_name, is_default=True
+                ),
+                "remark-unwrap-images@4.0.0": ImportVar(
+                    tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True
+                ),
+                "rehype-katex@6.0.3": ImportVar(
+                    tag=_REHYPE_KATEX._var_name, is_default=True
+                ),
+                "rehype-raw@6.1.1": ImportVar(
+                    tag=_REHYPE_RAW._var_name, is_default=True
+                ),
+            },
+            *[
+                component(_MOCK_ARG)._get_imports()  # type: ignore
+                for component in self.component_map.values()
+            ],
+            CodeBlock.create(theme="light")._get_imports(),  # type: ignore,
+            Code.create()._get_imports(),  # type: ignore,
+        ]
 
     def get_component(self, tag: str, **props) -> Component:
         """Get the component for a tag and props.

+ 3 - 3
reflex/components/markdown/markdown.pyi

@@ -11,7 +11,6 @@ import textwrap
 from functools import lru_cache
 from hashlib import md5
 from typing import Any, Callable, Dict, Union
-from reflex.compiler import utils
 from reflex.components.component import Component, CustomComponent
 from reflex.components.radix.themes.layout.list import (
     ListItem,
@@ -22,8 +21,8 @@ from reflex.components.radix.themes.typography.heading import Heading
 from reflex.components.radix.themes.typography.link import Link
 from reflex.components.radix.themes.typography.text import Text
 from reflex.components.tags.tag import Tag
-from reflex.utils import imports, types
-from reflex.utils.imports import ImportVar
+from reflex.utils import types
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 _CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False)
@@ -124,6 +123,7 @@ class Markdown(Component):
             The markdown component.
         """
         ...
+    def add_imports(self) -> ImportDict | list[ImportDict]: ...
     def get_component(self, tag: str, **props) -> Component: ...
     def format_component(self, tag: str, **props) -> str: ...
     def format_component_map(self) -> dict[str, str]: ...

+ 9 - 8
reflex/components/moment/moment.py

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
 
 from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 
@@ -90,14 +90,15 @@ class Moment(NoSSRComponent):
     # Display the date in the given timezone.
     tz: Var[str]
 
-    def _get_imports(self) -> imports.ImportDict:
-        merged_imports = super()._get_imports()
+    def add_imports(self) -> ImportDict:
+        """Add the imports for the Moment component.
+
+        Returns:
+            The import dict for the component.
+        """
         if self.tz is not None:
-            merged_imports = imports.merge_imports(
-                merged_imports,
-                {"moment-timezone": {imports.ImportVar(tag="")}},
-            )
-        return merged_imports
+            return {"moment-timezone": ""}
+        return {}
 
     def get_event_triggers(self) -> Dict[str, Any]:
         """Get the events triggers signatures for the component.

+ 2 - 1
reflex/components/moment/moment.pyi

@@ -10,7 +10,7 @@ from reflex.style import Style
 from typing import Any, Dict, List, Optional
 from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 class MomentDelta(Base):
@@ -25,6 +25,7 @@ class MomentDelta(Base):
     milliseconds: Optional[int]
 
 class Moment(NoSSRComponent):
+    def add_imports(self) -> ImportDict: ...
     def get_event_triggers(self) -> Dict[str, Any]: ...
     @overload
     @classmethod

+ 2 - 3
reflex/components/radix/primitives/accordion.py

@@ -11,7 +11,6 @@ from reflex.components.lucide.icon import Icon
 from reflex.components.radix.primitives.base import RadixPrimitiveComponent
 from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
 from reflex.style import Style
-from reflex.utils import imports
 from reflex.vars import Var, get_uuid_string_var
 
 LiteralAccordionType = Literal["single", "multiple"]
@@ -413,13 +412,13 @@ class AccordionContent(AccordionComponent):
 
     alias = "RadixAccordionContent"
 
-    def add_imports(self) -> imports.ImportDict:
+    def add_imports(self) -> dict:
         """Add imports to the component.
 
         Returns:
             The imports of the component.
         """
-        return {"@emotion/react": [imports.ImportVar(tag="keyframes")]}
+        return {"@emotion/react": "keyframes"}
 
     @classmethod
     def create(cls, *children, **props) -> Component:

+ 1 - 2
reflex/components/radix/primitives/accordion.pyi

@@ -15,7 +15,6 @@ from reflex.components.lucide.icon import Icon
 from reflex.components.radix.primitives.base import RadixPrimitiveComponent
 from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius
 from reflex.style import Style
-from reflex.utils import imports
 from reflex.vars import Var, get_uuid_string_var
 
 LiteralAccordionType = Literal["single", "multiple"]
@@ -899,7 +898,7 @@ class AccordionIcon(Icon):
         ...
 
 class AccordionContent(AccordionComponent):
-    def add_imports(self) -> imports.ImportDict: ...
+    def add_imports(self) -> dict: ...
     @overload
     @classmethod
     def create(  # type: ignore

+ 3 - 3
reflex/components/radix/themes/base.py

@@ -7,7 +7,7 @@ from typing import Any, Dict, Literal
 from reflex.components import Component
 from reflex.components.tags import Tag
 from reflex.config import get_config
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@@ -209,13 +209,13 @@ class Theme(RadixThemesComponent):
             children = [ThemePanel.create(), *children]
         return super().create(*children, **props)
 
-    def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]:
+    def add_imports(self) -> ImportDict | list[ImportDict]:
         """Add imports for the Theme component.
 
         Returns:
             The import dict.
         """
-        _imports: dict[str, list[ImportVar] | ImportVar] = {
+        _imports: ImportDict = {
             "/utils/theme.js": [ImportVar(tag="theme", is_default=True)],
         }
         if get_config().tailwind is None:

+ 2 - 2
reflex/components/radix/themes/base.pyi

@@ -11,7 +11,7 @@ from typing import Any, Dict, Literal
 from reflex.components import Component
 from reflex.components.tags import Tag
 from reflex.config import get_config
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
@@ -580,7 +580,7 @@ class Theme(RadixThemesComponent):
             A new component instance.
         """
         ...
-    def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: ...
+    def add_imports(self) -> ImportDict | list[ImportDict]: ...
 
 class ThemePanel(RadixThemesComponent):
     def add_imports(self) -> dict[str, str]: ...

+ 8 - 3
reflex/components/radix/themes/typography/link.py

@@ -12,7 +12,7 @@ from reflex.components.core.colors import color
 from reflex.components.core.cond import cond
 from reflex.components.el.elements.inline import A
 from reflex.components.next.link import NextLink
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 from ..base import (
@@ -59,8 +59,13 @@ class Link(RadixThemesComponent, A, MemoizationLeaf):
     # If True, the link will open in a new tab
     is_external: Var[bool]
 
-    def _get_imports(self) -> imports.ImportDict:
-        return {**super()._get_imports(), **next_link._get_imports()}
+    def add_imports(self) -> ImportDict:
+        """Add imports for the Link component.
+
+        Returns:
+            The import dict.
+        """
+        return next_link._get_imports()  # type: ignore
 
     @classmethod
     def create(cls, *children, **props) -> Component:

+ 2 - 1
reflex/components/radix/themes/typography/link.pyi

@@ -13,7 +13,7 @@ from reflex.components.core.colors import color
 from reflex.components.core.cond import cond
 from reflex.components.el.elements.inline import A
 from reflex.components.next.link import NextLink
-from reflex.utils import imports
+from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 from ..base import LiteralAccentColor, RadixThemesComponent
 from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight
@@ -22,6 +22,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
 next_link = NextLink.create()
 
 class Link(RadixThemesComponent, A, MemoizationLeaf):
+    def add_imports(self) -> ImportDict: ...
     @overload
     @classmethod
     def create(  # type: ignore

+ 10 - 7
reflex/components/suneditor/editor.py

@@ -8,7 +8,7 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 
@@ -176,12 +176,15 @@ class Editor(NoSSRComponent):
     # default: False
     disable_toolbar: Var[bool]
 
-    def _get_imports(self):
-        imports = super()._get_imports()
-        imports[""] = [
-            ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
-        ]
-        return imports
+    def add_imports(self) -> ImportDict:
+        """Add imports for the Editor component.
+
+        Returns:
+            The import dict.
+        """
+        return {
+            "": ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False)
+        }
 
     def get_event_triggers(self) -> Dict[str, Any]:
         """Get the event triggers that pass the component's value to the handler.

+ 2 - 1
reflex/components/suneditor/editor.pyi

@@ -13,7 +13,7 @@ from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.constants import EventTriggers
 from reflex.utils.format import to_camel_case
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
 from reflex.vars import Var
 
 class EditorButtonList(list, enum.Enum):
@@ -48,6 +48,7 @@ class EditorOptions(Base):
     button_list: Optional[List[Union[List[str], str]]]
 
 class Editor(NoSSRComponent):
+    def add_imports(self) -> ImportDict: ...
     def get_event_triggers(self) -> Dict[str, Any]: ...
     @overload
     @classmethod

+ 34 - 5
reflex/utils/imports.py

@@ -3,12 +3,12 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union
 
 from reflex.base import Base
 
 
-def merge_imports(*imports) -> ImportDict:
+def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
     """Merge multiple import dicts together.
 
     Args:
@@ -24,7 +24,31 @@ def merge_imports(*imports) -> ImportDict:
     return all_imports
 
 
-def collapse_imports(imports: ImportDict) -> ImportDict:
+def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
+    """Parse the import dict into a standard format.
+
+    Args:
+        imports: The import dict to parse.
+
+    Returns:
+        The parsed import dict.
+    """
+
+    def _make_list(value: ImportTypes) -> list[str | ImportVar] | list[ImportVar]:
+        if isinstance(value, (str, ImportVar)):
+            return [value]
+        return value
+
+    return {
+        package: [
+            ImportVar(tag=tag) if isinstance(tag, str) else tag
+            for tag in _make_list(maybe_tags)
+        ]
+        for package, maybe_tags in imports.items()
+    }
+
+
+def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
     """Remove all duplicate ImportVar within an ImportDict.
 
     Args:
@@ -33,7 +57,10 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
     Returns:
         The collapsed import dict.
     """
-    return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
+    return {
+        lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars
+        for lib, import_vars in imports.items()
+    }
 
 
 class ImportVar(Base):
@@ -90,4 +117,6 @@ class ImportVar(Base):
         )
 
 
-ImportDict = Dict[str, List[ImportVar]]
+ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
+ImportDict = Dict[str, ImportTypes]
+ParsedImportDict = Dict[str, List[ImportVar]]

+ 20 - 2
reflex/vars.py

@@ -39,7 +39,12 @@ from reflex.utils import console, imports, serializers, types
 from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueError
 
 # This module used to export ImportVar itself, so we still import it for export here
-from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.imports import (
+    ImportDict,
+    ImportVar,
+    ParsedImportDict,
+    parse_imports,
+)
 from reflex.utils.types import override
 
 if TYPE_CHECKING:
@@ -120,7 +125,7 @@ class VarData(Base):
     state: str = ""
 
     # Imports needed to render this var
-    imports: ImportDict = {}
+    imports: ParsedImportDict = {}
 
     # Hooks that need to be present in the component to render this var
     hooks: Dict[str, None] = {}
@@ -130,6 +135,19 @@ class VarData(Base):
     # segments.
     interpolations: List[Tuple[int, int]] = []
 
+    def __init__(
+        self, imports: Union[ImportDict, ParsedImportDict] | None = None, **kwargs: Any
+    ):
+        """Initialize the var data.
+
+        Args:
+            imports: The imports needed to render this var.
+            **kwargs: The var data fields.
+        """
+        if imports:
+            kwargs["imports"] = parse_imports(imports)
+        super().__init__(**kwargs)
+
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
         """Merge multiple var data objects.

+ 2 - 2
reflex/vars.pyi

@@ -10,7 +10,7 @@ from reflex.base import Base as Base
 from reflex.state import State as State
 from reflex.state import BaseState as BaseState
 from reflex.utils import console as console, format as format, types as types
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportVar, ImportDict, ParsedImportDict
 from types import FunctionType
 from typing import (
     Any,
@@ -36,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 
 class VarData(Base):
     state: str = ""
-    imports: dict[str, List[ImportVar]] = {}
+    imports: Union[ImportDict, ParsedImportDict] = {}
     hooks: Dict[str, None] = {}
     interpolations: List[Tuple[int, int]] = []
     @classmethod

+ 2 - 3
tests/compiler/test_compiler.py

@@ -4,8 +4,7 @@ from typing import List
 import pytest
 
 from reflex.compiler import compiler, utils
-from reflex.utils import imports
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportVar, ParsedImportDict
 
 
 @pytest.mark.parametrize(
@@ -93,7 +92,7 @@ def test_compile_import_statement(
         ),
     ],
 )
-def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]):
+def test_compile_imports(import_dict: ParsedImportDict, test_dicts: List[dict]):
     """Test the compile_imports function.
 
     Args:

+ 12 - 16
tests/components/test_component.py

@@ -20,7 +20,7 @@ from reflex.event import EventChain, EventHandler, parse_args_spec
 from reflex.state import BaseState
 from reflex.style import Style
 from reflex.utils import imports
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
 from reflex.vars import BaseVar, Var, VarData
 
 
@@ -56,7 +56,7 @@ def component1() -> Type[Component]:
         # A test string/number prop.
         text_or_number: Var[Union[int, str]]
 
-        def _get_imports(self) -> imports.ImportDict:
+        def _get_imports(self) -> ParsedImportDict:
             return {"react": [ImportVar(tag="Component")]}
 
         def _get_custom_code(self) -> str:
@@ -89,7 +89,7 @@ def component2() -> Type[Component]:
                 "on_close": lambda e0: [e0],
             }
 
-        def _get_imports(self) -> imports.ImportDict:
+        def _get_imports(self) -> ParsedImportDict:
             return {"react-redux": [ImportVar(tag="connect")]}
 
         def _get_custom_code(self) -> str:
@@ -1773,21 +1773,15 @@ def test_invalid_event_trigger():
     ),
 )
 def test_component_add_imports(tags):
-    def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
-        return [
-            ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
-            for tag in tags
-        ]
-
     class BaseComponent(Component):
-        def _get_imports(self) -> imports.ImportDict:
+        def _get_imports(self) -> ImportDict:
             return {}
 
     class Reference(Component):
-        def _get_imports(self) -> imports.ImportDict:
+        def _get_imports(self) -> ParsedImportDict:
             return imports.merge_imports(
                 super()._get_imports(),
-                {"react": _list_to_import_vars(tags)},
+                parse_imports({"react": tags}),
                 {"foo": [ImportVar(tag="bar")]},
             )
 
@@ -1806,10 +1800,12 @@ def test_component_add_imports(tags):
     baseline = Reference.create()
     test = Test.create()
 
-    assert baseline._get_all_imports() == {
-        "react": _list_to_import_vars(tags),
-        "foo": [ImportVar(tag="bar")],
-    }
+    assert baseline._get_all_imports() == parse_imports(
+        {
+            "react": tags,
+            "foo": [ImportVar(tag="bar")],
+        }
+    )
     assert test._get_all_imports() == baseline._get_all_imports()
 
 

+ 36 - 1
tests/utils/test_imports.py

@@ -1,6 +1,12 @@
 import pytest
 
-from reflex.utils.imports import ImportVar, merge_imports
+from reflex.utils.imports import (
+    ImportDict,
+    ImportVar,
+    ParsedImportDict,
+    merge_imports,
+    parse_imports,
+)
 
 
 @pytest.mark.parametrize(
@@ -76,3 +82,32 @@ def test_merge_imports(input_1, input_2, output):
 
     for key in output:
         assert set(res[key]) == set(output[key])
+
+
+@pytest.mark.parametrize(
+    "input, output",
+    [
+        ({}, {}),
+        (
+            {"react": "Component"},
+            {"react": [ImportVar(tag="Component")]},
+        ),
+        (
+            {"react": ["Component"]},
+            {"react": [ImportVar(tag="Component")]},
+        ),
+        (
+            {"react": ["Component", ImportVar(tag="useState")]},
+            {"react": [ImportVar(tag="Component"), ImportVar(tag="useState")]},
+        ),
+        (
+            {"react": ["Component"], "foo": "anotherFunction"},
+            {
+                "react": [ImportVar(tag="Component")],
+                "foo": [ImportVar(tag="anotherFunction")],
+            },
+        ),
+    ],
+)
+def test_parse_imports(input: ImportDict, output: ParsedImportDict):
+    assert parse_imports(input) == output