Ver código fonte

[REF-2523] Implement new public Component API (#3203)

Masen Furer 1 ano atrás
pai
commit
19d8f6c752

+ 1 - 0
reflex/__init__.py

@@ -167,6 +167,7 @@ _MAPPING = {
     "reflex.style": ["style", "toggle_color_mode"],
     "reflex.testing": ["testing"],
     "reflex.utils": ["utils"],
+    "reflex.utils.imports": ["ImportVar"],
     "reflex.vars": ["vars", "cached_var", "Var"],
 }
 

+ 1 - 0
reflex/__init__.pyi

@@ -150,6 +150,7 @@ from reflex import style as style
 from reflex.style import toggle_color_mode as toggle_color_mode
 from reflex import testing as testing
 from reflex import utils as utils
+from reflex.utils.imports import ImportVar as ImportVar
 from reflex import vars as vars
 from reflex.vars import cached_var as cached_var
 from reflex.vars import Var as Var

+ 141 - 0
reflex/components/component.py

@@ -213,6 +213,91 @@ class Component(BaseComponent, ABC):
     # State class associated with this component instance
     State: Optional[Type[reflex.state.State]] = None
 
+    def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
+        """Add imports for the component.
+
+        This method should be implemented by subclasses to add new imports for the component.
+
+        Implementations do NOT need to call super(). The result of calling
+        add_imports in each parent class will be merged internally.
+
+        Returns:
+            The additional imports for this component subclass.
+
+        The format of the return value is a dictionary where the keys are the
+        library names (with optional npm-style version specifications) mapping
+        to a single name to be imported, or a list names to be imported.
+
+        For advanced use cases, the values can be ImportVar instances (for
+        example, to provide an alias or mark that an import is the default
+        export from the given library).
+
+        ```python
+        return {
+            "react": "useEffect",
+            "react-draggable": ["DraggableCore", rx.ImportVar(tag="Draggable", is_default=True)],
+        }
+        ```
+        """
+        return {}
+
+    def add_hooks(self) -> list[str]:
+        """Add hooks inside the component function.
+
+        Hooks are pieces of literal Javascript code that is inserted inside the
+        React component function.
+
+        Each logical hook should be a separate string in the list.
+
+        Common strings will be deduplicated and inserted into the component
+        function only once, so define const variables and other identical code
+        in their own strings to avoid defining the same const or hook multiple
+        times.
+
+        If a hook depends on specific data from the component instance, be sure
+        to use unique values inside the string to _avoid_ deduplication.
+
+        Implementations do NOT need to call super(). The result of calling
+        add_hooks in each parent class will be merged and deduplicated internally.
+
+        Returns:
+            The additional hooks for this component subclass.
+
+        ```python
+        return [
+            "const [count, setCount] = useState(0);",
+            "useEffect(() => { setCount((prev) => prev + 1); console.log(`mounted ${count} times`); }, []);",
+        ]
+        ```
+        """
+        return []
+
+    def add_custom_code(self) -> list[str]:
+        """Add custom Javascript code into the page that contains this component.
+
+        Custom code is inserted at module level, after any imports.
+
+        Each string of custom code is deduplicated per-page, so take care to
+        avoid defining the same const or function differently from different
+        component instances.
+
+        Custom code is useful for defining global functions or constants which
+        can then be referenced inside hooks or used by component vars.
+
+        Implementations do NOT need to call super(). The result of calling
+        add_custom_code in each parent class will be merged and deduplicated internally.
+
+        Returns:
+            The additional custom code for this component subclass.
+
+        ```python
+        return [
+            "const translatePoints = (event) => { return { x: event.clientX, y: event.clientY }; };",
+        ]
+        ```
+        """
+        return []
+
     @classmethod
     def __init_subclass__(cls, **kwargs):
         """Set default properties.
@@ -949,6 +1034,30 @@ class Component(BaseComponent, ABC):
                     return True
         return False
 
+    @classmethod
+    def _iter_parent_classes_with_method(cls, method: str) -> Iterator[Type[Component]]:
+        """Iterate through parent classes that define a given method.
+
+        Used for handling the `add_*` API functions that internally simulate a super() call chain.
+
+        Args:
+            method: The method to look for.
+
+        Yields:
+            The parent classes that define the method (differently than the base).
+        """
+        seen_methods = set([getattr(Component, method)])
+        for clz in cls.mro():
+            if clz is Component:
+                break
+            if not issubclass(clz, Component):
+                continue
+            method_func = getattr(clz, method, None)
+            if not callable(method_func) or method_func in seen_methods:
+                continue
+            seen_methods.add(method_func)
+            yield clz
+
     def _get_custom_code(self) -> str | None:
         """Get custom code for the component.
 
@@ -971,6 +1080,11 @@ class Component(BaseComponent, ABC):
         if custom_code is not None:
             code.add(custom_code)
 
+        # Add the custom code from add_custom_code method.
+        for clz in self._iter_parent_classes_with_method("add_custom_code"):
+            for item in clz.add_custom_code(self):
+                code.add(item)
+
         # Add the custom code for the children.
         for child in self.children:
             code |= child._get_all_custom_code()
@@ -1106,6 +1220,26 @@ class Component(BaseComponent, ABC):
             var._var_data.imports for var in self._get_vars() if var._var_data
         ]
 
+        # If any subclass implements add_imports, merge the imports.
+        def _make_list(
+            value: str | ImportVar | list[str | ImportVar],
+        ) -> list[str | ImportVar]:
+            if isinstance(value, (str, ImportVar)):
+                return [value]
+            return value
+
+        _added_import_dicts = []
+        for clz in self._iter_parent_classes_with_method("add_imports"):
+            _added_import_dicts.append(
+                {
+                    package: [
+                        ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
+                        for tag in _make_list(maybe_tags)
+                    ]
+                    for package, maybe_tags in clz.add_imports(self).items()
+                }
+            )
+
         return imports.merge_imports(
             *self._get_props_imports(),
             self._get_dependencies_imports(),
@@ -1113,6 +1247,7 @@ class Component(BaseComponent, ABC):
             _imports,
             event_imports,
             *var_imports,
+            *_added_import_dicts,
         )
 
     def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
@@ -1248,6 +1383,12 @@ class Component(BaseComponent, ABC):
         if hooks is not None:
             code[hooks] = None
 
+        # Add the hook code from add_hooks for each parent class (this is reversed to preserve
+        # the order of the hooks in the final output)
+        for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
+            for hook in clz.add_hooks(self):
+                code[hook] = None
+
         # Add the hook code for the children.
         for child in self.children:
             code = {**code, **child._get_all_hooks()}

+ 205 - 0
tests/components/test_component.py

@@ -1746,3 +1746,208 @@ def test_invalid_event_trigger():
 
     with pytest.raises(ValueError):
         trigger_comp(on_b=rx.console_log("log"))
+
+
+@pytest.mark.parametrize(
+    "tags",
+    (
+        ["Component"],
+        ["Component", "useState"],
+        [ImportVar(tag="Component")],
+        [ImportVar(tag="Component"), ImportVar(tag="useState")],
+        ["Component", ImportVar(tag="useState")],
+    ),
+)
+def test_component_add_imports(tags):
+    def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
+        return [
+            ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
+            for tag in tags
+        ]
+
+    class BaseComponent(Component):
+        def _get_imports(self) -> imports.ImportDict:
+            return {}
+
+    class Reference(Component):
+        def _get_imports(self) -> imports.ImportDict:
+            return imports.merge_imports(
+                super()._get_imports(),
+                {"react": _list_to_import_vars(tags)},
+                {"foo": [ImportVar(tag="bar")]},
+            )
+
+    class TestBase(Component):
+        def add_imports(
+            self,
+        ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
+            return {"foo": "bar"}
+
+    class Test(TestBase):
+        def add_imports(
+            self,
+        ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
+            return {"react": (tags[0] if len(tags) == 1 else tags)}
+
+    baseline = Reference.create()
+    test = Test.create()
+
+    assert baseline._get_all_imports() == {
+        "react": _list_to_import_vars(tags),
+        "foo": [ImportVar(tag="bar")],
+    }
+    assert test._get_all_imports() == baseline._get_all_imports()
+
+
+def test_component_add_hooks():
+    class BaseComponent(Component):
+        def _get_hooks(self):
+            return "const hook1 = 42"
+
+    class ChildComponent1(BaseComponent):
+        pass
+
+    class GrandchildComponent1(ChildComponent1):
+        def add_hooks(self):
+            return [
+                "const hook2 = 43",
+                "const hook3 = 44",
+            ]
+
+    class GreatGrandchildComponent1(GrandchildComponent1):
+        def add_hooks(self):
+            return [
+                "const hook4 = 45",
+            ]
+
+    class GrandchildComponent2(ChildComponent1):
+        def _get_hooks(self):
+            return "const hook5 = 46"
+
+    class GreatGrandchildComponent2(GrandchildComponent2):
+        def add_hooks(self):
+            return [
+                "const hook2 = 43",
+                "const hook6 = 47",
+            ]
+
+    assert list(BaseComponent()._get_all_hooks()) == ["const hook1 = 42"]
+    assert list(ChildComponent1()._get_all_hooks()) == ["const hook1 = 42"]
+    assert list(GrandchildComponent1()._get_all_hooks()) == [
+        "const hook1 = 42",
+        "const hook2 = 43",
+        "const hook3 = 44",
+    ]
+    assert list(GreatGrandchildComponent1()._get_all_hooks()) == [
+        "const hook1 = 42",
+        "const hook2 = 43",
+        "const hook3 = 44",
+        "const hook4 = 45",
+    ]
+    assert list(GrandchildComponent2()._get_all_hooks()) == ["const hook5 = 46"]
+    assert list(GreatGrandchildComponent2()._get_all_hooks()) == [
+        "const hook5 = 46",
+        "const hook2 = 43",
+        "const hook6 = 47",
+    ]
+    assert list(
+        BaseComponent.create(
+            GrandchildComponent1.create(GreatGrandchildComponent2()),
+            GreatGrandchildComponent1(),
+        )._get_all_hooks(),
+    ) == [
+        "const hook1 = 42",
+        "const hook2 = 43",
+        "const hook3 = 44",
+        "const hook5 = 46",
+        "const hook6 = 47",
+        "const hook4 = 45",
+    ]
+    assert list(
+        Fragment.create(
+            GreatGrandchildComponent2(),
+            GreatGrandchildComponent1(),
+        )._get_all_hooks()
+    ) == [
+        "const hook5 = 46",
+        "const hook2 = 43",
+        "const hook6 = 47",
+        "const hook1 = 42",
+        "const hook3 = 44",
+        "const hook4 = 45",
+    ]
+
+
+def test_component_add_custom_code():
+    class BaseComponent(Component):
+        def _get_custom_code(self):
+            return "const custom_code1 = 42"
+
+    class ChildComponent1(BaseComponent):
+        pass
+
+    class GrandchildComponent1(ChildComponent1):
+        def add_custom_code(self):
+            return [
+                "const custom_code2 = 43",
+                "const custom_code3 = 44",
+            ]
+
+    class GreatGrandchildComponent1(GrandchildComponent1):
+        def add_custom_code(self):
+            return [
+                "const custom_code4 = 45",
+            ]
+
+    class GrandchildComponent2(ChildComponent1):
+        def _get_custom_code(self):
+            return "const custom_code5 = 46"
+
+    class GreatGrandchildComponent2(GrandchildComponent2):
+        def add_custom_code(self):
+            return [
+                "const custom_code2 = 43",
+                "const custom_code6 = 47",
+            ]
+
+    assert BaseComponent()._get_all_custom_code() == {"const custom_code1 = 42"}
+    assert ChildComponent1()._get_all_custom_code() == {"const custom_code1 = 42"}
+    assert GrandchildComponent1()._get_all_custom_code() == {
+        "const custom_code1 = 42",
+        "const custom_code2 = 43",
+        "const custom_code3 = 44",
+    }
+    assert GreatGrandchildComponent1()._get_all_custom_code() == {
+        "const custom_code1 = 42",
+        "const custom_code2 = 43",
+        "const custom_code3 = 44",
+        "const custom_code4 = 45",
+    }
+    assert GrandchildComponent2()._get_all_custom_code() == {"const custom_code5 = 46"}
+    assert GreatGrandchildComponent2()._get_all_custom_code() == {
+        "const custom_code2 = 43",
+        "const custom_code5 = 46",
+        "const custom_code6 = 47",
+    }
+    assert BaseComponent.create(
+        GrandchildComponent1.create(GreatGrandchildComponent2()),
+        GreatGrandchildComponent1(),
+    )._get_all_custom_code() == {
+        "const custom_code1 = 42",
+        "const custom_code2 = 43",
+        "const custom_code3 = 44",
+        "const custom_code4 = 45",
+        "const custom_code5 = 46",
+        "const custom_code6 = 47",
+    }
+    assert Fragment.create(
+        GreatGrandchildComponent2(),
+        GreatGrandchildComponent1(),
+    )._get_all_custom_code() == {
+        "const custom_code1 = 42",
+        "const custom_code2 = 43",
+        "const custom_code3 = 44",
+        "const custom_code4 = 45",
+        "const custom_code5 = 46",
+        "const custom_code6 = 47",
+    }