فهرست منبع

components deserve to be first class props (#4827)

* components deserve to be first class props

* default back to {}

* smarter yield

* how much does caching help?

* only hit the slower path on _are_fields_known

* remove the cache thingy

* cache the inner _get_component_prop_names

* oops

* dang it darglint

* refactor things a bit

* fix events
Khaleel Al-Adhami 3 ماه پیش
والد
کامیت
abab18e165
5فایلهای تغییر یافته به همراه259 افزوده شده و 137 حذف شده
  1. 69 16
      reflex/components/base/bare.py
  2. 168 102
      reflex/components/component.py
  3. 0 8
      reflex/components/core/cond.py
  4. 21 10
      reflex/vars/base.py
  5. 1 1
      tests/units/components/test_component.py

+ 69 - 16
reflex/components/base/bare.py

@@ -2,9 +2,9 @@
 
 from __future__ import annotations
 
-from typing import Any, Iterator
+from typing import Any, Iterator, Sequence
 
-from reflex.components.component import Component, LiteralComponentVar
+from reflex.components.component import BaseComponent, Component, ComponentStyle
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.config import PerformanceMode, environment
@@ -12,7 +12,7 @@ from reflex.utils import console
 from reflex.utils.decorator import once
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
-from reflex.vars.base import VarData
+from reflex.vars.base import GLOBAL_CACHE, VarData
 from reflex.vars.sequence import LiteralStringVar
 
 
@@ -47,6 +47,11 @@ def validate_str(value: str):
             )
 
 
+def _components_from_var(var: Var) -> Sequence[BaseComponent]:
+    var_data = var._get_all_var_data()
+    return var_data.components if var_data else ()
+
+
 class Bare(Component):
     """A component with no tag."""
 
@@ -80,8 +85,9 @@ class Bare(Component):
             The hooks for the component.
         """
         hooks = super()._get_all_hooks_internal()
-        if isinstance(self.contents, LiteralComponentVar):
-            hooks |= self.contents._var_value._get_all_hooks_internal()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                hooks |= component._get_all_hooks_internal()
         return hooks
 
     def _get_all_hooks(self) -> dict[str, VarData | None]:
@@ -91,18 +97,22 @@ class Bare(Component):
             The hooks for the component.
         """
         hooks = super()._get_all_hooks()
-        if isinstance(self.contents, LiteralComponentVar):
-            hooks |= self.contents._var_value._get_all_hooks()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                hooks |= component._get_all_hooks()
         return hooks
 
-    def _get_all_imports(self) -> ParsedImportDict:
+    def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
         """Include the imports for the component.
 
+        Args:
+            collapse: Whether to collapse the imports.
+
         Returns:
             The imports for the component.
         """
-        imports = super()._get_all_imports()
-        if isinstance(self.contents, LiteralComponentVar):
+        imports = super()._get_all_imports(collapse=collapse)
+        if isinstance(self.contents, Var):
             var_data = self.contents._get_all_var_data()
             if var_data:
                 imports |= {k: list(v) for k, v in var_data.imports}
@@ -115,8 +125,9 @@ class Bare(Component):
             The dynamic imports.
         """
         dynamic_imports = super()._get_all_dynamic_imports()
-        if isinstance(self.contents, LiteralComponentVar):
-            dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                dynamic_imports |= component._get_all_dynamic_imports()
         return dynamic_imports
 
     def _get_all_custom_code(self) -> set[str]:
@@ -126,10 +137,24 @@ class Bare(Component):
             The custom code.
         """
         custom_code = super()._get_all_custom_code()
-        if isinstance(self.contents, LiteralComponentVar):
-            custom_code |= self.contents._var_value._get_all_custom_code()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                custom_code |= component._get_all_custom_code()
         return custom_code
 
+    def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
+        """Get the components that should be wrapped in the app.
+
+        Returns:
+            The components that should be wrapped in the app.
+        """
+        app_wrap_components = super()._get_all_app_wrap_components()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                if isinstance(component, Component):
+                    app_wrap_components |= component._get_all_app_wrap_components()
+        return app_wrap_components
+
     def _get_all_refs(self) -> set[str]:
         """Get the refs for the children of the component.
 
@@ -137,8 +162,9 @@ class Bare(Component):
             The refs for the children.
         """
         refs = super()._get_all_refs()
-        if isinstance(self.contents, LiteralComponentVar):
-            refs |= self.contents._var_value._get_all_refs()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                refs |= component._get_all_refs()
         return refs
 
     def _render(self) -> Tag:
@@ -148,6 +174,33 @@ class Bare(Component):
             return Tagless(contents=f"{{{self.contents!s}}}")
         return Tagless(contents=str(self.contents))
 
+    def _add_style_recursive(
+        self, style: ComponentStyle, theme: Component | None = None
+    ) -> Component:
+        """Add style to the component and its children.
+
+        Args:
+            style: The style to add.
+            theme: The theme to add.
+
+        Returns:
+            The component with the style added.
+        """
+        new_self = super()._add_style_recursive(style, theme)
+
+        are_components_touched = False
+
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                if isinstance(component, Component):
+                    component._add_style_recursive(style, theme)
+                    are_components_touched = True
+
+        if are_components_touched:
+            GLOBAL_CACHE.clear()
+
+        return new_self
+
     def _get_vars(
         self, include_children: bool = False, ignore_ids: set[int] | None = None
     ) -> Iterator[Var]:

+ 168 - 102
reflex/components/component.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import copy
 import dataclasses
+import inspect
 import typing
 from abc import ABC, abstractmethod
 from functools import lru_cache, wraps
@@ -21,6 +22,8 @@ from typing import (
     Set,
     Type,
     Union,
+    get_args,
+    get_origin,
 )
 
 from typing_extensions import Self
@@ -43,6 +46,7 @@ from reflex.constants import (
 from reflex.constants.compiler import SpecialAttributes
 from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.event import (
+    EventActionsMixin,
     EventCallback,
     EventChain,
     EventHandler,
@@ -191,6 +195,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
     return types._isinstance(obj, type_hint, nested=1)
 
 
+def _components_from(
+    component_or_var: Union[BaseComponent, Var],
+) -> tuple[BaseComponent, ...]:
+    """Get the components from a component or Var.
+
+    Args:
+        component_or_var: The component or Var to get the components from.
+
+    Returns:
+        The components.
+    """
+    if isinstance(component_or_var, Var):
+        var_data = component_or_var._get_all_var_data()
+        return var_data.components if var_data else ()
+    if isinstance(component_or_var, BaseComponent):
+        return (component_or_var,)
+    return ()
+
+
 class Component(BaseComponent, ABC):
     """A component with style, event trigger and other props."""
 
@@ -489,7 +512,7 @@ class Component(BaseComponent, ABC):
 
         # Remove any keys that were added as events.
         for key in kwargs["event_triggers"]:
-            del kwargs[key]
+            kwargs.pop(key, None)
 
         # Place data_ and aria_ attributes into custom_attrs
         special_attributes = tuple(
@@ -665,13 +688,22 @@ class Component(BaseComponent, ABC):
         """
         return set()
 
+    @classmethod
+    def _are_fields_known(cls) -> bool:
+        """Check if all fields are known at compile time. True for most components.
+
+        Returns:
+            Whether all fields are known at compile time.
+        """
+        return True
+
     @classmethod
     @lru_cache(maxsize=None)
-    def get_component_props(cls) -> set[str]:
-        """Get the props that expected a component as value.
+    def _get_component_prop_names(cls) -> Set[str]:
+        """Get the names of the component props. NOTE: This assumes all fields are known.
 
         Returns:
-            The components props.
+            The names of the component props.
         """
         return {
             name
@@ -680,6 +712,26 @@ class Component(BaseComponent, ABC):
             and types._issubclass(field.outer_type_, Component)
         }
 
+    def _get_components_in_props(self) -> Sequence[BaseComponent]:
+        """Get the components in the props.
+
+        Returns:
+            The components in the props
+        """
+        if self._are_fields_known():
+            return [
+                component
+                for name in self._get_component_prop_names()
+                for component in _components_from(getattr(self, name))
+            ]
+        return [
+            component
+            for prop in self.get_props()
+            if (value := getattr(self, prop)) is not None
+            and isinstance(value, (BaseComponent, Var))
+            for component in _components_from(value)
+        ]
+
     @classmethod
     def create(cls, *children, **props) -> Self:
         """Create the component.
@@ -1136,6 +1188,9 @@ class Component(BaseComponent, ABC):
         if custom_code is not None:
             code.add(custom_code)
 
+        for component in self._get_components_in_props():
+            code |= component._get_all_custom_code()
+
         # Add the custom code from add_custom_code method.
         for clz in self._iter_parent_classes_with_method("add_custom_code"):
             for item in clz.add_custom_code(self):
@@ -1163,7 +1218,7 @@ class Component(BaseComponent, ABC):
             The dynamic imports.
         """
         # Store the import in a set to avoid duplicates.
-        dynamic_imports = set()
+        dynamic_imports: set[str] = set()
 
         # Get dynamic import for this component.
         dynamic_import = self._get_dynamic_imports()
@@ -1174,25 +1229,12 @@ class Component(BaseComponent, ABC):
         for child in self.children:
             dynamic_imports |= child._get_all_dynamic_imports()
 
-        for prop in self.get_component_props():
-            if getattr(self, prop) is not None:
-                dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
+        for component in self._get_components_in_props():
+            dynamic_imports |= component._get_all_dynamic_imports()
 
         # Return the dynamic imports
         return dynamic_imports
 
-    def _get_props_imports(self) -> List[ParsedImportDict]:
-        """Get the imports needed for components props.
-
-        Returns:
-            The  imports for the components props of the component.
-        """
-        return [
-            getattr(self, prop)._get_all_imports()
-            for prop in self.get_component_props()
-            if getattr(self, prop) is not None
-        ]
-
     def _should_transpile(self, dep: str | None) -> bool:
         """Check if a dependency should be transpiled.
 
@@ -1303,7 +1345,6 @@ class Component(BaseComponent, ABC):
             )
 
         return imports.merge_imports(
-            *self._get_props_imports(),
             self._get_dependencies_imports(),
             self._get_hooks_imports(),
             _imports,
@@ -1380,6 +1421,8 @@ class Component(BaseComponent, ABC):
                         for k in var_data.hooks
                     }
                 )
+                for component in var_data.components:
+                    vars_hooks.update(component._get_all_hooks())
         return vars_hooks
 
     def _get_events_hooks(self) -> dict[str, VarData | None]:
@@ -1528,6 +1571,9 @@ class Component(BaseComponent, ABC):
             refs.add(ref)
         for child in self.children:
             refs |= child._get_all_refs()
+        for component in self._get_components_in_props():
+            refs |= component._get_all_refs()
+
         return refs
 
     def _get_all_custom_components(
@@ -1551,6 +1597,9 @@ class Component(BaseComponent, ABC):
             if not isinstance(child, Component):
                 continue
             custom_components |= child._get_all_custom_components(seen=seen)
+        for component in self._get_components_in_props():
+            if isinstance(component, Component) and component.tag is not None:
+                custom_components |= component._get_all_custom_components(seen=seen)
         return custom_components
 
     @property
@@ -1614,17 +1663,65 @@ class CustomComponent(Component):
     # The props of the component.
     props: Dict[str, Any] = {}
 
-    # Props that reference other components.
-    component_props: Dict[str, Component] = {}
-
-    def __init__(self, *args, **kwargs):
+    def __init__(self, **kwargs):
         """Initialize the custom component.
 
         Args:
-            *args: The args to pass to the component.
             **kwargs: The kwargs to pass to the component.
         """
-        super().__init__(*args, **kwargs)
+        component_fn = kwargs.get("component_fn")
+
+        # Set the props.
+        props_types = typing.get_type_hints(component_fn) if component_fn else {}
+        props = {key: value for key, value in kwargs.items() if key in props_types}
+        kwargs = {key: value for key, value in kwargs.items() if key not in props_types}
+
+        event_types = {
+            key
+            for key in props
+            if (
+                (get_origin((annotation := props_types.get(key))) or annotation)
+                == EventHandler
+            )
+        }
+
+        def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
+            type_ = props_types[key]
+
+            return (
+                args[0]
+                if (args := get_args(type_))
+                else (
+                    annotation_args[1]
+                    if get_origin(
+                        (
+                            annotation := inspect.getfullargspec(
+                                component_fn
+                            ).annotations[key]
+                        )
+                    )
+                    is typing.Annotated
+                    and (annotation_args := get_args(annotation))
+                    else no_args_event_spec
+                )
+            )
+
+        super().__init__(
+            event_triggers={
+                key: EventChain.create(
+                    value=props[key],
+                    args_spec=get_args_spec(key),
+                    key=key,
+                )
+                for key in event_types
+            },
+            **kwargs,
+        )
+
+        to_camel_cased_props = {
+            format.to_camel_case(key) for key in props if key not in event_types
+        }
+        self.get_props = lambda: to_camel_cased_props  # pyright: ignore [reportIncompatibleVariableOverride]
 
         # Unset the style.
         self.style = Style()
@@ -1632,51 +1729,36 @@ class CustomComponent(Component):
         # Set the tag to the name of the function.
         self.tag = format.to_title_case(self.component_fn.__name__)
 
-        # Get the event triggers defined in the component declaration.
-        event_triggers_in_component_declaration = self.get_event_triggers()
-
-        # Set the props.
-        props = typing.get_type_hints(self.component_fn)
-        for key, value in kwargs.items():
+        for key, value in props.items():
             # Skip kwargs that are not props.
-            if key not in props:
+            if key not in props_types:
                 continue
 
+            camel_cased_key = format.to_camel_case(key)
+
             # Get the type based on the annotation.
-            type_ = props[key]
+            type_ = props_types[key]
 
             # Handle event chains.
-            if types._issubclass(type_, EventChain):
-                value = EventChain.create(
-                    value=value,
-                    args_spec=event_triggers_in_component_declaration.get(
-                        key, no_args_event_spec
-                    ),
-                    key=key,
+            if types._issubclass(type_, EventActionsMixin):
+                inspect.getfullargspec(component_fn).annotations[key]
+                self.props[camel_cased_key] = EventChain.create(
+                    value=value, args_spec=get_args_spec(key), key=key
                 )
-                self.props[format.to_camel_case(key)] = value
                 continue
 
-            # Handle subclasses of Base.
-            if isinstance(value, Base):
-                base_value = LiteralVar.create(value)
-
-                # Track hooks and imports associated with Component instances.
-                if base_value is not None and isinstance(value, Component):
-                    self.component_props[key] = value
-                    value = base_value._replace(
-                        merge_var_data=VarData(
-                            imports=value._get_all_imports(),
-                            hooks=value._get_all_hooks(),
-                        )
-                    )
-                else:
-                    value = base_value
-            else:
-                value = LiteralVar.create(value)
+            value = LiteralVar.create(value)
+            self.props[camel_cased_key] = value
+            setattr(self, camel_cased_key, value)
 
-            # Set the prop.
-            self.props[format.to_camel_case(key)] = value
+    @classmethod
+    def _are_fields_known(cls) -> bool:
+        """Check if the fields are known.
+
+        Returns:
+            Whether the fields are known.
+        """
+        return False
 
     def __eq__(self, other: Any) -> bool:
         """Check if the component is equal to another.
@@ -1698,7 +1780,7 @@ class CustomComponent(Component):
         return hash(self.tag)
 
     @classmethod
-    def get_props(cls) -> Set[str]:  # pyright: ignore [reportIncompatibleVariableOverride]
+    def get_props(cls) -> Set[str]:
         """Get the props for the component.
 
         Returns:
@@ -1735,27 +1817,8 @@ class CustomComponent(Component):
                 seen=seen
             )
 
-        # Fetch custom components from props as well.
-        for child_component in self.component_props.values():
-            if child_component.tag is None:
-                continue
-            if child_component.tag not in seen:
-                seen.add(child_component.tag)
-                if isinstance(child_component, CustomComponent):
-                    custom_components |= {child_component}
-                custom_components |= child_component._get_all_custom_components(
-                    seen=seen
-                )
         return custom_components
 
-    def _render(self) -> Tag:
-        """Define how to render the component in React.
-
-        Returns:
-            The tag to render.
-        """
-        return super()._render(props=self.props)
-
     def get_prop_vars(self) -> List[Var]:
         """Get the prop vars.
 
@@ -1765,29 +1828,19 @@ class CustomComponent(Component):
         return [
             Var(
                 _js_expr=name,
-                _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
+                _var_type=(
+                    prop._var_type
+                    if isinstance(prop, Var)
+                    else (
+                        type(prop)
+                        if not isinstance(prop, EventActionsMixin)
+                        else EventChain
+                    )
+                ),
             ).guess_type()
             for name, prop in self.props.items()
         ]
 
-    def _get_vars(
-        self, include_children: bool = False, ignore_ids: set[int] | None = None
-    ) -> Iterator[Var]:
-        """Walk all Vars used in this component.
-
-        Args:
-            include_children: Whether to include Vars from children.
-            ignore_ids: The ids to ignore.
-
-        Yields:
-            Each var referenced by the component (props, styles, event handlers).
-        """
-        ignore_ids = ignore_ids or set()
-        yield from super()._get_vars(
-            include_children=include_children, ignore_ids=ignore_ids
-        )
-        yield from filter(lambda prop: isinstance(prop, Var), self.props.values())
-
     @lru_cache(maxsize=None)  # noqa: B019
     def get_component(self) -> Component:
         """Render the component.
@@ -2475,6 +2528,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
             The VarData for the var.
         """
         return VarData.merge(
+            self._var_data,
             VarData(
                 imports={
                     "@emotion/react": [
@@ -2517,9 +2571,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
         Returns:
             The var.
         """
+        var_datas = [
+            var_data
+            for var in value._get_vars(include_children=True)
+            if (var_data := var._get_all_var_data())
+        ]
+
         return LiteralComponentVar(
             _js_expr="",
             _var_type=type(value),
-            _var_data=_var_data,
+            _var_data=VarData.merge(
+                _var_data,
+                *var_datas,
+                VarData(
+                    components=(value,),
+                ),
+            ),
             _var_value=value,
         )

+ 0 - 8
reflex/components/core/cond.py

@@ -61,14 +61,6 @@ class Cond(MemoizationLeaf):
             )
         )
 
-    def _get_props_imports(self):
-        """Get the imports needed for component's props.
-
-        Returns:
-            The imports for the component's props of the component.
-        """
-        return []
-
     def _render(self) -> Tag:
         return CondTag(
             cond=self.cond,

+ 21 - 10
reflex/vars/base.py

@@ -76,6 +76,7 @@ from reflex.utils.types import (
 )
 
 if TYPE_CHECKING:
+    from reflex.components.component import BaseComponent
     from reflex.state import BaseState
 
     from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
@@ -132,6 +133,9 @@ class VarData:
     # Position of the hook in the component
     position: Hooks.HookPosition | None = None
 
+    # Components that are part of this var
+    components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
+
     def __init__(
         self,
         state: str = "",
@@ -140,6 +144,7 @@ class VarData:
         hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
         deps: list[Var] | None = None,
         position: Hooks.HookPosition | None = None,
+        components: Iterable[BaseComponent] | None = None,
     ):
         """Initialize the var data.
 
@@ -150,6 +155,7 @@ class VarData:
             hooks: Hooks that need to be present in the component to render this var.
             deps: Dependencies of the var for useCallback.
             position: Position of the hook in the component.
+            components: Components that are part of this var.
         """
         if isinstance(hooks, str):
             hooks = [hooks]
@@ -164,6 +170,7 @@ class VarData:
         object.__setattr__(self, "hooks", tuple(hooks or {}))
         object.__setattr__(self, "deps", tuple(deps or []))
         object.__setattr__(self, "position", position or None)
+        object.__setattr__(self, "components", tuple(components or []))
 
         if hooks and any(hooks.values()):
             merged_var_data = VarData.merge(self, *hooks.values())
@@ -174,6 +181,7 @@ class VarData:
                 object.__setattr__(self, "hooks", merged_var_data.hooks)
                 object.__setattr__(self, "deps", merged_var_data.deps)
                 object.__setattr__(self, "position", merged_var_data.position)
+                object.__setattr__(self, "components", merged_var_data.components)
 
     def old_school_imports(self) -> ImportDict:
         """Return the imports as a mutable dict.
@@ -242,17 +250,19 @@ class VarData:
         else:
             position = None
 
-        if state or _imports or hooks or field_name or deps or position:
-            return VarData(
-                state=state,
-                field_name=field_name,
-                imports=_imports,
-                hooks=hooks,
-                deps=deps,
-                position=position,
-            )
+        components = tuple(
+            component for var_data in all_var_datas for component in var_data.components
+        )
 
-        return None
+        return VarData(
+            state=state,
+            field_name=field_name,
+            imports=_imports,
+            hooks=hooks,
+            deps=deps,
+            position=position,
+            components=components,
+        )
 
     def __bool__(self) -> bool:
         """Check if the var data is non-empty.
@@ -267,6 +277,7 @@ class VarData:
             or self.field_name
             or self.deps
             or self.position
+            or self.components
         )
 
     @classmethod

+ 1 - 1
tests/units/components/test_component.py

@@ -871,7 +871,7 @@ def test_create_custom_component(my_component):
     """
     component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
     assert component.tag == "MyComponent"
-    assert component.get_props() == set()
+    assert component.get_props() == {"prop1", "prop2"}
     assert component._get_all_custom_components() == {component}