Преглед изворни кода

WiP: use ImportList internally instead of ImportDict

* deprecate `_get_imports` in favor of new `_get_imports_list`
* `_get_all_imports` now returns an `ImportList`
* Compiler uses `ImportList.collapse` to get an `ImportDict`
Masen Furer пре 1 година
родитељ
комит
35252464a0

+ 10 - 21
reflex/app.py

@@ -79,7 +79,7 @@ from reflex.state import (
 )
 from reflex.utils import console, exceptions, format, prerequisites, types
 from reflex.utils.exec import is_testing_env, should_skip_compile
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList
 
 # Define custom types.
 ComponentCallable = Callable[[], Component]
@@ -618,27 +618,16 @@ class App(Base):
 
             admin.mount_to(self.api)
 
-    def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]):
+    def get_frontend_packages(self, imports: ImportList):
         """Gets the frontend packages to be installed and filters out the unnecessary ones.
 
         Args:
-            imports: A dictionary containing the imports used in the current page.
+            imports: A list containing the imports used in the current page.
 
         Example:
             >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
         """
-        page_imports = {
-            i
-            for i, tags in imports.items()
-            if i
-            not in [
-                *constants.PackageJson.DEPENDENCIES.keys(),
-                *constants.PackageJson.DEV_DEPENDENCIES.keys(),
-            ]
-            and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
-            and i != ""
-            and any(tag.install for tag in tags)
-        }
+        page_imports = [i.package for i in imports.collapse().values() if i.install]
         frontend_packages = get_config().frontend_packages
         _frontend_packages = []
         for package in frontend_packages:
@@ -653,7 +642,7 @@ class App(Base):
                 )
                 continue
             _frontend_packages.append(package)
-        page_imports.update(_frontend_packages)
+        page_imports.extend(_frontend_packages)
         prerequisites.install_frontend_packages(page_imports, get_config())
 
     def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
@@ -794,7 +783,7 @@ class App(Base):
         self.style = evaluate_style_namespaces(self.style)
 
         # Track imports and custom components found.
-        all_imports = {}
+        all_imports = ImportList()
         custom_components = set()
 
         for _route, component in self.pages.items():
@@ -804,7 +793,7 @@ class App(Base):
             component.apply_theme(self.theme)
 
             # Add component._get_all_imports() to all_imports.
-            all_imports.update(component._get_all_imports())
+            all_imports.extend(component._get_all_imports())
 
             # Add the app wrappers from this component.
             app_wrappers.update(component._get_all_app_wrap_components())
@@ -932,10 +921,10 @@ class App(Base):
                 custom_components_imports,
             ) = custom_components_future.result()
             compile_results.append(custom_components_result)
-            all_imports.update(custom_components_imports)
+            all_imports.extend(custom_components_imports)
 
         # Get imports from AppWrap components.
-        all_imports.update(app_root._get_all_imports())
+        all_imports.extend(app_root._get_all_imports())
 
         progress.advance(task)
 
@@ -951,7 +940,7 @@ class App(Base):
         # Setup the next.config.js
         transpile_packages = [
             package
-            for package, import_vars in all_imports.items()
+            for package, import_vars in all_imports.collapse().items()
             if any(import_var.transpile for import_var in import_vars)
         ]
         prerequisites.update_next_config(

+ 25 - 14
reflex/compiler/compiler.py

@@ -19,7 +19,7 @@ from reflex.config import get_config
 from reflex.state import BaseState
 from reflex.style import LIGHT_COLOR_MODE
 from reflex.utils.exec import is_prod_mode
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 from reflex.vars import Var
 
 
@@ -197,25 +197,34 @@ def _compile_components(
     Returns:
         The compiled components.
     """
-    imports = {
-        "react": [ImportVar(tag="memo")],
-        f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
-    }
+    _imports = ImportList(
+        [
+            ImportVar(package="react", tag="memo"),
+            ImportVar(
+                package=f"/{constants.Dirs.STATE_PATH}",
+                tag="E",
+            ),
+            ImportVar(
+                package=f"/{constants.Dirs.STATE_PATH}",
+                tag="isTrue",
+            ),
+        ]
+    )
     component_renders = []
 
     # Compile each component.
     for component in components:
         component_render, component_imports = utils.compile_custom_component(component)
         component_renders.append(component_render)
-        imports = utils.merge_imports(imports, component_imports)
+        _imports.extend(component_imports)
 
     # Compile the components page.
     return (
         templates.COMPONENTS.render(
-            imports=utils.compile_imports(imports),
+            imports=utils.compile_imports(_imports),
             components=component_renders,
         ),
-        imports,
+        _imports,
     )
 
 
@@ -235,7 +244,7 @@ def _compile_stateful_components(
     Returns:
         The rendered stateful components code.
     """
-    all_import_dicts = []
+    all_imports = []
     rendered_components = {}
 
     def get_shared_components_recursive(component: BaseComponent):
@@ -266,7 +275,7 @@ def _compile_stateful_components(
             rendered_components.update(
                 {code: None for code in component._get_all_custom_code()},
             )
-            all_import_dicts.append(component._get_all_imports())
+            all_imports.extend(component._get_all_imports())
 
             # Indicate that this component now imports from the shared file.
             component.rendered_as_shared = True
@@ -275,9 +284,11 @@ def _compile_stateful_components(
         get_shared_components_recursive(page_component)
 
     # Don't import from the file that we're about to create.
-    all_imports = utils.merge_imports(*all_import_dicts)
-    all_imports.pop(
-        f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
+    all_imports = ImportList(
+        imp
+        for imp in all_imports
+        if imp.library
+        != f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}"
     )
 
     return templates.STATEFUL_COMPONENTS.render(
@@ -408,7 +419,7 @@ def compile_page(
 
 def compile_components(
     components: set[CustomComponent],
-) -> tuple[str, str, Dict[str, list[ImportVar]]]:
+) -> tuple[str, str, ImportList]:
     """Compile the custom components.
 
     Args:

+ 13 - 15
reflex/compiler/utils.py

@@ -88,16 +88,16 @@ def validate_imports(import_dict: imports.ImportDict):
                 used_tags[import_name] = lib
 
 
-def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
-    """Compile an import dict.
+def compile_imports(import_list: imports.ImportList) -> list[dict]:
+    """Compile an import list.
 
     Args:
-        import_dict: The import dict to compile.
+        import_list: The import list to compile.
 
     Returns:
-        The list of import dict.
+        The list of template import dict.
     """
-    collapsed_import_dict = imports.collapse_imports(import_dict)
+    collapsed_import_dict = import_list.collapse()
     validate_imports(collapsed_import_dict)
     import_dicts = []
     for lib, fields in collapsed_import_dict.items():
@@ -114,9 +114,6 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
                 import_dicts.append(get_import_dict(module))
             continue
 
-        # remove the version before rendering the package imports
-        lib = format.format_library_name(lib)
-
         import_dicts.append(get_import_dict(lib, default, rest))
     return import_dicts
 
@@ -237,7 +234,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
 
 def compile_custom_component(
     component: CustomComponent,
-) -> tuple[dict, imports.ImportDict]:
+) -> tuple[dict, imports.ImportList]:
     """Compile a custom component.
 
     Args:
@@ -250,11 +247,12 @@ def compile_custom_component(
     render = component.get_component(component)
 
     # Get the imports.
-    imports = {
-        lib: fields
-        for lib, fields in render._get_all_imports().items()
-        if lib != component.library
-    }
+    component_library_name = format.format_library_name(component.library)
+    _imports = imports.ImportList(
+        imp
+        for imp in render._get_all_imports()
+        if imp.library != component_library_name
+    )
 
     # Concatenate the props.
     props = [prop._var_name for prop in component.get_prop_vars()]
@@ -268,7 +266,7 @@ def compile_custom_component(
             "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
             "custom_code": render._get_all_custom_code(),
         },
-        imports,
+        _imports,
     )
 
 

+ 7 - 8
reflex/components/chakra/base.py

@@ -35,19 +35,18 @@ class ChakraComponent(Component):
 
     @classmethod
     @lru_cache(maxsize=None)
-    def _get_dependencies_imports(cls) -> imports.ImportDict:
+    def _get_dependencies_imports(cls) -> imports.ImportList:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
             The dependencies imports of the component.
         """
-        return {
-            dep: [imports.ImportVar(tag=None, render=False)]
-            for dep in [
-                "@chakra-ui/system@2.5.7",
-                "framer-motion@10.16.4",
-            ]
-        }
+        return [
+            imports.ImportVar(
+                package="@chakra-ui/system@2.5.7", tag=None, render=False
+            ),
+            imports.ImportVar(package="framer-motion@10.16.4", tag=None, render=False),
+        ]
 
 
 class ChakraProvider(ChakraComponent):

+ 100 - 59
reflex/components/component.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import copy
+import itertools
 import typing
 from abc import ABC, abstractmethod
 from functools import lru_cache, wraps
@@ -95,11 +96,11 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component.
 
         Returns:
-            The import dict with the required imports.
+            The list of all required ImportVar.
         """
 
     @abstractmethod
@@ -994,17 +995,22 @@ class Component(BaseComponent, ABC):
         # Return the dynamic imports
         return dynamic_imports
 
-    def _get_props_imports(self) -> List[str]:
+    def _get_props_imports(self) -> imports.ImportList:
         """Get the imports needed for components props.
 
         Returns:
-            The  imports for the components props of the component.
+            The imports for the components props of the component.
         """
-        return [
-            getattr(self, prop)._get_all_imports()
-            for prop in self.get_component_props()
-            if getattr(self, prop) is not None
-        ]
+        return imports.ImportList(
+            sum(
+                (
+                    getattr(self, prop)._get_all_imports()
+                    for prop in self.get_component_props()
+                    if getattr(self, prop) is not None
+                ),
+                [],
+            )
+        )
 
     def _should_transpile(self, dep: str | None) -> bool:
         """Check if a dependency should be transpiled.
@@ -1020,97 +1026,129 @@ class Component(BaseComponent, ABC):
             or format.format_library_name(dep or "") in self.transpile_packages
         )
 
-    def _get_dependencies_imports(self) -> imports.ImportDict:
+    def _get_dependencies_imports(self) -> imports.ImportList:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
             The dependencies imports of the component.
         """
-        return {
-            dep: [
-                ImportVar(
-                    tag=None,
-                    render=False,
-                    transpile=self._should_transpile(dep),
-                )
-            ]
+        return imports.ImportList(
+            ImportVar(
+                package=dep,
+                tag=None,
+                render=False,
+                transpile=self._should_transpile(dep),
+            )
             for dep in self.lib_dependencies
-        }
+        )
 
-    def _get_hooks_imports(self) -> imports.ImportDict:
+    def _get_hooks_imports(self) -> imports.ImportList:
         """Get the imports required by certain hooks.
 
         Returns:
             The imports required for all selected hooks.
         """
-        _imports = {}
+        _imports = imports.ImportList()
 
         if self._get_ref_hook():
             # Handle hooks needed for attaching react refs to DOM nodes.
-            _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
-            _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+            _imports.extend(
+                [
+                    ImportVar(package="react", tag="useRef"),
+                    ImportVar(package=f"/{Dirs.STATE_PATH}", tag="refs"),
+                ]
+            )
 
         if self._get_mount_lifecycle_hook():
             # Handle hooks for `on_mount` / `on_unmount`.
-            _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+            _imports.append(ImportVar(package="react", tag="useEffect"))
 
         if self._get_special_hooks():
             # Handle additional internal hooks (autofocus, etc).
-            _imports.setdefault("react", set()).update(
-                {
-                    ImportVar(tag="useRef"),
-                    ImportVar(tag="useEffect"),
-                },
+            _imports.extend(
+                [
+                    ImportVar(package="react", tag="useEffect"),
+                    ImportVar(package="react", tag="useRef"),
+                ]
             )
 
         user_hooks = self._get_hooks()
         if user_hooks is not None and isinstance(user_hooks, Var):
-            _imports = imports.merge_imports(_imports, user_hooks._var_data.imports)  # type: ignore
+            _imports.extend(user_hooks._var_data.imports)
 
         return _imports
 
     def _get_imports(self) -> imports.ImportDict:
-        """Get all the libraries and fields that are used by the component.
+        """Deprecated method to get all the libraries and fields used by the component.
 
         Returns:
             The imports needed by the component.
         """
-        _imports = {}
+        return {}
+
+    def _get_imports_list(self) -> imports.ImportList:
+        """Internal method to get the imports as a list.
+
+        Returns:
+            The imports as a list.
+        """
+        _imports = imports.ImportList(
+            itertools.chain(
+                self._get_props_imports(),
+                self._get_dependencies_imports(),
+                self._get_hooks_imports(),
+            )
+        )
+
+        # Handle deprecated _get_imports
+        import_dict = self._get_imports()
+        if import_dict:
+            console.deprecate(
+                feature_name="_get_imports",
+                reason="use add_imports instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+            _imports.extend(imports.ImportList.from_import_dict(import_dict))
 
         # Import this component's tag from the main library.
         if self.library is not None and self.tag is not None:
-            _imports[self.library] = {self.import_var}
+            _imports.append(self.import_var)
 
         # Get static imports required for event processing.
-        event_imports = Imports.EVENTS if self.event_triggers else {}
+        if self.event_triggers:
+            _imports.append(Imports.EVENTS)
 
         # Collect imports from Vars used directly by this component.
-        var_imports = [
-            var._var_data.imports for var in self._get_vars() if var._var_data
-        ]
-
-        return imports.merge_imports(
-            *self._get_props_imports(),
-            self._get_dependencies_imports(),
-            self._get_hooks_imports(),
-            _imports,
-            event_imports,
-            *var_imports,
-        )
+        for var in self._get_vars():
+            if var._var_data:
+                _imports.extend(var._var_data.imports)
+        return _imports
 
-    def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
+    def _get_all_imports(self, collapse: bool = False) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component and its children.
 
         Args:
-            collapse: Whether to collapse the imports by removing duplicates.
+            collapse: Whether to collapse the imports into a dict (deprecated).
 
         Returns:
-            The import dict with the required imports.
+            The list of all required imports.
         """
-        _imports = imports.merge_imports(
-            self._get_imports(), *[child._get_all_imports() for child in self.children]
+        _imports = imports.ImportList(
+            self._get_imports_list()
+            + sum((child._get_all_imports() for child in self.children), [])
         )
-        return imports.collapse_imports(_imports) if collapse else _imports
+
+        if collapse:
+            console.deprecate(
+                feature_name="collapse kwarg to _get_all_imports",
+                reason="use ImportList.collapse instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+            return _imports.collapse()  # type: ignore
+
+        return _imports
 
     def _get_mount_lifecycle_hook(self) -> str | None:
         """Generate the component lifecycle hook.
@@ -1296,6 +1334,7 @@ class Component(BaseComponent, ABC):
         tag = self.tag.partition(".")[0] if self.tag else None
         alias = self.alias.partition(".")[0] if self.alias else None
         return ImportVar(
+            package=self.library,
             tag=tag,
             is_default=self.is_default,
             alias=alias,
@@ -1575,7 +1614,6 @@ class NoSSRComponent(Component):
         return imports.merge_imports(
             dynamic_import,
             _imports,
-            self._get_dependencies_imports(),
         )
 
     def _get_dynamic_imports(self) -> str:
@@ -1893,18 +1931,21 @@ class StatefulComponent(BaseComponent):
         """
         return {}
 
-    def _get_all_imports(self) -> imports.ImportDict:
+    def _get_all_imports(self) -> imports.ImportList:
         """Get all the libraries and fields that are used by the component.
 
         Returns:
-            The import dict with the required imports.
+            The list of all required imports.
         """
         if self.rendered_as_shared:
-            return {
-                f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
-                    ImportVar(tag=self.tag)
+            return imports.ImportList(
+                [
+                    imports.ImportVar(
+                        package=f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}",
+                        tag=self.tag,
+                    )
                 ]
-            }
+            )
         return self.component._get_all_imports()
 
     def _get_all_dynamic_imports(self) -> set[str]:

+ 10 - 8
reflex/components/core/cond.py

@@ -12,9 +12,9 @@ from reflex.style import LIGHT_COLOR_MODE, color_mode
 from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 
-_IS_TRUE_IMPORT = {
-    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
-}
+_IS_TRUE_IMPORT = imports.ImportList(
+    [imports.ImportVar(library=f"/{Dirs.STATE_PATH}", tag="isTrue")]
+)
 
 
 class Cond(MemoizationLeaf):
@@ -95,11 +95,13 @@ class Cond(MemoizationLeaf):
             cond_state=f"isTrue({self.cond._var_full_name})",
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        return imports.merge_imports(
-            super()._get_imports(),
-            getattr(self.cond._var_data, "imports", {}),
-            _IS_TRUE_IMPORT,
+    def _get_imports_list(self) -> imports.ImportList:
+        return imports.ImportList(
+            [
+                *super()._get_imports_list(),
+                *getattr(self.cond._var_data, "imports", []),
+                *_IS_TRUE_IMPORT,
+            ]
         )
 
     def _apply_theme(self, theme: Component):

+ 6 - 6
reflex/constants/compiler.py

@@ -6,7 +6,7 @@ from types import SimpleNamespace
 
 from reflex.base import Base
 from reflex.constants import Dirs
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 
 # The prefix used to create setters for state vars.
 SETTER_PREFIX = "set_"
@@ -102,11 +102,11 @@ class ComponentName(Enum):
 class Imports(SimpleNamespace):
     """Common sets of import vars."""
 
-    EVENTS = {
-        "react": {ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
-        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
-    }
+    EVENTS: ImportList = [
+        ImportVar(package="react", tag="useContext"),
+        ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
+        ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
+    ]
 
 
 class Hooks(SimpleNamespace):

+ 2 - 5
reflex/utils/format.py

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
 
 from reflex import constants
 from reflex.utils import exceptions, serializers, types
+from reflex.utils.imports import split_library_name_version
 from reflex.utils.serializers import serialize
 from reflex.vars import BaseVar, Var
 
@@ -716,11 +717,7 @@ def format_library_name(library_fullname: str):
     Returns:
         The name without the @version if it was part of the name
     """
-    lib, at, version = library_fullname.rpartition("@")
-    if not lib:
-        lib = at + version
-
-    return lib
+    return split_library_name_version(library_fullname)[0]
 
 
 def json_dumps(obj: Any) -> str:

+ 154 - 5
reflex/utils/imports.py

@@ -3,9 +3,10 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Set
 
 from reflex.base import Base
+from reflex.constants.installer import PackageJson
 
 
 def merge_imports(*imports) -> ImportDict:
@@ -36,9 +37,29 @@ def collapse_imports(imports: ImportDict) -> ImportDict:
     return {lib: list(set(import_vars)) for lib, import_vars in imports.items()}
 
 
+def split_library_name_version(library_fullname: str):
+    """Split the name of a library from its version.
+
+    Args:
+        library_fullname: The fullname of the library.
+
+    Returns:
+        A tuple of the library name and version.
+    """
+    lib, at, version = library_fullname.rpartition("@")
+    if not lib:
+        lib = at + version
+        version = None
+
+    return lib, version
+
+
 class ImportVar(Base):
     """An import var."""
 
+    # The package name associated with the tag
+    library: Optional[str]
+
     # The name of the import tag.
     tag: Optional[str]
 
@@ -48,6 +69,12 @@ class ImportVar(Base):
     # The tag alias.
     alias: Optional[str] = None
 
+    # The following fields provide extra information about the import,
+    # but are not factored in when considering hash or equality
+
+    # The version of the package
+    version: Optional[str]
+
     # Whether this import need to install the associated lib
     install: Optional[bool] = True
 
@@ -58,6 +85,34 @@ class ImportVar(Base):
     # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
     transpile: Optional[bool] = False
 
+    def __init__(
+        self,
+        *,
+        package: Optional[str] = None,
+        **kwargs,
+    ):
+        if package is not None:
+            if (
+                kwargs.get("library", None) is not None
+                or kwargs.get("version", None) is not None
+            ):
+                raise ValueError(
+                    "Cannot provide 'library' or 'version' as keyword arguments when "
+                    "specifying 'package' as an argument"
+                )
+            kwargs["library"], kwargs["version"] = split_library_name_version(package)
+
+        install = (
+            package is not None
+            # TODO: handle version conflicts
+            and package not in PackageJson.DEPENDENCIES
+            and package not in PackageJson.DEV_DEPENDENCIES
+            and not any(package.startswith(prefix) for prefix in ["/", ".", "next/"])
+            and package != ""
+        )
+        kwargs.setdefault("install", install)
+        super().__init__(**kwargs)
+
     @property
     def name(self) -> str:
         """The name of the import.
@@ -72,6 +127,17 @@ class ImportVar(Base):
         else:
             return self.tag or ""
 
+    @property
+    def package(self) -> str:
+        """The package to install for this import
+
+        Returns:
+            The library name and (optional) version to be installed by npm/bun.
+        """
+        if self.version:
+            return f"{self.library}@{self.version}"
+        return self.library
+
     def __hash__(self) -> int:
         """Define a hash function for the import var.
 
@@ -80,14 +146,97 @@ class ImportVar(Base):
         """
         return hash(
             (
+                self.library,
                 self.tag,
                 self.is_default,
                 self.alias,
-                self.install,
-                self.render,
-                self.transpile,
+                # These do not fundamentally change the import in any way
+                # self.install,
+                # self.render,
+                # self.transpile,
             )
         )
 
+    def __eq__(self, other: ImportVar) -> bool:
+        """Define equality for the import var.
+
+        Args:
+            other: The other import var to compare.
+
+        Returns:
+            Whether the two import vars are equal.
+        """
+        if type(self) != type(other):
+            return NotImplemented
+        return (self.library, self.tag, self.is_default, self.alias) == (
+            other.library,
+            other.tag,
+            other.is_default,
+            other.alias,
+        )
+
+    def collapse(self, other_import_var: ImportVar) -> ImportVar:
+        """Collapse two import vars together.
+
+        Args:
+            other_import_var: The other import var to collapse with.
+
+        Returns:
+            The collapsed import var with sticky props perserved.
+        """
+        if self != other_import_var:
+            raise ValueError("Cannot collapse two import vars with different hashes")
+
+        if self.version is not None and other_import_var.version is not None:
+            if self.version != other_import_var.version:
+                raise ValueError(
+                    "Cannot collapse two import vars with conflicting version specifiers: "
+                    f"{self} {other_import_var}"
+                )
+
+        return type(self)(
+            library=self.library,
+            version=self.version or other_import_var.version,
+            tag=self.tag,
+            is_default=self.is_default,
+            alias=self.alias,
+            install=self.install or other_import_var.install,
+            render=self.render or other_import_var.render,
+            transpile=self.transpile or other_import_var.transpile,
+        )
+
 
-ImportDict = Dict[str, List[ImportVar]]
+class ImportList(List[ImportVar]):
+    """A list of import vars."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        for ix, value in enumerate(self):
+            if not isinstance(value, ImportVar):
+                # convert dicts to ImportVar
+                self[ix] = ImportVar(**value)
+
+    @classmethod
+    def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
+        return [
+            ImportVar(package=lib, **imp.dict())
+            for lib, imps in import_dict.items()
+            for imp in imps
+        ]
+
+    def collapse(self) -> ImportDict:
+        """When collapsing an import list, prefer packages with version specifiers."""
+        collapsed = {}
+        for imp in self:
+            collapsed.setdefault(imp.library, {})
+            if imp in collapsed[imp.library]:
+                # Need to check if the current import has any special properties that need to
+                # be preserved, like the version specifier, install, or transpile.
+                existing_imp = collapsed[imp.library][imp]
+                collapsed[imp.library][imp] = existing_imp.collapse(imp)
+            else:
+                collapsed[imp.library][imp] = imp
+        return {lib: set(imps) for lib, imps in collapsed.items()}
+
+
+ImportDict = Dict[str, Set[ImportVar]]

+ 27 - 10
reflex/vars.py

@@ -37,7 +37,7 @@ from reflex.base import Base
 from reflex.utils import console, format, imports, serializers, types
 
 # This module used to export ImportVar itself, so we still import it for export here
-from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.imports import ImportDict, ImportList, ImportVar
 
 if TYPE_CHECKING:
     from reflex.state import BaseState
@@ -116,7 +116,7 @@ class VarData(Base):
     state: str = ""
 
     # Imports needed to render this var
-    imports: ImportDict = {}
+    imports: ImportList = []
 
     # Hooks that need to be present in the component to render this var
     hooks: Dict[str, None] = {}
@@ -126,6 +126,19 @@ class VarData(Base):
     # segments.
     interpolations: List[Tuple[int, int]] = []
 
+    def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
+        if isinstance(imports, dict):
+            imports = ImportList.from_import_dict(imports)
+            console.deprecate(
+                feature_name="Passing ImportDict for VarData",
+                reason="use ImportList instead",
+                deprecation_version="0.5.0",
+                removal_version="0.6.0",
+            )
+        elif imports is None:
+            imports = []
+        super().__init__(imports=imports, **kwargs)
+
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
         """Merge multiple var data objects.
@@ -137,14 +150,14 @@ class VarData(Base):
             The merged var data object.
         """
         state = ""
-        _imports = {}
+        _imports = []
         hooks = {}
         interpolations = []
         for var_data in others:
             if var_data is None:
                 continue
             state = state or var_data.state
-            _imports = imports.merge_imports(_imports, var_data.imports)
+            _imports.extend(var_data.imports)
             hooks.update(var_data.hooks)
             interpolations += var_data.interpolations
 
@@ -180,11 +193,18 @@ class VarData(Base):
 
         # Don't compare interpolations - that's added in by the decoder, and
         # not part of the vardata itself.
+        if not isinstance(self.imports, ImportList):
+            self_imports = ImportList(self.imports).collapse()
+        else:
+            self_imports = self.imports.collapse()
+        if not isinstance(other.imports, ImportList):
+            other_imports = ImportList(other.imports).collapse()
+        else:
+            other_imports = other.imports.collapse()
         return (
             self.state == other.state
             and self.hooks.keys() == other.hooks.keys()
-            and imports.collapse_imports(self.imports)
-            == imports.collapse_imports(other.imports)
+            and self_imports == other_imports
         )
 
     def dict(self) -> dict:
@@ -196,10 +216,7 @@ class VarData(Base):
         return {
             "state": self.state,
             "interpolations": list(self.interpolations),
-            "imports": {
-                lib: [import_var.dict() for import_var in import_vars]
-                for lib, import_vars in self.imports.items()
-            },
+            "imports": [import_var.dict() for import_var in self.imports],
             "hooks": self.hooks,
         }
 

+ 34 - 22
tests/compiler/test_compiler.py

@@ -4,8 +4,7 @@ from typing import List
 import pytest
 
 from reflex.compiler import compiler, utils
-from reflex.utils import imports
-from reflex.utils.imports import ImportVar
+from reflex.utils.imports import ImportList, ImportVar
 
 
 @pytest.mark.parametrize(
@@ -48,43 +47,56 @@ def test_compile_import_statement(
 
 
 @pytest.mark.parametrize(
-    "import_dict,test_dicts",
+    "import_list,test_dicts",
     [
-        ({}, []),
+        (ImportList(), []),
         (
-            {"axios": [ImportVar(tag="axios", is_default=True)]},
+            ImportList([ImportVar(library="axios", tag="axios", is_default=True)]),
             [{"lib": "axios", "default": "axios", "rest": []}],
         ),
         (
-            {"axios": [ImportVar(tag="foo"), ImportVar(tag="bar")]},
+            ImportList(
+                [
+                    ImportVar(library="axios", tag="foo"),
+                    ImportVar(library="axios", tag="bar"),
+                ]
+            ),
             [{"lib": "axios", "default": "", "rest": ["bar", "foo"]}],
         ),
         (
-            {
-                "axios": [
-                    ImportVar(tag="axios", is_default=True),
-                    ImportVar(tag="foo"),
-                    ImportVar(tag="bar"),
-                ],
-                "react": [ImportVar(tag="react", is_default=True)],
-            },
+            ImportList(
+                [
+                    ImportVar(library="axios", tag="axios", is_default=True),
+                    ImportVar(library="axios", tag="foo"),
+                    ImportVar(library="axios", tag="bar"),
+                    ImportVar(library="react", tag="react", is_default=True),
+                ]
+            ),
             [
                 {"lib": "axios", "default": "axios", "rest": ["bar", "foo"]},
                 {"lib": "react", "default": "react", "rest": []},
             ],
         ),
         (
-            {"": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")]},
+            ImportList(
+                [
+                    ImportVar(library="", tag="lib1.js"),
+                    ImportVar(library="", tag="lib2.js"),
+                ]
+            ),
             [
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
             ],
         ),
         (
-            {
-                "": [ImportVar(tag="lib1.js"), ImportVar(tag="lib2.js")],
-                "axios": [ImportVar(tag="axios", is_default=True)],
-            },
+            ImportList(
+                [
+                    ImportVar(library="", tag="lib1.js"),
+                    ImportVar(library="", tag="lib2.js"),
+                    ImportVar(library="axios", tag="axios", is_default=True),
+                ]
+            ),
             [
                 {"lib": "lib1.js", "default": "", "rest": []},
                 {"lib": "lib2.js", "default": "", "rest": []},
@@ -93,14 +105,14 @@ def test_compile_import_statement(
         ),
     ],
 )
-def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]):
+def test_compile_imports(import_list: ImportList, test_dicts: List[dict]):
     """Test the compile_imports function.
 
     Args:
-        import_dict: The import dictionary.
+        import_list: The list of ImportVar.
         test_dicts: The expected output.
     """
-    imports = utils.compile_imports(import_dict)
+    imports = utils.compile_imports(import_list)
     for import_dict, test_dict in zip(imports, test_dicts):
         assert import_dict["lib"] == test_dict["lib"]
         assert import_dict["default"] == test_dict["default"]

+ 2 - 2
tests/components/core/test_banner.py

@@ -20,7 +20,7 @@ def test_connection_banner():
         "react",
         "/utils/context",
         "/utils/state",
-        "@radix-ui/themes@^3.0.0",
+        "@radix-ui/themes",
         "/env.json",
     ]
 
@@ -36,7 +36,7 @@ def test_connection_modal():
         "react",
         "/utils/context",
         "/utils/state",
-        "@radix-ui/themes@^3.0.0",
+        "@radix-ui/themes",
         "/env.json",
     ]
 

+ 13 - 11
tests/components/test_component.py

@@ -296,11 +296,11 @@ def test_get_imports(component1, component2):
     """
     c1 = component1.create()
     c2 = component2.create(c1)
-    assert c1._get_all_imports() == {"react": [ImportVar(tag="Component")]}
-    assert c2._get_all_imports() == {
-        "react-redux": [ImportVar(tag="connect")],
-        "react": [ImportVar(tag="Component")],
-    }
+    assert c1._get_all_imports() == [ImportVar(library="react", tag="Component")]
+    assert c2._get_all_imports() == [
+        ImportVar(library="react-redux", tag="connect"),
+        ImportVar(library="react", tag="Component"),
+    ]
 
 
 def test_get_custom_code(component1, component2):
@@ -1514,22 +1514,24 @@ def test_custom_component_get_imports():
     custom_comp = wrapper()
 
     # Inner is not imported directly, but it is imported by the custom component.
-    assert "inner" not in custom_comp._get_all_imports()
+    inner_import = ImportVar(library="inner", tag="Inner")
+    assert inner_import not in custom_comp._get_all_imports()
 
     # The imports are only resolved during compilation.
     _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
-    assert "inner" in imports_inner
+    assert inner_import in imports_inner
 
     outer_comp = outer(c=wrapper())
 
     # Libraries are not imported directly, but are imported by the custom component.
-    assert "inner" not in outer_comp._get_all_imports()
-    assert "other" not in outer_comp._get_all_imports()
+    other_import = ImportVar(library="other", tag="Other")
+    assert inner_import not in outer_comp._get_all_imports()
+    assert other_import not in outer_comp._get_all_imports()
 
     # The imports are only resolved during compilation.
     _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
-    assert "inner" in imports_outer
-    assert "other" in imports_outer
+    assert inner_import in imports_outer
+    assert other_import in imports_outer
 
 
 def test_custom_component_declare_event_handlers_in_fields():

+ 1 - 1
tests/test_var.py

@@ -837,7 +837,7 @@ def test_state_with_initial_computed_var(
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true, "transpile": false}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": [{"library": "/utils/context", "tag": "StateContexts", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}, {"library": "react", "tag": "useContext", "is_default": false, "alias": null, "version": null, "install": false, "render": true, "transpile": false}], "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",