Masen Furer 1 год назад
Родитель
Сommit
3423fec2a6

+ 2 - 2
reflex/app.py

@@ -627,7 +627,7 @@ class App(Base):
         Example:
             >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
         """
-        page_imports = [i.package for i in imports.collapse().values() if i.install]
+        page_imports = [i.package for i in imports if i.install and i.package]
         frontend_packages = get_config().frontend_packages
         _frontend_packages = []
         for package in frontend_packages:
@@ -643,7 +643,7 @@ class App(Base):
                 continue
             _frontend_packages.append(package)
         page_imports.extend(_frontend_packages)
-        prerequisites.install_frontend_packages(page_imports, get_config())
+        prerequisites.install_frontend_packages(set(page_imports), get_config())
 
     def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
         for component in tuple(app_wrappers.values()):

+ 1 - 1
reflex/compiler/compiler.py

@@ -188,7 +188,7 @@ def _compile_component(component: Component) -> str:
 
 def _compile_components(
     components: set[CustomComponent],
-) -> tuple[str, Dict[str, list[ImportVar]]]:
+) -> tuple[str, ImportList]:
     """Compile the components.
 
     Args:

+ 1 - 1
reflex/compiler/utils.py

@@ -247,7 +247,7 @@ def compile_custom_component(
     render = component.get_component(component)
 
     # Get the imports.
-    component_library_name = format.format_library_name(component.library)
+    component_library_name = format.format_library_name(component.library or "")
     _imports = imports.ImportList(
         imp
         for imp in render._get_all_imports()

+ 16 - 8
reflex/components/chakra/base.py

@@ -35,7 +35,7 @@ class ChakraComponent(Component):
 
     @classmethod
     @lru_cache(maxsize=None)
-    def _get_dependencies_imports(cls) -> imports.ImportList:
+    def _get_dependencies_imports(cls) -> List[imports.ImportVar]:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
@@ -67,13 +67,21 @@ class ChakraProvider(ChakraComponent):
             theme=Var.create("extendTheme(theme)", _var_is_local=False),
         )
 
-    def _get_imports(self) -> imports.ImportDict:
-        _imports = super()._get_imports()
-        _imports.setdefault(self.__fields__["library"].default, []).append(
-            imports.ImportVar(tag="extendTheme", is_default=False),
-        )
-        _imports.setdefault("/utils/theme.js", []).append(
-            imports.ImportVar(tag="theme", is_default=True),
+    def _get_imports_list(self) -> List[imports.ImportVar]:
+        _imports = super()._get_imports_list()
+        _imports.extend(
+            [
+                imports.ImportVar(
+                    package=self.__fields__["library"].default,
+                    tag="extendTheme",
+                    is_default=False,
+                ),
+                imports.ImportVar(
+                    package="/utils/theme.js",
+                    tag="theme",
+                    is_default=True,
+                ),
+            ],
         )
         return _imports
 

+ 8 - 4
reflex/components/component.py

@@ -1026,7 +1026,7 @@ class Component(BaseComponent, ABC):
             or format.format_library_name(dep or "") in self.transpile_packages
         )
 
-    def _get_dependencies_imports(self) -> imports.ImportList:
+    def _get_dependencies_imports(self) -> List[ImportVar]:
         """Get the imports from lib_dependencies for installing.
 
         Returns:
@@ -1073,7 +1073,11 @@ class Component(BaseComponent, ABC):
             )
 
         user_hooks = self._get_hooks()
-        if user_hooks is not None and isinstance(user_hooks, Var):
+        if (
+            user_hooks is not None
+            and isinstance(user_hooks, Var)
+            and user_hooks._var_data is not None
+        ):
             _imports.extend(user_hooks._var_data.imports)
 
         return _imports
@@ -1086,7 +1090,7 @@ class Component(BaseComponent, ABC):
         """
         return {}
 
-    def _get_imports_list(self) -> imports.ImportList:
+    def _get_imports_list(self) -> List[ImportVar]:
         """Internal method to get the imports as a list.
 
         Returns:
@@ -1117,7 +1121,7 @@ class Component(BaseComponent, ABC):
 
         # Get static imports required for event processing.
         if self.event_triggers:
-            _imports.append(Imports.EVENTS)
+            _imports.extend(Imports.EVENTS)
 
         # Collect imports from Vars used directly by this component.
         for var in self._get_vars():

+ 8 - 5
reflex/components/core/banner.py

@@ -51,11 +51,14 @@ has_too_many_connection_errors: Var = Var.create_safe(
 class WebsocketTargetURL(Bare):
     """A component that renders the websocket target URL."""
 
-    def _get_imports(self) -> imports.ImportDict:
-        return {
-            f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
-            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
-        }
+    def _get_imports_list(self) -> list[imports.ImportVar]:
+        return [
+            imports.ImportVar(
+                library=f"/{Dirs.STATE_PATH}",
+                tag="getBackendURL",
+            ),
+            imports.ImportVar(library="/env.json", tag="env", is_default=True),
+        ]
 
     @classmethod
     def create(cls) -> Component:

+ 35 - 30
reflex/components/markdown/markdown.py

@@ -7,7 +7,6 @@ from functools import lru_cache
 from hashlib import md5
 from typing import Any, Callable, Dict, Union
 
-from reflex.compiler import utils
 from reflex.components.component import Component, CustomComponent
 from reflex.components.radix.themes.layout.list import (
     ListItem,
@@ -154,47 +153,53 @@ class Markdown(Component):
 
         return custom_components
 
-    def _get_imports(self) -> imports.ImportDict:
+    def _get_imports_list(self) -> list[imports.ImportVar]:
         # Import here to avoid circular imports.
         from reflex.components.datadisplay.code import CodeBlock
         from reflex.components.radix.themes.typography.code import Code
 
-        imports = super()._get_imports()
+        _imports = super()._get_imports_list()
 
         # Special markdown imports.
-        imports.update(
-            {
-                "": [ImportVar(tag="katex/dist/katex.min.css")],
-                "remark-math@5.1.1": [
-                    ImportVar(tag=_REMARK_MATH._var_name, is_default=True)
-                ],
-                "remark-gfm@3.0.1": [
-                    ImportVar(tag=_REMARK_GFM._var_name, is_default=True)
-                ],
-                "remark-unwrap-images@4.0.0": [
-                    ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True)
-                ],
-                "rehype-katex@6.0.3": [
-                    ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True)
-                ],
-                "rehype-raw@6.1.1": [
-                    ImportVar(tag=_REHYPE_RAW._var_name, is_default=True)
-                ],
-            }
+        _imports.extend(
+            [
+                ImportVar(library="", tag="katex/dist/katex.min.css"),
+                ImportVar(
+                    package="remark-math@5.1.1",
+                    tag=_REMARK_MATH._var_name,
+                    is_default=True,
+                ),
+                ImportVar(
+                    package="remark-gfm@3.0.1",
+                    tag=_REMARK_GFM._var_name,
+                    is_default=True,
+                ),
+                ImportVar(
+                    package="remark-unwrap-images@4.0.0",
+                    tag=_REMARK_UNWRAP_IMAGES._var_name,
+                    is_default=True,
+                ),
+                ImportVar(
+                    package="remark-katex@6.0.3",
+                    tag=_REHYPE_KATEX._var_name,
+                    is_default=True,
+                ),
+                ImportVar(
+                    package="rehype-raw@6.1.1",
+                    tag=_REHYPE_RAW._var_name,
+                    is_default=True,
+                ),
+            ]
         )
 
         # Get the imports for each component.
         for component in self.component_map.values():
-            imports = utils.merge_imports(
-                imports, component(_MOCK_ARG)._get_all_imports()
-            )
+            _imports.extend(component(_MOCK_ARG)._get_all_imports())
 
         # Get the imports for the code components.
-        imports = utils.merge_imports(
-            imports, CodeBlock.create(theme="light")._get_imports()
-        )
-        imports = utils.merge_imports(imports, Code.create()._get_imports())
-        return imports
+        _imports.extend(CodeBlock.create(theme="light")._get_all_imports())
+        _imports.extend(Code.create()._get_all_imports())
+        return _imports
 
     def get_component(self, tag: str, **props) -> Component:
         """Get the component for a tag and props.

+ 5 - 7
reflex/components/radix/themes/base.py

@@ -243,13 +243,11 @@ class ThemePanel(RadixThemesComponent):
     # Whether the panel is open. Defaults to False.
     default_open: Var[bool]
 
-    def _get_imports(self) -> dict[str, list[imports.ImportVar]]:
-        return imports.merge_imports(
-            super()._get_imports(),
-            {
-                "react": [imports.ImportVar(tag="useEffect")],
-            },
-        )
+    def _get_imports_list(self) -> list[imports.ImportVar]:
+        return [
+            *super()._get_imports_list(),
+            imports.ImportVar(package="react", tag="useEffect"),
+        ]
 
     def _get_hooks(self) -> str | None:
         # The panel freezes the tab if the user color preference differs from the

+ 7 - 5
reflex/constants/compiler.py

@@ -102,11 +102,13 @@ class ComponentName(Enum):
 class Imports(SimpleNamespace):
     """Common sets of import vars."""
 
-    EVENTS: ImportList = [
-        ImportVar(package="react", tag="useContext"),
-        ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
-        ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
-    ]
+    EVENTS: ImportList = ImportList(
+        [
+            ImportVar(package="react", tag="useContext"),
+            ImportVar(package=f"/{Dirs.CONTEXTS_PATH}", tag="EventLoopContext"),
+            ImportVar(package=f"/{Dirs.STATE_PATH}", tag=CompileVars.TO_EVENT),
+        ]
+    )
 
 
 class Hooks(SimpleNamespace):

+ 60 - 25
reflex/utils/imports.py

@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional
 
 from reflex.base import Base
 from reflex.constants.installer import PackageJson
@@ -91,6 +91,15 @@ class ImportVar(Base):
         package: Optional[str] = None,
         **kwargs,
     ):
+        """Create a new ImportVar.
+
+        Args:
+            package: The package to install for this import.
+            **kwargs: The import var fields.
+
+        Raises:
+            ValueError: If the package is provided with library or version.
+        """
         if package is not None:
             if (
                 kwargs.get("library", None) is not None
@@ -128,8 +137,8 @@ class ImportVar(Base):
             return self.tag or ""
 
     @property
-    def package(self) -> str:
-        """The package to install for this import
+    def package(self) -> str | None:
+        """The package to install for this import.
 
         Returns:
             The library name and (optional) version to be installed by npm/bun.
@@ -150,10 +159,6 @@ class ImportVar(Base):
                 self.tag,
                 self.is_default,
                 self.alias,
-                # These do not fundamentally change the import in any way
-                # self.install,
-                # self.render,
-                # self.transpile,
             )
         )
 
@@ -183,16 +188,22 @@ class ImportVar(Base):
 
         Returns:
             The collapsed import var with sticky props perserved.
+
+        Raises:
+            ValueError: If the two import vars have conflicting properties.
         """
         if self != other_import_var:
             raise ValueError("Cannot collapse two import vars with different hashes")
 
-        if self.version is not None and other_import_var.version is not None:
-            if self.version != other_import_var.version:
-                raise ValueError(
-                    "Cannot collapse two import vars with conflicting version specifiers: "
-                    f"{self} {other_import_var}"
-                )
+        if (
+            self.version is not None
+            and other_import_var.version is not None
+            and self.version != other_import_var.version
+        ):
+            raise ValueError(
+                "Cannot collapse two import vars with conflicting version specifiers: "
+                f"{self} {other_import_var}"
+            )
 
         return type(self)(
             library=self.library,
@@ -210,6 +221,15 @@ class ImportList(List[ImportVar]):
     """A list of import vars."""
 
     def __init__(self, *args, **kwargs):
+        """Create a new ImportList (wrapper over `list`).
+
+        Any items that are not already `ImportVar` will be assumed as dicts to convert
+        into an ImportVar.
+
+        Args:
+            *args: The args to pass to list.__init__
+            **kwargs: The kwargs to pass to list.__init__
+        """
         super().__init__(*args, **kwargs)
         for ix, value in enumerate(self):
             if not isinstance(value, ImportVar):
@@ -217,26 +237,41 @@ class ImportList(List[ImportVar]):
                 self[ix] = ImportVar(**value)
 
     @classmethod
-    def from_import_dict(cls, import_dict: ImportDict) -> ImportList:
-        return [
+    def from_import_dict(
+        cls, import_dict: ImportDict | Dict[str, set[ImportVar]]
+    ) -> ImportList:
+        """Create an import list from an import dict.
+
+        Args:
+            import_dict: The import dict to convert.
+
+        Returns:
+            The import list.
+        """
+        return cls(
             ImportVar(package=lib, **imp.dict())
             for lib, imps in import_dict.items()
             for imp in imps
-        ]
+        )
 
     def collapse(self) -> ImportDict:
-        """When collapsing an import list, prefer packages with version specifiers."""
-        collapsed = {}
+        """When collapsing an import list, prefer packages with version specifiers.
+
+        Returns:
+            The collapsed import dict ({library_name: [import_var1, ...]}).
+        """
+        collapsed: dict[str, dict[ImportVar, ImportVar]] = {}
         for imp in self:
-            collapsed.setdefault(imp.library, {})
-            if imp in collapsed[imp.library]:
+            lib = imp.library or ""
+            collapsed.setdefault(lib, {})
+            if imp in collapsed[lib]:
                 # Need to check if the current import has any special properties that need to
                 # be preserved, like the version specifier, install, or transpile.
-                existing_imp = collapsed[imp.library][imp]
-                collapsed[imp.library][imp] = existing_imp.collapse(imp)
+                existing_imp = collapsed[lib][imp]
+                collapsed[lib][imp] = existing_imp.collapse(imp)
             else:
-                collapsed[imp.library][imp] = imp
-        return {lib: set(imps) for lib, imps in collapsed.items()}
+                collapsed[lib][imp] = imp
+        return {lib: list(set(imps)) for lib, imps in collapsed.items()}
 
 
-ImportDict = Dict[str, Set[ImportVar]]
+ImportDict = Dict[str, List[ImportVar]]

+ 49 - 24
reflex/vars.py

@@ -34,7 +34,7 @@ from typing import (
 
 from reflex import constants
 from reflex.base import Base
-from reflex.utils import console, format, imports, serializers, types
+from reflex.utils import console, format, serializers, types
 
 # This module used to export ImportVar itself, so we still import it for export here
 from reflex.utils.imports import ImportDict, ImportList, ImportVar
@@ -116,7 +116,7 @@ class VarData(Base):
     state: str = ""
 
     # Imports needed to render this var
-    imports: ImportList = []
+    imports: ImportList = ImportList()
 
     # Hooks that need to be present in the component to render this var
     hooks: Dict[str, None] = {}
@@ -126,7 +126,24 @@ class VarData(Base):
     # segments.
     interpolations: List[Tuple[int, int]] = []
 
-    def __init__(self, imports: ImportDict | ImportList = None, **kwargs):
+    def __init__(
+        self,
+        imports: ImportList
+        | List[ImportVar | Dict[str, Optional[Union[str, bool]]]]
+        | ImportDict
+        | Dict[str, set[ImportVar]]
+        | None = None,
+        **kwargs,
+    ):
+        """Initialize the VarData.
+
+        If imports is an ImportDict it will be converted to an ImportList and a
+        deprecation warning will be displayed.
+
+        Args:
+            imports: The imports needed to render this var.
+            **kwargs: Additional fields to set.
+        """
         if isinstance(imports, dict):
             imports = ImportList.from_import_dict(imports)
             console.deprecate(
@@ -135,9 +152,12 @@ class VarData(Base):
                 deprecation_version="0.5.0",
                 removal_version="0.6.0",
             )
-        elif imports is None:
-            imports = []
-        super().__init__(imports=imports, **kwargs)
+        else:
+            imports = ImportList(imports or [])
+        super().__init__(
+            imports=imports,  # type: ignore
+            **kwargs,
+        )
 
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None:
@@ -150,7 +170,7 @@ class VarData(Base):
             The merged var data object.
         """
         state = ""
-        _imports = []
+        _imports = ImportList()
         hooks = {}
         interpolations = []
         for var_data in others:
@@ -1059,11 +1079,12 @@ class Var:
                 ",", other, fn="spreadArraysOrObjects", flip=flip
             )._replace(
                 merge_var_data=VarData(
-                    imports={
-                        f"/{constants.Dirs.STATE_PATH}": [
-                            ImportVar(tag="spreadArraysOrObjects")
-                        ]
-                    },
+                    imports=[
+                        ImportVar(
+                            package=f"/{constants.Dirs.STATE_PATH}",
+                            tag="spreadArraysOrObjects",
+                        ),
+                    ],
                 ),
             )
         return self.operation("+", other, flip=flip)
@@ -1612,11 +1633,11 @@ class Var:
                 v2._var_data,
                 step._var_data,
                 VarData(
-                    imports={
-                        "/utils/helpers/range.js": [
-                            ImportVar(tag="range", is_default=True),
-                        ],
-                    },
+                    imports=[
+                        ImportVar(
+                            package="/utils/helpers/range", tag="range", is_default=True
+                        ),
+                    ]
                 ),
             ),
         )
@@ -1644,9 +1665,9 @@ class Var:
             _var_is_string=False,
             _var_full_name_needs_state_prefix=False,
             merge_var_data=VarData(
-                imports={
-                    f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")],
-                },
+                imports=[
+                    ImportVar(package=f"/{constants.Dirs.STATE_PATH}", tag="refs")
+                ],
             ),
         )
 
@@ -1684,10 +1705,14 @@ class Var:
                     format.format_state_name(state_name)
                 ): None
             },
-            imports={
-                f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
-                "react": [ImportVar(tag="useContext")],
-            },
+            imports=ImportList(
+                [
+                    ImportVar(
+                        package=f"/{constants.Dirs.CONTEXTS_PATH}", tag="StateContexts"
+                    ),
+                    ImportVar(package="react", tag="useContext"),
+                ]
+            ),
         )
         self._var_data = VarData.merge(self._var_data, new_var_data)
         self._var_full_name_needs_state_prefix = True

+ 7 - 7
tests/components/core/test_banner.py

@@ -9,14 +9,14 @@ from reflex.components.radix.themes.typography.text import Text
 
 def test_websocket_target_url():
     url = WebsocketTargetURL.create()
-    _imports = url._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == ["/utils/state", "/env.json"]
+    _imports = url._get_all_imports()
+    assert [i.library for i in _imports] == ["/utils/state", "/env.json"]
 
 
 def test_connection_banner():
     banner = ConnectionBanner.create()
-    _imports = banner._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == [
+    _imports = banner._get_all_imports()
+    assert [i.library for i in _imports] == [
         "react",
         "/utils/context",
         "/utils/state",
@@ -31,8 +31,8 @@ def test_connection_banner():
 
 def test_connection_modal():
     modal = ConnectionModal.create()
-    _imports = modal._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == [
+    _imports = modal._get_all_imports()
+    assert [i.library for i in _imports] == [
         "react",
         "/utils/context",
         "/utils/state",
@@ -48,4 +48,4 @@ def test_connection_modal():
 def test_connection_pulser():
     pulser = ConnectionPulser.create()
     _custom_code = pulser._get_all_custom_code()
-    _imports = pulser._get_all_imports(collapse=True)
+    _imports = pulser._get_all_imports()