Преглед изворни кода

Merge branch 'masenf/custom_component_respect_value_not_annotation' into masenf/multiprocess-compile-fx-conflict

Masen Furer пре 1 година
родитељ
комит
04211a3234
4 измењених фајлова са 79 додато и 21 уклоњено
  1. 10 7
      reflex/app.py
  2. 14 7
      reflex/compiler/compiler.py
  3. 16 7
      reflex/components/component.py
  4. 39 0
      tests/components/test_component.py

+ 10 - 7
reflex/app.py

@@ -875,6 +875,7 @@ class App(Base):
 
         with executor:
             result_futures = []
+            custom_components_future = None
 
             def _mark_complete(_=None):
                 progress.advance(task)
@@ -892,7 +893,10 @@ class App(Base):
             _submit_work(ExecutorSafeFunctions.compile_app)
 
             # Compile the custom components.
-            _submit_work(ExecutorSafeFunctions.compile_custom_components)
+            custom_components_future = executor.submit(
+                ExecutorSafeFunctions.compile_custom_components,
+            )
+            custom_components_future.add_done_callback(_mark_complete)
 
             # Compile the root stylesheet with base styles.
             _submit_work(compiler.compile_root_stylesheet, self.stylesheets)
@@ -913,12 +917,11 @@ class App(Base):
             for future in concurrent.futures.as_completed(result_futures):
                 compile_results.append(future.result())
 
-        # Get imports from AppWrap components.
-        all_imports.update(app_root.get_imports())
-
-        # Iterate through all the custom components and add their imports to the all_imports.
-        for component in custom_components:
-            all_imports.update(component.get_imports())
+            # Special case for custom_components, since we need the compiled imports
+            # to install proper frontend packages.
+            *custom_components_result, custom_components_imports = custom_components_future.result()
+            compile_results.append(custom_components_result)
+            all_imports.update(custom_components_imports)
 
         progress.advance(task)
 

+ 14 - 7
reflex/compiler/compiler.py

@@ -186,7 +186,9 @@ def _compile_component(component: Component) -> str:
     return templates.COMPONENT.render(component=component)
 
 
-def _compile_components(components: set[CustomComponent]) -> str:
+def _compile_components(
+    components: set[CustomComponent],
+) -> tuple[str, Dict[str, list[ImportVar]]]:
     """Compile the components.
 
     Args:
@@ -208,9 +210,12 @@ def _compile_components(components: set[CustomComponent]) -> str:
         imports = utils.merge_imports(imports, component_imports)
 
     # Compile the components page.
-    return templates.COMPONENTS.render(
-        imports=utils.compile_imports(imports),
-        components=component_renders,
+    return (
+        templates.COMPONENTS.render(
+            imports=utils.compile_imports(imports),
+            components=component_renders,
+        ),
+        imports,
     )
 
 
@@ -401,7 +406,9 @@ def compile_page(
     return output_path, code
 
 
-def compile_components(components: set[CustomComponent]):
+def compile_components(
+    components: set[CustomComponent],
+) -> tuple[str, str, Dict[str, list[ImportVar]]]:
     """Compile the custom components.
 
     Args:
@@ -414,8 +421,8 @@ def compile_components(components: set[CustomComponent]):
     output_path = utils.get_components_path()
 
     # Compile the components.
-    code = _compile_components(components)
-    return output_path, code
+    code, imports = _compile_components(components)
+    return output_path, code, imports
 
 
 def compile_stateful_components(

+ 16 - 7
reflex/components/component.py

@@ -1265,6 +1265,9 @@ class CustomComponent(Component):
     # The props of the component.
     props: Dict[str, Any] = {}
 
+    # Props that reference other components.
+    component_props: Dict[str, Component] = {}
+
     def __init__(self, *args, **kwargs):
         """Initialize the custom component.
 
@@ -1296,17 +1299,13 @@ class CustomComponent(Component):
                 self.props[format.to_camel_case(key)] = value
                 continue
 
-            # Convert the type to a Var, then get the type of the var.
-            if not types._issubclass(type_, Var):
-                type_ = Var[type_]
-            type_ = types.get_args(type_)[0]
-
             # Handle subclasses of Base.
-            if types._issubclass(type_, Base):
+            if isinstance(value, Base):
                 base_value = Var.create(value)
 
                 # Track hooks and imports associated with Component instances.
-                if base_value is not None and types._issubclass(type_, Component):
+                if base_value is not None and isinstance(value, Component):
+                    self.component_props[key] = value
                     value = base_value._replace(
                         merge_var_data=VarData(  # type: ignore
                             imports=value.get_imports(),
@@ -1373,6 +1372,16 @@ class CustomComponent(Component):
             custom_components |= self.get_component(self).get_custom_components(
                 seen=seen
             )
+
+        # Fetch custom components from props as well.
+        for child_component in self.component_props.values():
+            if child_component.tag is None:
+                continue
+            if child_component.tag not in seen:
+                seen.add(child_component.tag)
+                if isinstance(child_component, CustomComponent):
+                    custom_components |= {child_component}
+                custom_components |= child_component.get_custom_components(seen=seen)
         return custom_components
 
     def _render(self) -> Tag:

+ 39 - 0
tests/components/test_component.py

@@ -4,6 +4,7 @@ import pytest
 
 import reflex as rx
 from reflex.base import Base
+from reflex.compiler.compiler import compile_components
 from reflex.components.base.bare import Bare
 from reflex.components.chakra.layout.box import Box
 from reflex.components.component import (
@@ -1269,3 +1270,41 @@ def test_deprecated_props(capsys):
     assert "type={`type1`}" in c2_1_render["props"]
     assert "min={`min1`}" in c2_1_render["props"]
     assert "max={`max1`}" in c2_1_render["props"]
+
+
+def test_custom_component_get_imports():
+    class Inner(Component):
+        tag = "Inner"
+        library = "inner"
+
+    class Other(Component):
+        tag = "Other"
+        library = "other"
+
+    @rx.memo
+    def wrapper():
+        return Inner.create()
+
+    @rx.memo
+    def outer(c: Component):
+        return Other.create(c)
+
+    custom_comp = wrapper()
+
+    # Inner is not imported directly, but it is imported by the custom component.
+    assert "inner" not in custom_comp.get_imports()
+
+    # The imports are only resolved during compilation.
+    _, _, imports_inner = compile_components(custom_comp.get_custom_components())
+    assert "inner" in imports_inner
+
+    outer_comp = outer(c=wrapper())
+
+    # Libraries are not imported directly, but are imported by the custom component.
+    assert "inner" not in outer_comp.get_imports()
+    assert "other" not in outer_comp.get_imports()
+
+    # The imports are only resolved during compilation.
+    _, _, imports_outer = compile_components(outer_comp.get_custom_components())
+    assert "inner" in imports_outer
+    assert "other" in imports_outer