소스 검색

Account for imports of @rx.memo components for frontend package installation (#2863)

* CustomComponent ignores the annotation type in favor of the passed value

Do not require rx.memo wrapped functions to have 100% correct annotations.

* test_custom_component_get_imports: test imports from wrapped custom components

* Account for imports of custom components for frontend installation

Create new ImportVar that only install the package to avoid rendering the
imports in the page that contains the custom component itself.

* Handle Imports that should never be installed

* Fix up double negative logic

* Perf Optimization: use the imports we already calculate during compile

Instead of augmenting _get_imports with a weird, slow, recursive crawl and
dictionary reconstruction, just use the imports that we compile into the
components.js file to install frontend packages needed by the custom
components.

Same effect, but adds essentially zero overhead to the compilation.
Masen Furer 1 년 전
부모
커밋
432fcd4a5b
4개의 변경된 파일79개의 추가작업 그리고 19개의 파일을 삭제
  1. 10 5
      reflex/app.py
  2. 14 7
      reflex/compiler/compiler.py
  3. 16 7
      reflex/components/component.py
  4. 39 0
      tests/components/test_component.py

+ 10 - 5
reflex/app.py

@@ -818,6 +818,7 @@ class App(Base):
             compile_results.append((stateful_components_path, stateful_components_code))
 
             result_futures = []
+            custom_components_future = None
 
             def submit_work(fn, *args, **kwargs):
                 """Submit work to the thread pool and add a callback to mark the task as complete.
@@ -847,7 +848,10 @@ class App(Base):
             submit_work(compiler.compile_app, app_root)
 
             # Compile the custom components.
-            submit_work(compiler.compile_components, custom_components)
+            custom_components_future = thread_pool.submit(
+                compiler.compile_components, custom_components
+            )
+            custom_components_future.add_done_callback(mark_complete)
 
             # Compile the root stylesheet with base styles.
             submit_work(compiler.compile_root_stylesheet, self.stylesheets)
@@ -878,14 +882,15 @@ class App(Base):
             # Get imports from AppWrap components.
             all_imports.update(app_root.get_imports())
 
-            # Iterate through all the custom components and add their imports to the all_imports.
-            for component in custom_components:
-                all_imports.update(component.get_imports())
-
             # Wait for all compilation tasks to complete.
             for future in concurrent.futures.as_completed(result_futures):
                 compile_results.append(future.result())
 
+            # Iterate through all the custom components and add their imports to the all_imports.
+            custom_components_result = custom_components_future.result()
+            compile_results.append(custom_components_result[:2])
+            all_imports.update(custom_components_result[2])
+
             # Empty the .web pages directory.
             compiler.purge_web_pages_dir()
 

+ 14 - 7
reflex/compiler/compiler.py

@@ -186,7 +186,9 @@ def _compile_component(component: Component) -> str:
     return templates.COMPONENT.render(component=component)
 
 
-def _compile_components(components: set[CustomComponent]) -> str:
+def _compile_components(
+    components: set[CustomComponent],
+) -> tuple[str, Dict[str, list[ImportVar]]]:
     """Compile the components.
 
     Args:
@@ -208,9 +210,12 @@ def _compile_components(components: set[CustomComponent]) -> str:
         imports = utils.merge_imports(imports, component_imports)
 
     # Compile the components page.
-    return templates.COMPONENTS.render(
-        imports=utils.compile_imports(imports),
-        components=component_renders,
+    return (
+        templates.COMPONENTS.render(
+            imports=utils.compile_imports(imports),
+            components=component_renders,
+        ),
+        imports,
     )
 
 
@@ -401,7 +406,9 @@ def compile_page(
     return output_path, code
 
 
-def compile_components(components: set[CustomComponent]):
+def compile_components(
+    components: set[CustomComponent],
+) -> tuple[str, str, Dict[str, list[ImportVar]]]:
     """Compile the custom components.
 
     Args:
@@ -414,8 +421,8 @@ def compile_components(components: set[CustomComponent]):
     output_path = utils.get_components_path()
 
     # Compile the components.
-    code = _compile_components(components)
-    return output_path, code
+    code, imports = _compile_components(components)
+    return output_path, code, imports
 
 
 def compile_stateful_components(

+ 16 - 7
reflex/components/component.py

@@ -1265,6 +1265,9 @@ class CustomComponent(Component):
     # The props of the component.
     props: Dict[str, Any] = {}
 
+    # Props that reference other components.
+    component_props: Dict[str, Component] = {}
+
     def __init__(self, *args, **kwargs):
         """Initialize the custom component.
 
@@ -1296,17 +1299,13 @@ class CustomComponent(Component):
                 self.props[format.to_camel_case(key)] = value
                 continue
 
-            # Convert the type to a Var, then get the type of the var.
-            if not types._issubclass(type_, Var):
-                type_ = Var[type_]
-            type_ = types.get_args(type_)[0]
-
             # Handle subclasses of Base.
-            if types._issubclass(type_, Base):
+            if isinstance(value, Base):
                 base_value = Var.create(value)
 
                 # Track hooks and imports associated with Component instances.
-                if base_value is not None and types._issubclass(type_, Component):
+                if base_value is not None and isinstance(value, Component):
+                    self.component_props[key] = value
                     value = base_value._replace(
                         merge_var_data=VarData(  # type: ignore
                             imports=value.get_imports(),
@@ -1373,6 +1372,16 @@ class CustomComponent(Component):
             custom_components |= self.get_component(self).get_custom_components(
                 seen=seen
             )
+
+        # Fetch custom components from props as well.
+        for child_component in self.component_props.values():
+            if child_component.tag is None:
+                continue
+            if child_component.tag not in seen:
+                seen.add(child_component.tag)
+                if isinstance(child_component, CustomComponent):
+                    custom_components |= {child_component}
+                custom_components |= child_component.get_custom_components(seen=seen)
         return custom_components
 
     def _render(self) -> Tag:

+ 39 - 0
tests/components/test_component.py

@@ -4,6 +4,7 @@ import pytest
 
 import reflex as rx
 from reflex.base import Base
+from reflex.compiler.compiler import compile_components
 from reflex.components.base.bare import Bare
 from reflex.components.chakra.layout.box import Box
 from reflex.components.component import (
@@ -1269,3 +1270,41 @@ def test_deprecated_props(capsys):
     assert "type={`type1`}" in c2_1_render["props"]
     assert "min={`min1`}" in c2_1_render["props"]
     assert "max={`max1`}" in c2_1_render["props"]
+
+
+def test_custom_component_get_imports():
+    class Inner(Component):
+        tag = "Inner"
+        library = "inner"
+
+    class Other(Component):
+        tag = "Other"
+        library = "other"
+
+    @rx.memo
+    def wrapper():
+        return Inner.create()
+
+    @rx.memo
+    def outer(c: Component):
+        return Other.create(c)
+
+    custom_comp = wrapper()
+
+    # Inner is not imported directly, but it is imported by the custom component.
+    assert "inner" not in custom_comp.get_imports()
+
+    # The imports are only resolved during compilation.
+    _, _, imports_inner = compile_components(custom_comp.get_custom_components())
+    assert "inner" in imports_inner
+
+    outer_comp = outer(c=wrapper())
+
+    # Libraries are not imported directly, but are imported by the custom component.
+    assert "inner" not in outer_comp.get_imports()
+    assert "other" not in outer_comp.get_imports()
+
+    # The imports are only resolved during compilation.
+    _, _, imports_outer = compile_components(outer_comp.get_custom_components())
+    assert "inner" in imports_outer
+    assert "other" in imports_outer