Prechádzať zdrojové kódy

WiP: use ImportList internally instead of ImportDict

* deprecate `_get_imports` in favor of new `_get_imports_list`
* `_get_all_imports` now returns an `ImportList`
* Compiler uses `ImportList.collapse` to get an `ImportDict`
Masen Furer 1 rok pred
rodič
commit
35252464a0

+ 10 - 21
reflex/app.py

@@ -79,7 +79,7 @@ from reflex.state import (
 )
 )
 from reflex.utils import console, exceptions, format, prerequisites, types
 from reflex.utils import console, exceptions, format, prerequisites, types
 from reflex.utils.exec import is_testing_env, should_skip_compile
 from reflex.utils.exec import is_testing_env, should_skip_compile
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList
 
 
 # Define custom types.
 # Define custom types.
 ComponentCallable = Callable[[], Component]
 ComponentCallable = Callable[[], Component]
@@ -618,27 +618,16 @@ class App(Base):
 
 
             admin.mount_to(self.api)
             admin.mount_to(self.api)
 
 
-    def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]):
+    def get_frontend_packages(self, imports: ImportList):
         """Gets the frontend packages to be installed and filters out the unnecessary ones.
         """Gets the frontend packages to be installed and filters out the unnecessary ones.
 
 
         Args:
         Args:
-            imports: A dictionary containing the imports used in the current page.
+            imports: A list containing the imports used in the current page.
 
 
         Example:
         Example:
             >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
             >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
         """
         """
-        page_imports = {
-            i
-            for i, tags in imports.items()
-            if i
-            not in [
-                *constants.PackageJson.DEPENDENCIES.keys(),
-                *constants.PackageJson.DEV_DEPENDENCIES.keys(),
-            ]
-            and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
-            and i != ""
-            and any(tag.install for tag in tags)
-        }
+        page_imports = [i.package for i in imports.collapse().values() if i.install]
         frontend_packages = get_config().frontend_packages
         frontend_packages = get_config().frontend_packages
         _frontend_packages = []
         _frontend_packages = []
         for package in frontend_packages:
         for package in frontend_packages:
@@ -653,7 +642,7 @@ class App(Base):
                 )
                 )
                 continue
                 continue
             _frontend_packages.append(package)
             _frontend_packages.append(package)
-        page_imports.update(_frontend_packages)
+        page_imports.extend(_frontend_packages)
         prerequisites.install_frontend_packages(page_imports, get_config())
         prerequisites.install_frontend_packages(page_imports, get_config())
 
 
     def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
     def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
@@ -794,7 +783,7 @@ class App(Base):
         self.style = evaluate_style_namespaces(self.style)
         self.style = evaluate_style_namespaces(self.style)
 
 
         # Track imports and custom components found.
         # Track imports and custom components found.
-        all_imports = {}
+        all_imports = ImportList()
         custom_components = set()
         custom_components = set()
 
 
         for _route, component in self.pages.items():
         for _route, component in self.pages.items():
@@ -804,7 +793,7 @@ class App(Base):
             component.apply_theme(self.theme)
             component.apply_theme(self.theme)
 
 
             # Add component._get_all_imports() to all_imports.
             # Add component._get_all_imports() to all_imports.
-            all_imports.update(component._get_all_imports())
+            all_imports.extend(component._get_all_imports())
 
 
             # Add the app wrappers from this component.
             # Add the app wrappers from this component.
             app_wrappers.update(component._get_all_app_wrap_components())
             app_wrappers.update(component._get_all_app_wrap_components())
@@ -932,10 +921,10 @@ class App(Base):
                 custom_components_imports,
                 custom_components_imports,
             ) = custom_components_future.result()
             ) = custom_components_future.result()
             compile_results.append(custom_components_result)
             compile_results.append(custom_components_result)
-            all_imports.update(custom_components_imports)
+            all_imports.extend(custom_components_imports)
 
 
         # Get imports from AppWrap components.
         # Get imports from AppWrap components.
-        all_imports.update(app_root._get_all_imports())
+        all_imports.extend(app_root._get_all_imports())
 
 
         progress.advance(task)
         progress.advance(task)
 
 
@@ -951,7 +940,7 @@ class App(Base):
         # Setup the next.config.js
         # Setup the next.config.js
         transpile_packages = [
         transpile_packages = [
             package
             package
-            for package, import_vars in all_imports.items()
+            for package, import_vars in all_imports.collapse().items()
             if any(import_var.transpile for import_var in import_vars)
             if any(import_var.transpile for import_var in import_vars)
         ]
         ]
         prerequisites.update_next_config(
         prerequisites.update_next_config(

+ 25 - 14
reflex/compiler/compiler.py

@@ -19,7 +19,7 @@ from reflex.config import get_config
 from reflex.state import BaseState
 from reflex.state import BaseState
 from reflex.style import LIGHT_COLOR_MODE
 from reflex.style import LIGHT_COLOR_MODE
 from reflex.utils.exec import is_prod_mode
 from reflex.utils.exec import is_prod_mode
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 from reflex.vars import Var
 from reflex.vars import Var
 
 
 
 
@@ -197,25 +197,34 @@ def _compile_components(
     Returns:
     Returns:
         The compiled components.
         The compiled components.
     """
     """
-    imports = {
-        "react": [ImportVar(tag="memo")],
-        f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
-    }
+    _imports = ImportList(
+        [
+            ImportVar(package="react", tag="memo"),
+            ImportVar(
+                package=f"/{constants.Dirs.STATE_PATH}",
+                tag="E",
+            ),
+            ImportVar(
+                package=f"/{constants.Dirs.STATE_PATH}",
+                tag="isTrue",
+            ),
+        ]
+    )
     component_renders = []
     component_renders = []
 
 
     # Compile each component.
     # Compile each component.
     for component in components:
     for component in components:
         component_render, component_imports = utils.compile_custom_component(component)
         component_render, component_imports = utils.compile_custom_component(component)
         component_renders.append(component_render)
         component_renders.append(component_render)
-        imports = utils.merge_imports(imports, component_imports)
+        _imports.extend(component_imports)
 
 
     # Compile the components page.
     # Compile the components page.
     return (
     return (
         templates.COMPONENTS.render(
         templates.COMPONENTS.render(
-            imports=utils.compile_imports(imports),
+            imports=utils.compile_imports(_imports),
             components=component_renders,
             components=component_renders,
         ),
         ),
-        imports,
+        _imports,
     )
     )
 
 
 
 
@@ -235,7 +244,7 @@ def _compile_stateful_components(
     Returns:
     Returns:
         The rendered stateful components code.
         The rendered stateful components code.
     """
     """
-    all_import_dicts = []
+    all_imports = []
     rendered_components = {}
     rendered_components = {}
 
 
     def get_shared_components_recursive(component: BaseComponent):
     def get_shared_components_recursive(component: BaseComponent):
@@ -266,7 +275,7 @@ def _compile_stateful_components(
             rendered_components.update(
             rendered_components.update(
                 {code: None for code in component._get_all_custom_code()},
                 {code: None for code in component._get_all_custom_code()},
             )
             )
-            all_import_dicts.append(component._get_all_imports())
+            all_imports.extend(component._get_all_imports())
 
 
             # Indicate that this component now imports from the shared file.
             # Indicate that this component now imports from the shared file.
             component.rendered_as_shared = True
             component.rendered_as_shared = True
@@ -275,9 +284,11 @@ def _compile_stateful_components(
         get_shared_components_recursive(page_component)
         get_shared_components_recursive(page_component)
 
 
     # Don't import from the file that we're about to create.
     # Don't import from the file that we're about to create.
-    all_imports = utils.merge_imports(*all_import_dicts)
-    all_imports.pop(
-        f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
+    all_imports = ImportList(
+        imp
+        for imp in all_imports
+        if imp.library
+        != f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}"
     )
     )
 
 
     return templates.STATEFUL_COMPONENTS.render(
     return templates.STATEFUL_COMPONENTS.render(
@@ -408,7 +419,7 @@ def compile_page(
 
 
 def compile_components(
 def compile_components(
     components: set[CustomComponent],
     components: set[CustomComponent],
-) -> tuple[str, str, Dict[str, list[ImportVar]]]:
+) -> tuple[str, str, ImportList]:
     """Compile the custom components.
     """Compile the custom components.
 
 
     Args:
     Args:

+ 13 - 15
reflex/compiler/utils.py

@@ -88,16 +88,16 @@ def validate_imports(import_dict: imports.ImportDict):
                 used_tags[import_name] = lib
                 used_tags[import_name] = lib
 
 
 
 
-def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
-    """Compile an import dict.
+def compile_imports(import_list: imports.ImportList) -> list[dict]:
+    """Compile an import list.
 
 
     Args:
     Args:
-        import_dict: The import dict to compile.
+        import_list: The import list to compile.
 
 
     Returns:
     Returns:
-        The list of import dict.
+        The list of template import dict.
     """
     """
-    collapsed_import_dict = imports.collapse_imports(import_dict)
+    collapsed_import_dict = import_list.collapse()
     validate_imports(collapsed_import_dict)
     validate_imports(collapsed_import_dict)
     import_dicts = []
     import_dicts = []
     for lib, fields in collapsed_import_dict.items():
     for lib, fields in collapsed_import_dict.items():
@@ -114,9 +114,6 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
                 import_dicts.append(get_import_dict(module))
                 import_dicts.append(get_import_dict(module))
             continue
             continue
 
 
-        # remove the version before rendering the package imports
-        lib = format.format_library_name(lib)
-
         import_dicts.append(get_import_dict(lib, default, rest))
         import_dicts.append(get_import_dict(lib, default, rest))
     return import_dicts
     return import_dicts
 
 
@@ -237,7 +234,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
 
 
 def compile_custom_component(
 def compile_custom_component(
     component: CustomComponent,
     component: CustomComponent,
-) -> tuple[dict, imports.ImportDict]:
+) -> tuple[dict, imports.ImportList]:
     """Compile a custom component.
     """Compile a custom component.
 
 
     Args:
     Args:
@@ -250,11 +247,12 @@ def compile_custom_component(
     render = component.get_component(component)
     render = component.get_component(component)
 
 
     # Get the imports.
     # Get the imports.
-    imports = {
-        lib: fields
-        for lib, fields in render._get_all_imports().items()
-        if lib != component.library
-    }
+    component_library_name = format.format_library_name(component.library)
+    _imports = imports.ImportList(
+        imp
+        for imp in render._get_all_imports()
+        if imp.library != component_library_name
+    )
 
 
     # Concatenate the props.
     # Concatenate the props.
     props = [prop._var_name for prop in component.get_prop_vars()]
     props = [prop._var_name for prop in component.get_prop_vars()]
@@ -268,7 +266,7 @@ def compile_custom_component(
             "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
             "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
             "custom_code": render._get_all_custom_code(),
             "custom_code": render._get_all_custom_code(),
         },
         },
-        imports,
+        _imports,
     )
     )
 
 
 
 

+ 7 - 8
reflex/components/chakra/base.py

@@ -35,19 +35,18 @@ class ChakraComponent(Component):
 
 
     @classmethod
     @classmethod
     @lru_cache(maxsize=None)
     @lru_cache(maxsize=None)
-    def _get_dependencies_imports(cls) -> imports.ImportDict:
+    def _get_dependencies_imports(cls) -> imports.ImportList:
         """Get the imports from lib_dependencies for installing.
         """Get the imports from lib_dependencies for installing.
 
 
         Returns:
         Returns:
             The dependencies imports of the component.
             The dependencies imports of the component.
         """
         """
-        return {
-            dep: [imports.ImportVar(tag=None, render=False)]
-            for dep in [
-                "@chakra-ui/system@2.5.7",
-                "framer-motion@10.16.4",
-            ]
-        }
+        return [
+            imports.ImportVar(
+                package="@chakra-ui/system@2.5.7", tag=None, render=False
+            ),
+            imports.ImportVar(package="framer-motion@10.16.4", tag=None, render=False),
+        ]
 
 
 
 
 class ChakraProvider(ChakraComponent):
 class ChakraProvider(ChakraComponent):

+ 100 - 59
reflex/components/component.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import copy
 import copy
+import itertools
 import typing
 import typing
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from functools import lru_cache, wraps
 from functools import lru_cache, wraps
@@ -95,11 +96,11 @@ class BaseComponent(Base, ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.
 
 
         Returns:
         Returns:
-            The import dict with the required imports.
+            The list of all required ImportVar.
         """
         """
 
 
     @abstractmethod
     @abstractmethod
@@ -994,17 +995,22 @@ class Component(BaseComponent, ABC):
         # Return the dynamic imports
         # Return the dynamic imports
         return dynamic_imports
         return dynamic_imports
 
 
-    def _get_props_imports(self) -> List[str]:
+    def _get_props_imports(self) -> imports.ImportList:
         """Get the imports needed for components props.
         """Get the imports needed for components props.
 
 
         Returns:
         Returns:
-            The  imports for the components props of the component.
+            The imports for the components props of the component.
         """
         """
-        return [
-            getattr(self, prop)._get_all_imports()
-            for prop in self.get_component_props()
-            if getattr(self, prop) is not None
-        ]
+        return imports.ImportList(
+            sum(
+                (
+                    getattr(self, prop)._get_all_imports()
+                    for prop in self.get_component_props()
+                    if getattr(self, prop) is not None
+                ),
+                [],
+            )
+        )
 
 
     def _should_transpile(self, dep: str | None) -> bool:
     def _should_transpile(self, dep: str | None) -> bool:
         """Check if a dependency should be transpiled.
         """Check if a dependency should be transpiled.
@@ -1020,97 +1026,129 @@ class Component(BaseComponent, ABC):
             or format.format_library_name(dep or "") in self.transpile_packages
             or format.format_library_name(dep or "") in self.transpile_packages
         )
         )
 
 
-    def _get_dependencies_imports(self) -> imports.ImportDict:
+    def _get_dependencies_imports(self) -> imports.ImportList:
         """Get the imports from lib_dependencies for installing.
         """Get the imports from lib_dependencies for installing.
 
 
         Returns:
         Returns:
             The dependencies imports of the component.
             The dependencies imports of the component.
         """
         """
-        return {
-            dep: [
-                ImportVar(
-                    tag=None,
-                    render=False,
-                    transpile=self._should_transpile(dep),
-                )
-            ]
+        return imports.ImportList(
+            ImportVar(
+                package=dep,
+                tag=None,
+                render=False,
+                transpile=self._should_transpile(dep),
+            )
             for dep in self.lib_dependencies
             for dep in self.lib_dependencies
-        }
+        )
 
 
-    def _get_hooks_imports(self) -> imports.ImportDict:
+    def _get_hooks_imports(self) -> imports.ImportList:
         """Get the imports required by certain hooks.
         """Get the imports required by certain hooks.
 
 
         Returns:
         Returns:
             The imports required for all selected hooks.
             The imports required for all selected hooks.
         """
         """
-        _imports = {}
+        _imports = imports.ImportList()
 
 
         if self._get_ref_hook():
         if self._get_ref_hook():
             # Handle hooks needed for attaching react refs to DOM nodes.
             # Handle hooks needed for attaching react refs to DOM nodes.
-            _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
-            _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+            _imports.extend(
+                [
+                    ImportVar(package="react", tag="useRef"),
+                    ImportVar(package=f"/{Dirs.STATE_PATH}", tag="refs"),
+                ]
+            )
 
 
         if self._get_mount_lifecycle_hook():
         if self._get_mount_lifecycle_hook():
             # Handle hooks for `on_mount` / `on_unmount`.
             # Handle hooks for `on_mount` / `on_unmount`.
-            _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+            _imports.append(ImportVar(package="react", tag="useEffect"))
 
 
         if self._get_special_hooks():
         if self._get_special_hooks():
             # Handle additional internal hooks (autofocus, etc).
             # Handle additional internal hooks (autofocus, etc).
-            _imports.setdefault("react", set()).update(
-                {
-                    ImportVar(tag="useRef"),
-                    ImportVar(tag="useEffect"),
-                },
+            _imports.extend(
+                [
+                    ImportVar(package="react", tag="useEffect"),
+                    ImportVar(package="react", tag="useRef"),
+                ]
             )
             )
 
 
         user_hooks = self._get_hooks()
         user_hooks = self._get_hooks()
         if user_hooks is not None and isinstance(user_hooks, Var):
         if user_hooks is not None and isinstance(user_hooks, Var):
-            _imports = imports.merge_imports(_imports, user_hooks._var_data.imports)  # type: ignore
+            _imports.extend(user_hooks._var_data.imports)
 
 
         return _imports
         return _imports
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
-        """Get all the libraries and fields that are used by the component.
+        """Deprecated method to get all the libraries and fields used by the component.
 
 
         Returns:
         Returns:
             The imports needed by the component.
             The imports needed by the component.
         """
         """
-        _imports = {}
+        return {}
+
+    def _get_imports_list(self) -> imports.ImportList:
+        """Internal method to get the imports as a list.
+
+        Returns:
+            The imports as a list.
+        """
+        _imports = imports.ImportList(
+            itertools.chain(
+                self._get_props_imports(),
+                self._get_dependencies_imports(),
+                self._get_hooks_imports(),
+            )
+        )
+
+        # Handle deprecated _get_imports
+        import_dict = self._get_imports()
+        if import_dict:
+            console.deprecate(
+                feature_name="_get_imports",
+                reason="use add_imports instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+            _imports.extend(imports.ImportList.from_import_dict(import_dict))
 
 
         # Import this component's tag from the main library.
         # Import this component's tag from the main library.
         if self.library is not None and self.tag is not None:
         if self.library is not None and self.tag is not None:
-            _imports[self.library] = {self.import_var}
+            _imports.append(self.import_var)
 
 
         # Get static imports required for event processing.
         # Get static imports required for event processing.
-        event_imports = Imports.EVENTS if self.event_triggers else {}
+        if self.event_triggers:
+            _imports.append(Imports.EVENTS)
 
 
         # Collect imports from Vars used directly by this component.
         # Collect imports from Vars used directly by this component.
-        var_imports = [
-            var._var_data.imports for var in self._get_vars() if var._var_data
-        ]
-
-        return imports.merge_imports(
-            *self._get_props_imports(),
-            self._get_dependencies_imports(),
-            self._get_hooks_imports(),
-            _imports,
-            event_imports,
-            *var_imports,
-        )
+        for var in self._get_vars():
+            if var._var_data:
+                _imports.extend(var._var_data.imports)
+        return _imports
 
 
-    def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
+    def _get_all_imports(self, collapse: bool = False) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component and its children.
         """Get all the libraries and fields that are used by the component and its children.
 
 
         Args:
         Args:
-            collapse: Whether to collapse the imports by removing duplicates.
+            collapse: Whether to collapse the imports into a dict (deprecated).
 
 
         Returns:
         Returns:
-            The import dict with the required imports.
+            The list of all required imports.
         """
         """
-        _imports = imports.merge_imports(
-            self._get_imports(), *[child._get_all_imports() for child in self.children]
+        _imports = imports.ImportList(
+            self._get_imports_list()
+            + sum((child._get_all_imports() for child in self.children), [])
         )
         )
-        return imports.collapse_imports(_imports) if collapse else _imports
+
+        if collapse:
+            console.deprecate(
+                feature_name="collapse kwarg to _get_all_imports",
+                reason="use ImportList.collapse instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+            return _imports.collapse()  # type: ignore
+
+        return _imports
 
 
     def _get_mount_lifecycle_hook(self) -> str | None:
     def _get_mount_lifecycle_hook(self) -> str | None:
         """Generate the component lifecycle hook.
         """Generate the component lifecycle hook.
@@ -1296,6 +1334,7 @@ class Component(BaseComponent, ABC):
         tag = self.tag.partition(".")[0] if self.tag else None
         tag = self.tag.partition(".")[0] if self.tag else None
         alias = self.alias.partition(".")[0] if self.alias else None
         alias = self.alias.partition(".")[0] if self.alias else None
         return ImportVar(
         return ImportVar(
+            package=self.library,
             tag=tag,
             tag=tag,
             is_default=self.is_default,
             is_default=self.is_default,
             alias=alias,
             alias=alias,
@@ -1575,7 +1614,6 @@ class NoSSRComponent(Component):
         return imports.merge_imports(
         return imports.merge_imports(
             dynamic_import,
             dynamic_import,
             _imports,
             _imports,
-            self._get_dependencies_imports(),
         )
         )
 
 
     def _get_dynamic_imports(self) -> str:
     def _get_dynamic_imports(self) -> str:
@@ -1893,18 +1931,21 @@ class StatefulComponent(BaseComponent):
         """
         """
         return {}
         return {}
 
 
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.
 
 
         Returns:
         Returns:
-            The import dict with the required imports.
+            The list of all required imports.
         """
         """
         if self.rendered_as_shared:
         if self.rendered_as_shared:
-            return {
-                f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
-                    ImportVar(tag=self.tag)
+            return imports.ImportList(
+                [
+                    imports.ImportVar(
+                        package=f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}",
+                        tag=self.tag,
+                    )
                 ]
                 ]
-            }
+            )
         return self.component._get_all_imports()
         return self.component._get_all_imports()
 
 
     def _get_all_dynamic_imports(self) -> set[str]:
     def _get_all_dynamic_imports(self) -> set[str]:

+ 10 - 8
reflex/components/core/cond.py

@@ -12,9 +12,9 @@ from reflex.style import LIGHT_COLOR_MODE, color_mode
 from reflex.utils import format, imports
 from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 from reflex.vars import BaseVar, Var, VarData
 
 
-_IS_TRUE_IMPORT = {
-    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
-}
+_IS_TRUE_IMPORT = imports.ImportList(
+    [imports.ImportVar(library=f"/{Dirs.STATE_PATH}", tag="isTrue")]
+)
 
 
 
 
 class Cond(MemoizationLeaf):
 class Cond(MemoizationLeaf):
@@ -95,11 +95,13 @@ class Cond(MemoizationLeaf):
             cond_state=f"isTrue({self.cond._var_full_name})",
             cond_state=f"isTrue({self.cond._var_full_name})",
         )
         )
 
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            getattr(self.cond._var_data, "imports", {}),
-            _IS_TRUE_IMPORT,
+    def _get_imports_list(self) -> imports.ImportList:
+        return imports.ImportList(
+            [
+                *super()._get_imports_list(),
+                *getattr(self.cond._var_data, "imports", []),
+                *_IS_TRUE_IMPORT,
+            ]
         )
         )
 
 
     def _apply_theme(self, theme: Component):
     def _apply_theme(self, theme: Component):

+ 6 - 6
reflex/constants/compiler.py

@@ -6,7 +6,7 @@ from types import SimpleNamespace
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.constants import Dirs
 from reflex.constants import Dirs
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 
 
 # The prefix used to create setters for state vars.
 # The prefix used to create setters for state vars.
 SETTER_PREFIX = "set_"
 SETTER_PREFIX = "set_"
@@ -102,11 +102,11 @@ class ComponentName(Enum):
 class Imports(SimpleNamespace):
 class Imports(SimpleNamespace):
     """Common sets of import vars."""
     """Common sets of import vars."""
 
 
-    EVENTS = {
-        "react": {ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
-        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
-    }
+    EVENTS: ImportList = [
+        ImportVar(package="react", tag="useContext"),
+        ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
+        ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
+    ]
 
 
 
 
 class Hooks(SimpleNamespace):
 class Hooks(SimpleNamespace):

+ 2 - 5
reflex/utils/format.py

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
 
 
 from reflex import constants
 from reflex import constants
 from reflex.utils import exceptions, serializers, types
 from reflex.utils import exceptions, serializers, types
+from reflex.utils.imports import split_library_name_version
 from reflex.utils.serializers import serialize
 from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, Var
 from reflex.vars import BaseVar, Var
 
 
@@ -716,11 +717,7 @@ def format_library_name(library_fullname: str):
     Returns:
     Returns:
         The name without the @version if it was part of the name
         The name without the @version if it was part of the name
     """
     """
-    lib, at, version = library_fullname.rpartition("@")
-    if not lib:
-        lib = at + version
-
-    return lib
+    return split_library_name_version(library_fullname)[0]
 
 
 
 
 def json_dumps(obj: Any) -> str:
 def json_dumps(obj: Any) -> str:

+ 154 - 5
reflex/utils/imports.py

@@ -3,9 +3,10 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from collections import defaultdict
 from collections import defaultdict
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Set
 
 
 from reflex.base import Base
 from reflex.base import Base
+from reflex.constants.installer import PackageJson
 
 
 
 
 def merge_imports(*imports) -> ImportDict:
 def merge_imports(*imports) -> ImportDict:
@@ -36,9 +37,29 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
     return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
     return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
 
 
 
 
+def split_library_name_version(library_fullname: str):
+    """Split the name of a library from its version.
+
+    Args:
+        library_fullname: The fullname of the library.
+
+    Returns:
+        A tuple of the library name and version.
+    """
+    lib, at, version = library_fullname.rpartition("@")
+    if not lib:
+        lib = at + version
+        version = None
+
+    return lib, version
+
+
 class ImportVar(Base):
 class ImportVar(Base):
     """An import var."""
     """An import var."""
 
 
+    # The package name associated with the tag
+    library: Optional[str]
+
     # The name of the import tag.
     # The name of the import tag.
     tag: Optional[str]
     tag: Optional[str]
 
 
@@ -48,6 +69,12 @@ class ImportVar(Base):
     # The tag alias.
     # The tag alias.
     alias: Optional[str] = None
     alias: Optional[str] = None
 
 
+    # The following fields provide extra information about the import,
+    # but are not factored in when considering hash or equality
+
+    # The version of the package
+    version: Optional[str]
+
     # Whether this import need to install the associated lib
     # Whether this import need to install the associated lib
     install: Optional[bool] = True
     install: Optional[bool] = True
 
 
@@ -58,6 +85,34 @@ class ImportVar(Base):
     # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
     # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
     transpile: Optional[bool] = False
     transpile: Optional[bool] = False
 
 
+    def __init__(
+        self,
+        *,
+        package: Optional[str] = None,
+        **kwargs,
+    ):
+        if package is not None:
+            if (
+                kwargs.get("library", None) is not None
+                or kwargs.get("version", None) is not None
+            ):
+                raise ValueError(
+                    "Cannot provide 'library' or 'version' as keyword arguments when "
+                    "specifying 'package' as an argument"
+                )
+            kwargs["library"], kwargs["version"] = split_library_name_version(package)
+
+        install = (
+            package is not None
+            # TODO: handle version conflicts
+            and package not in PackageJson.DEPENDENCIES
+            and package not in PackageJson.DEV_DEPENDENCIES
+            and not any(package.startswith(prefix) for prefix in ["/", ".", "next/"])
+            and package != ""
+        )
+        kwargs.setdefault("install", install)
+        super().__init__(**kwargs)
+
     @property
     @property
     def name(self) -> str:
     def name(self) -> str:
         """The name of the import.
         """The name of the import.
@@ -72,6 +127,17 @@ class ImportVar(Base):
         else:
         else:
             return self.tag or ""
             return self.tag or ""
 
 
+    @property
+    def package(self) -> str:
+        """The package to install for this import
+
+        Returns:
+            The library name and (optional) version to be installed by npm/bun.
+        """
+        if self.version:
+            return f"{self.library}@{self.version}"
+        return self.library
+
     def __hash__(self) -> int:
     def __hash__(self) -> int:
         """Define a hash function for the import var.
         """Define a hash function for the import var.
 
 
@@ -80,14 +146,97 @@ class ImportVar(Base):
         """
         """
         return hash(
         return hash(
             (
             (
+                self.library,
                 self.tag,
                 self.tag,
                 self.is_default,
                 self.is_default,
                 self.alias,
                 self.alias,
-                self.install,
-                self.render,
-                self.transpile,
+                # These do not fundamentally change the import in any way
+                # self.install,
+                # self.render,
+                # self.transpile,
             )
             )
         )
         )
 
 
+    def __eq__(self, other: ImportVar) -> bool:
+        """Define equality for the import var.
+
+        Args:
+            other: The other import var to compare.
+
+        Returns:
+            Whether the two import vars are equal.
+        """
+        if type(self) != type(other):
+            return NotImplemented
+        return (self.library, self.tag, self.is_default, self.alias) == (
+            other.library,
+            other.tag,
+            other.is_default,
+            other.alias,
+        )
+
+    def collapse(self, other_import_var: ImportVar) -> ImportVar:
+        """Collapse two import vars together.
+
+        Args:
+            other_import_var: The other import var to collapse with.
+
+        Returns:
+            The collapsed import var with sticky props perserved.
+        """
+        if self != other_import_var:
+            raise ValueError("Cannot collapse two import vars with different hashes")
+
+        if self.version is not None and other_import_var.version is not None:
+            if self.version != other_import_var.version:
+                raise ValueError(
+                    "Cannot collapse two import vars with conflicting version specifiers: "
+                    f"{self} {other_import_var}"
+                )
+
+        return type(self)(
+            library=self.library,
+            version=self.version or other_import_var.version,
+            tag=self.tag,
+            is_default=self.is_default,
+            alias=self.alias,
+            install=self.install or other_import_var.install,
+            render=self.render or other_import_var.render,
+            transpile=self.transpile or other_import_var.transpile,
+        )
+
 
 
-ImportDict = Dict[str, List[ImportVar]]
+class ImportList(List[ImportVar]):
+    """A list of import vars."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        for ix, value in enumerate(self):
+            if not isinstance(value, ImportVar):
+                # convert dicts to ImportVar
+                self[ix] = ImportVar(**value)
+
+    @classmethod
+    def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
+        return [
+            ImportVar(package=lib, **imp.dict())
+            for lib, imps in import_dict.items()
+            for imp in imps
+        ]
+
+    def collapse(self) -> ImportDict:
+        """When collapsing an import list, prefer packages with version specifiers."""
+        collapsed = {}
+        for imp in self:
+            collapsed.setdefault(imp.library, {})
+            if imp in collapsed[imp.library]:
+                # Need to check if the current import has any special properties that need to
+                # be preserved, like the version specifier, install, or transpile.
+                existing_imp = collapsed[imp.library][imp]
+                collapsed[imp.library][imp] = existing_imp.collapse(imp)
+            else:
+                collapsed[imp.library][imp] = imp
+        return {lib: set(imps) for lib, imps in collapsed.items()}
+
+
+ImportDict = Dict[str, Set[ImportVar]]

+ 27 - 10
reflex/vars.py

@@ -37,7 +37,7 @@ from reflex.base import Base
 from reflex.utils import console, format, imports, serializers, types
 from reflex.utils import console, format, imports, serializers, types
 
 
 # This module used to export ImportVar itself, so we still import it for export here
 # This module used to export ImportVar itself, so we still import it for export here
-from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.imports import ImportDict, ImportList, ImportVar
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from reflex.state import BaseState
     from reflex.state import BaseState
@@ -116,7 +116,7 @@ class VarData(Base):
     state: str = ""
     state: str = ""
 
 
     # Imports needed to render this var
     # Imports needed to render this var
-    imports: ImportDict = {}
+    imports: ImportList = []
 
 
     # Hooks that need to be present in the component to render this var
     # Hooks that need to be present in the component to render this var
     hooks: Dict[str, None] = {}
     hooks: Dict[str, None] = {}
@@ -126,6 +126,19 @@ class VarData(Base):
     # segments.
     # segments.
     interpolations: List[Tuple[int, int]] = []
     interpolations: List[Tuple[int, int]] = []
 
 
+    def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
+        if isinstance(imports, dict):
+            imports = ImportList.from_import_dict(imports)
+            console.deprecate(
+                feature_name="Passing ImportDict for VarData",
+                reason="use ImportList instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+        elif imports is None:
+            imports = []
+        super().__init__(imports=imports, **kwargs)
+
     @classmethod
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
     def merge(cls, *others: VarData | None) -> VarData | None:
         """Merge multiple var data objects.
         """Merge multiple var data objects.
@@ -137,14 +150,14 @@ class VarData(Base):
             The merged var data object.
             The merged var data object.
         """
         """
         state = ""
         state = ""
-        _imports = {}
+        _imports = []
         hooks = {}
         hooks = {}
         interpolations = []
         interpolations = []
         for var_data in others:
         for var_data in others:
             if var_data is None:
             if var_data is None:
                 continue
                 continue
             state = state or var_data.state
             state = state or var_data.state
-            _imports = imports.merge_imports(_imports, var_data.imports)
+            _imports.extend(var_data.imports)
             hooks.update(var_data.hooks)
             hooks.update(var_data.hooks)
             interpolations += var_data.interpolations
             interpolations += var_data.interpolations
 
 
@@ -180,11 +193,18 @@ class VarData(Base):
 
 
         # Don't compare interpolations - that's added in by the decoder, and
         # Don't compare interpolations - that's added in by the decoder, and
         # not part of the vardata itself.
         # not part of the vardata itself.
+        if not isinstance(self.imports, ImportList):
+            self_imports = ImportList(self.imports).collapse()
+        else:
+            self_imports = self.imports.collapse()
+        if not isinstance(other.imports, ImportList):
+            other_imports = ImportList(other.imports).collapse()
+        else:
+            other_imports = other.imports.collapse()
         return (
         return (
             self.state == other.state
             self.state == other.state
             and self.hooks.keys() == other.hooks.keys()
             and self.hooks.keys() == other.hooks.keys()
-            and imports.collapse_imports(self.imports)
-            == imports.collapse_imports(other.imports)
+            and self_imports == other_imports
         )
         )
 
 
     def dict(self) -> dict:
     def dict(self) -> dict:
@@ -196,10 +216,7 @@ class VarData(Base):
         return {
         return {
             "state": self.state,
             "state": self.state,
             "interpolations": list(self.interpolations),
             "interpolations": list(self.interpolations),
-            "imports": {
-                lib: [import_var.dict() for import_var in import_vars]
-                for lib, import_vars in self.imports.items()
-            },
+            "imports": [import_var.dict() for import_var in self.imports],
             "hooks": self.hooks,
             "hooks": self.hooks,
         }
         }
 
 

+ 34 - 22
tests/compiler/test_compiler.py

@@ -4,8 +4,7 @@ from typing import List
 import pytest
 import pytest
 
 
 from reflex.compiler import compiler, utils
 from reflex.compiler import compiler, utils
-from reflex.utils import imports
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -48,43 +47,56 @@ def test_compile_import_statement(
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "import_dict,test_dicts",
+    "import_list,test_dicts",
     [
     [
-        ({}, []),
+        (ImportList(), []),
         (
         (
-            {"axios": [ImportVar(tag="axios", is_default=True)]},
+            ImportList([ImportVar(library="axios", tag="axios", is_default=True)]),
             [{"lib": "axios", "default": "axios", "rest": []}],
             [{"lib": "axios", "default": "axios", "rest": []}],
         ),
         ),
         (
         (
-            {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
+            ImportList(
+                [
+                    ImportVar(library="axios", tag="foo"),
+                    ImportVar(library="axios", tag="bar"),
+                ]
+            ),
             [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
             [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
         ),
         ),
         (
         (
-            {
-                "axios": [
-                    ImportVar(tag="axios", is_default=True),
-                    ImportVar(tag="foo"),
-                    ImportVar(tag="bar"),
-                ],
-                "react": [ImportVar(tag="react", is_default=True)],
-            },
+            ImportList(
+                [
+                    ImportVar(library="axios", tag="axios", is_default=True),
+                    ImportVar(library="axios", tag="foo"),
+                    ImportVar(library="axios", tag="bar"),
+                    ImportVar(library="react", tag="react", is_default=True),
+                ]
+            ),
             [
             [
                 {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
                 {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
                 {"lib": "react", "default": "react", "rest": []},
                 {"lib": "react", "default": "react", "rest": []},
             ],
             ],
         ),
         ),
         (
         (
-            {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
+            ImportList(
+                [
+                    ImportVar(library="", tag="lib1.js"),
+                    ImportVar(library="", tag="lib2.js"),
+                ]
+            ),
             [
             [
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
             ],
             ],
         ),
         ),
         (
         (
-            {
-                "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
-                "axios": [ImportVar(tag="axios", is_default=True)],
-            },
+            ImportList(
+                [
+                    ImportVar(library="", tag="lib1.js"),
+                    ImportVar(library="", tag="lib2.js"),
+                    ImportVar(library="axios", tag="axios", is_default=True),
+                ]
+            ),
             [
             [
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
@@ -93,14 +105,14 @@ def test_compile_import_statement(
         ),
         ),
     ],
     ],
 )
 )
-def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]):
+def test_compile_imports(import_list: ImportList, test_dicts: List[dict]):
     """Test the compile_imports function.
     """Test the compile_imports function.
 
 
     Args:
     Args:
-        import_dict: The import dictionary.
+        import_list: The list of ImportVar.
         test_dicts: The expected output.
         test_dicts: The expected output.
     """
     """
-    imports = utils.compile_imports(import_dict)
+    imports = utils.compile_imports(import_list)
     for import_dict, test_dict in zip(imports, test_dicts):
     for import_dict, test_dict in zip(imports, test_dicts):
         assert import_dict["lib"] == test_dict["lib"]
         assert import_dict["lib"] == test_dict["lib"]
         assert import_dict["default"] == test_dict["default"]
         assert import_dict["default"] == test_dict["default"]

+ 2 - 2
tests/components/core/test_banner.py

@@ -20,7 +20,7 @@ def test_connection_banner():
         "react",
         "react",
         "/utils/context",
         "/utils/context",
         "/utils/state",
         "/utils/state",
-        "@radix-ui/themes@^3.0.0",
+        "@radix-ui/themes",
         "/env.json",
         "/env.json",
     ]
     ]
 
 
@@ -36,7 +36,7 @@ def test_connection_modal():
         "react",
         "react",
         "/utils/context",
         "/utils/context",
         "/utils/state",
         "/utils/state",
-        "@radix-ui/themes@^3.0.0",
+        "@radix-ui/themes",
         "/env.json",
         "/env.json",
     ]
     ]
 
 

+ 13 - 11
tests/components/test_component.py

@@ -296,11 +296,11 @@ def test_get_imports(component1, component2):
     """
     """
     c1 = component1.create()
     c1 = component1.create()
     c2 = component2.create(c1)
     c2 = component2.create(c1)
-    assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]}
-    assert c2._get_all_imports() == {
-        "react-redux": [ImportVar(tag="connect")],
-        "react": [ImportVar(tag="Component")],
-    }
+    assert c1._get_all_imports() == [ImportVar(library="react", tag="Component")]
+    assert c2._get_all_imports() == [
+        ImportVar(library="react-redux", tag="connect"),
+        ImportVar(library="react", tag="Component"),
+    ]
 
 
 
 
 def test_get_custom_code(component1, component2):
 def test_get_custom_code(component1, component2):
@@ -1514,22 +1514,24 @@ def test_custom_component_get_imports():
     custom_comp = wrapper()
     custom_comp = wrapper()
 
 
     # Inner is not imported directly, but it is imported by the custom component.
     # Inner is not imported directly, but it is imported by the custom component.
-    assert "inner" not in custom_comp._get_all_imports()
+    inner_import = ImportVar(library="inner", tag="Inner")
+    assert inner_import not in custom_comp._get_all_imports()
 
 
     # The imports are only resolved during compilation.
     # The imports are only resolved during compilation.
     _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
     _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
-    assert "inner" in imports_inner
+    assert inner_import in imports_inner
 
 
     outer_comp = outer(c=wrapper())
     outer_comp = outer(c=wrapper())
 
 
     # Libraries are not imported directly, but are imported by the custom component.
     # Libraries are not imported directly, but are imported by the custom component.
-    assert "inner" not in outer_comp._get_all_imports()
-    assert "other" not in outer_comp._get_all_imports()
+    other_import = ImportVar(library="other", tag="Other")
+    assert inner_import not in outer_comp._get_all_imports()
+    assert other_import not in outer_comp._get_all_imports()
 
 
     # The imports are only resolved during compilation.
     # The imports are only resolved during compilation.
     _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
     _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
-    assert "inner" in imports_outer
-    assert "other" in imports_outer
+    assert inner_import in imports_outer
+    assert other_import in imports_outer
 
 
 
 
 def test_custom_component_declare_event_handlers_in_fields():
 def test_custom_component_declare_event_handlers_in_fields():

+ 1 - 1
tests/test_var.py

@@ -837,7 +837,7 @@ def test_state_with_initial_computed_var(
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": [{"library": "/utils/context", "tag": "StateContexts", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}, {"library": "react", "tag": "useContext", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}], "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         ),
         (
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",