Pārlūkot izejas kodu

Track all `rx.memo` components (#5172)

* Automatically compile any `@rx.memo` decorated function

If a component is memoized anywhere in the app, include the component in the
generated output.

Avoid extra component tree walk, since we can know immediately what all of the
custom components are. Any perf optimization gained by not compiling unused
memo functions is handily saved by avoiding the tree walk.

* dynamic: bundle local $/utils/components (rx.memo) module

allow `@rx.memo` decorated functions to be referenced by dynamic components
(which allows working around most limitations with dynamic components).

move special cases for $/ prefix modules to `_normalize_library_path` and
include them in `bundled_libraries` so they can be checked at runtime.

* fixup memo registry

* Pass dummy EventSpec to rx.memo function
Masen Furer 3 nedēļas atpakaļ
vecāks
revīzija
f3df5262d8

+ 3 - 9
reflex/app.py

@@ -45,6 +45,7 @@ from reflex.components.base.error_boundary import ErrorBoundary
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.strict_mode import StrictMode
 from reflex.components.base.strict_mode import StrictMode
 from reflex.components.component import (
 from reflex.components.component import (
+    CUSTOM_COMPONENTS,
     Component,
     Component,
     ComponentStyle,
     ComponentStyle,
     evaluate_style_namespaces,
     evaluate_style_namespaces,
@@ -1222,9 +1223,8 @@ class App(MiddlewareMixin, LifespanMixin):
 
 
         progress.advance(task)
         progress.advance(task)
 
 
-        # Track imports and custom components found.
+        # Track imports found.
         all_imports = {}
         all_imports = {}
-        custom_components = set()
 
 
         # This has to happen before compiling stateful components as that
         # This has to happen before compiling stateful components as that
         # prevents recursive functions from reaching all components.
         # prevents recursive functions from reaching all components.
@@ -1235,9 +1235,6 @@ class App(MiddlewareMixin, LifespanMixin):
             # Add the app wrappers from this component.
             # Add the app wrappers from this component.
             app_wrappers.update(component._get_all_app_wrap_components())
             app_wrappers.update(component._get_all_app_wrap_components())
 
 
-            # Add the custom components from the page to the set.
-            custom_components |= component._get_all_custom_components()
-
         if (toaster := self.toaster) is not None:
         if (toaster := self.toaster) is not None:
             from reflex.components.component import memo
             from reflex.components.component import memo
 
 
@@ -1255,9 +1252,6 @@ class App(MiddlewareMixin, LifespanMixin):
             if component is not None:
             if component is not None:
                 app_wrappers[key] = component
                 app_wrappers[key] = component
 
 
-        for component in app_wrappers.values():
-            custom_components |= component._get_all_custom_components()
-
         if self.error_boundary:
         if self.error_boundary:
             from reflex.compiler.compiler import into_component
             from reflex.compiler.compiler import into_component
 
 
@@ -1382,7 +1376,7 @@ class App(MiddlewareMixin, LifespanMixin):
             custom_components_output,
             custom_components_output,
             custom_components_result,
             custom_components_result,
             custom_components_imports,
             custom_components_imports,
-        ) = compiler.compile_components(custom_components)
+        ) = compiler.compile_components(set(CUSTOM_COMPONENTS.values()))
         compile_results.append((custom_components_output, custom_components_result))
         compile_results.append((custom_components_output, custom_components_result))
         all_imports.update(custom_components_imports)
         all_imports.update(custom_components_imports)
 
 

+ 1 - 4
reflex/compiler/compiler.py

@@ -56,7 +56,7 @@ def _normalize_library_name(lib: str) -> str:
     """
     """
     if lib == "react":
     if lib == "react":
         return "React"
         return "React"
-    return lib.replace("@", "").replace("/", "_").replace("-", "_")
+    return lib.replace("$/", "").replace("@", "").replace("/", "_").replace("-", "_")
 
 
 
 
 def _compile_app(app_root: Component) -> str:
 def _compile_app(app_root: Component) -> str:
@@ -72,9 +72,6 @@ def _compile_app(app_root: Component) -> str:
 
 
     window_libraries = [
     window_libraries = [
         (_normalize_library_name(name), name) for name in bundled_libraries
         (_normalize_library_name(name), name) for name in bundled_libraries
-    ] + [
-        ("utils_context", f"$/{constants.Dirs.UTILS}/context"),
-        ("utils_state", f"$/{constants.Dirs.UTILS}/state"),
     ]
     ]
 
 
     return templates.APP_ROOT.render(
     return templates.APP_ROOT.render(

+ 39 - 57
reflex/components/component.py

@@ -1647,32 +1647,6 @@ class Component(BaseComponent, ABC):
 
 
         return refs
         return refs
 
 
-    def _get_all_custom_components(
-        self, seen: set[str] | None = None
-    ) -> set[CustomComponent]:
-        """Get all the custom components used by the component.
-
-        Args:
-            seen: The tags of the components that have already been seen.
-
-        Returns:
-            The set of custom components.
-        """
-        custom_components = set()
-
-        # Store the seen components in a set to avoid infinite recursion.
-        if seen is None:
-            seen = set()
-        for child in self.children:
-            # Skip BaseComponent and StatefulComponent children.
-            if not isinstance(child, Component):
-                continue
-            custom_components |= child._get_all_custom_components(seen=seen)
-        for component in self._get_components_in_props():
-            if isinstance(component, Component) and component.tag is not None:
-                custom_components |= component._get_all_custom_components(seen=seen)
-        return custom_components
-
     @property
     @property
     def import_var(self):
     def import_var(self):
         """The tag to import.
         """The tag to import.
@@ -1857,37 +1831,6 @@ class CustomComponent(Component):
         """
         """
         return set()
         return set()
 
 
-    def _get_all_custom_components(
-        self, seen: set[str] | None = None
-    ) -> set[CustomComponent]:
-        """Get all the custom components used by the component.
-
-        Args:
-            seen: The tags of the components that have already been seen.
-
-        Raises:
-            ValueError: If the tag is not set.
-
-        Returns:
-            The set of custom components.
-        """
-        if self.tag is None:
-            raise ValueError("The tag must be set.")
-
-        # Store the seen components in a set to avoid infinite recursion.
-        if seen is None:
-            seen = set()
-        custom_components = {self} | super()._get_all_custom_components(seen=seen)
-
-        # Avoid adding the same component twice.
-        if self.tag not in seen:
-            seen.add(self.tag)
-            custom_components |= self.get_component(self)._get_all_custom_components(
-                seen=seen
-            )
-
-        return custom_components
-
     @staticmethod
     @staticmethod
     def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable:
     def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable:
         """Get the event spec from the args spec.
         """Get the event spec from the args spec.
@@ -1951,6 +1894,42 @@ class CustomComponent(Component):
         return self.component_fn(*self.get_prop_vars())
         return self.component_fn(*self.get_prop_vars())
 
 
 
 
+CUSTOM_COMPONENTS: dict[str, CustomComponent] = {}
+
+
+def _register_custom_component(
+    component_fn: Callable[..., Component],
+):
+    """Register a custom component to be compiled.
+
+    Args:
+        component_fn: The function that creates the component.
+
+    Raises:
+        TypeError: If the tag name cannot be determined.
+    """
+    dummy_props = {
+        prop: (
+            Var(
+                "",
+                _var_type=annotation,
+            )
+            if not types.safe_issubclass(annotation, EventHandler)
+            else EventSpec(handler=EventHandler(fn=lambda: []))
+        )
+        for prop, annotation in typing.get_type_hints(component_fn).items()
+        if prop != "return"
+    }
+    dummy_component = CustomComponent._create(
+        children=[],
+        component_fn=component_fn,
+        **dummy_props,
+    )
+    if dummy_component.tag is None:
+        raise TypeError(f"Could not determine the tag name for {component_fn!r}")
+    CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component
+
+
 def custom_component(
 def custom_component(
     component_fn: Callable[..., Component],
     component_fn: Callable[..., Component],
 ) -> Callable[..., CustomComponent]:
 ) -> Callable[..., CustomComponent]:
@@ -1971,6 +1950,9 @@ def custom_component(
             children=list(children), component_fn=component_fn, **props
             children=list(children), component_fn=component_fn, **props
         )
         )
 
 
+    # Register this component so it can be compiled.
+    _register_custom_component(component_fn)
+
     return wrapper
     return wrapper
 
 
 
 

+ 9 - 1
reflex/components/dynamic.py

@@ -26,7 +26,15 @@ def get_cdn_url(lib: str) -> str:
     return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm"
     return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm"
 
 
 
 
-bundled_libraries = {"react", "@radix-ui/themes", "@emotion/react", "next/link"}
+bundled_libraries = {
+    "react",
+    "@radix-ui/themes",
+    "@emotion/react",
+    "next/link",
+    f"$/{constants.Dirs.UTILS}/context",
+    f"$/{constants.Dirs.UTILS}/state",
+    f"$/{constants.Dirs.UTILS}/components",
+}
 
 
 
 
 def bundle_library(component: Union["Component", str]):
 def bundle_library(component: Union["Component", str]):

+ 0 - 21
reflex/components/markdown/markdown.py

@@ -192,27 +192,6 @@ class Markdown(Component):
             **props,
             **props,
         )
         )
 
 
-    def _get_all_custom_components(
-        self, seen: set[str] | None = None
-    ) -> set[CustomComponent]:
-        """Get all the custom components used by the component.
-
-        Args:
-            seen: The tags of the components that have already been seen.
-
-        Returns:
-            The set of custom components.
-        """
-        custom_components = super()._get_all_custom_components(seen=seen)
-
-        # Get the custom components for each tag.
-        for component in self.component_map.values():
-            custom_components |= component(_MOCK_ARG)._get_all_custom_components(
-                seen=seen
-            )
-
-        return custom_components
-
     def add_imports(self) -> ImportDict | list[ImportDict]:
     def add_imports(self) -> ImportDict | list[ImportDict]:
         """Add imports for the markdown component.
         """Add imports for the markdown component.
 
 

+ 9 - 5
tests/units/components/test_component.py

@@ -5,10 +5,11 @@ import pytest
 
 
 import reflex as rx
 import reflex as rx
 from reflex.base import Base
 from reflex.base import Base
-from reflex.compiler.compiler import compile_components
+from reflex.compiler.utils import compile_custom_component
 from reflex.components.base.bare import Bare
 from reflex.components.base.bare import Bare
 from reflex.components.base.fragment import Fragment
 from reflex.components.base.fragment import Fragment
 from reflex.components.component import (
 from reflex.components.component import (
+    CUSTOM_COMPONENTS,
     Component,
     Component,
     CustomComponent,
     CustomComponent,
     StatefulComponent,
     StatefulComponent,
@@ -877,7 +878,7 @@ def test_create_custom_component(my_component):
     component = rx.memo(my_component)(prop1="test", prop2=1)
     component = rx.memo(my_component)(prop1="test", prop2=1)
     assert component.tag == "MyComponent"
     assert component.tag == "MyComponent"
     assert component.get_props() == {"prop1", "prop2"}
     assert component.get_props() == {"prop1", "prop2"}
-    assert component._get_all_custom_components() == {component}
+    assert component.tag in CUSTOM_COMPONENTS
 
 
 
 
 def test_custom_component_hash(my_component):
 def test_custom_component_hash(my_component):
@@ -1801,10 +1802,13 @@ def test_custom_component_get_imports():
 
 
     # Inner is not imported directly, but it is imported by the custom component.
     # Inner is not imported directly, but it is imported by the custom component.
     assert "inner" not in custom_comp._get_all_imports()
     assert "inner" not in custom_comp._get_all_imports()
+    assert "outer" not in custom_comp._get_all_imports()
 
 
     # The imports are only resolved during compilation.
     # The imports are only resolved during compilation.
-    _, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
+    custom_comp.get_component(custom_comp)
+    _, imports_inner = compile_custom_component(custom_comp)
     assert "inner" in imports_inner
     assert "inner" in imports_inner
+    assert "outer" not in imports_inner
 
 
     outer_comp = outer(c=wrapper())
     outer_comp = outer(c=wrapper())
 
 
@@ -1813,8 +1817,8 @@ def test_custom_component_get_imports():
     assert "other" not in outer_comp._get_all_imports()
     assert "other" not in outer_comp._get_all_imports()
 
 
     # The imports are only resolved during compilation.
     # The imports are only resolved during compilation.
-    _, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
-    assert "inner" in imports_outer
+    _, imports_outer = compile_custom_component(outer_comp)
+    assert "inner" not in imports_outer
     assert "other" in imports_outer
     assert "other" in imports_outer