1
0
Эх сурвалжийг харах

[REF-2269] Add `add_imports` API for component class (#2937)

Martin Xu 1 жил өмнө
parent
commit
8edd1dfdc9

+ 27 - 0
reflex/components/component.py

@@ -1017,6 +1017,16 @@ class Component(BaseComponent, ABC):
             )
             )
         return _imports
         return _imports
 
 
+    def add_imports(
+        self,
+    ) -> Dict[str, Union[str, ImportVar, List[str | ImportVar]]]:
+        """User defined imports for the component. Need to be overriden in subclass.
+
+        Returns:
+            The user defined imports as a dict.
+        """
+        return {}
+
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.
 
 
@@ -1037,12 +1047,29 @@ class Component(BaseComponent, ABC):
             var._var_data.imports for var in self._get_vars() if var._var_data
             var._var_data.imports for var in self._get_vars() if var._var_data
         ]
         ]
 
 
+        # If the subclass implements add_imports, merge the imports.
+        def _make_list(
+            value: str | ImportVar | list[str | ImportVar],
+        ) -> list[str | ImportVar]:
+            if isinstance(value, (str, ImportVar)):
+                return [value]
+            return value
+
+        added_imports = {
+            package: [
+                ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
+                for tag in _make_list(maybe_tags)
+            ]
+            for package, maybe_tags in self.add_imports().items()
+        }
+
         return imports.merge_imports(
         return imports.merge_imports(
             *self._get_props_imports(),
             *self._get_props_imports(),
             self._get_dependencies_imports(),
             self._get_dependencies_imports(),
             self._get_hooks_imports(),
             self._get_hooks_imports(),
             _imports,
             _imports,
             event_imports,
             event_imports,
+            added_imports,
             *var_imports,
             *var_imports,
         )
         )
 
 

+ 43 - 1
tests/components/test_component.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Type
+from typing import Any, Dict, List, Type, Union
 
 
 import pytest
 import pytest
 
 
@@ -1308,3 +1308,45 @@ def test_custom_component_get_imports():
     _, _, imports_outer = compile_components(outer_comp.get_custom_components())
     _, _, imports_outer = compile_components(outer_comp.get_custom_components())
     assert "inner" in imports_outer
     assert "inner" in imports_outer
     assert "other" in imports_outer
     assert "other" in imports_outer
+
+
+@pytest.mark.parametrize(
+    "tags",
+    (
+        ["Component"],
+        ["Component", "useState"],
+        [ImportVar(tag="Component")],
+        [ImportVar(tag="Component"), ImportVar(tag="useState")],
+        ["Component", ImportVar(tag="useState")],
+    ),
+)
+def test_custom_component_add_imports(tags):
+    def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
+        return [
+            ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
+            for tag in tags
+        ]
+
+    class BaseComponent(Component):
+        def _get_imports(self) -> imports.ImportDict:
+            return {}
+
+    class Reference(Component):
+        def _get_imports(self) -> imports.ImportDict:
+            return imports.merge_imports(
+                super()._get_imports(),
+                {"react": _list_to_import_vars(tags)},
+            )
+
+    class Test(Component):
+        def add_imports(
+            self,
+        ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
+
+            return {"react": (tags[0] if len(tags) == 1 else tags)}
+
+    baseline = Reference.create()
+    test = Test.create()
+
+    assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
+    assert test.get_imports() == baseline.get_imports()