|
@@ -1,4 +1,4 @@
|
|
|
-from typing import Any, Dict, List, Type
|
|
|
+from typing import Any, Dict, List, Type, Union
|
|
|
|
|
|
import pytest
|
|
|
|
|
@@ -1308,3 +1308,45 @@ def test_custom_component_get_imports():
|
|
|
_, _, imports_outer = compile_components(outer_comp.get_custom_components())
|
|
|
assert "inner" in imports_outer
|
|
|
assert "other" in imports_outer
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "tags",
|
|
|
+ (
|
|
|
+ ["Component"],
|
|
|
+ ["Component", "useState"],
|
|
|
+ [ImportVar(tag="Component")],
|
|
|
+ [ImportVar(tag="Component"), ImportVar(tag="useState")],
|
|
|
+ ["Component", ImportVar(tag="useState")],
|
|
|
+ ),
|
|
|
+)
|
|
|
+def test_custom_component_add_imports(tags):
|
|
|
+ def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
|
|
|
+ return [
|
|
|
+ ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
|
|
|
+ for tag in tags
|
|
|
+ ]
|
|
|
+
|
|
|
+ class BaseComponent(Component):
|
|
|
+ def _get_imports(self) -> imports.ImportDict:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ class Reference(Component):
|
|
|
+ def _get_imports(self) -> imports.ImportDict:
|
|
|
+ return imports.merge_imports(
|
|
|
+ super()._get_imports(),
|
|
|
+ {"react": _list_to_import_vars(tags)},
|
|
|
+ )
|
|
|
+
|
|
|
+ class Test(Component):
|
|
|
+ def add_imports(
|
|
|
+ self,
|
|
|
+ ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
|
|
|
+
|
|
|
+ return {"react": (tags[0] if len(tags) == 1 else tags)}
|
|
|
+
|
|
|
+ baseline = Reference.create()
|
|
|
+ test = Test.create()
|
|
|
+
|
|
|
+ assert baseline.get_imports() == {"react": _list_to_import_vars(tags)}
|
|
|
+ assert test.get_imports() == baseline.get_imports()
|