Bladeren bron

components deserve to be first class props (#4827)

* components deserve to be first class props

* default back to {}

* smarter yield

* how much does caching help?

* only hit the slower path on _are_fields_known

* remove the cache thingy

* cache the inner _get_component_prop_names

* oops

* dang it darglint

* refactor things a bit

* fix events
Khaleel Al-Adhami 3 maanden geleden
bovenliggende
commit
abab18e165

+ 69 - 16
reflex/components/base/bare.py

@@ -2,9 +2,9 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any, Iterator
+from typing import Any, Iterator, Sequence
 
 
-from reflex.components.component import Component, LiteralComponentVar
+from reflex.components.component import BaseComponent, Component, ComponentStyle
 from reflex.components.tags import Tag
 from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.components.tags.tagless import Tagless
 from reflex.config import PerformanceMode, environment
 from reflex.config import PerformanceMode, environment
@@ -12,7 +12,7 @@ from reflex.utils import console
 from reflex.utils.decorator import once
 from reflex.utils.decorator import once
 from reflex.utils.imports import ParsedImportDict
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
 from reflex.vars import BooleanVar, ObjectVar, Var
-from reflex.vars.base import VarData
+from reflex.vars.base import GLOBAL_CACHE, VarData
 from reflex.vars.sequence import LiteralStringVar
 from reflex.vars.sequence import LiteralStringVar
 
 
 
 
@@ -47,6 +47,11 @@ def validate_str(value: str):
             )
             )
 
 
 
 
+def _components_from_var(var: Var) -> Sequence[BaseComponent]:
+    var_data = var._get_all_var_data()
+    return var_data.components if var_data else ()
+
+
 class Bare(Component):
 class Bare(Component):
     """A component with no tag."""
     """A component with no tag."""
 
 
@@ -80,8 +85,9 @@ class Bare(Component):
             The hooks for the component.
             The hooks for the component.
         """
         """
         hooks = super()._get_all_hooks_internal()
         hooks = super()._get_all_hooks_internal()
-        if isinstance(self.contents, LiteralComponentVar):
-            hooks |= self.contents._var_value._get_all_hooks_internal()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                hooks |= component._get_all_hooks_internal()
         return hooks
         return hooks
 
 
     def _get_all_hooks(self) -> dict[str, VarData | None]:
     def _get_all_hooks(self) -> dict[str, VarData | None]:
@@ -91,18 +97,22 @@ class Bare(Component):
             The hooks for the component.
             The hooks for the component.
         """
         """
         hooks = super()._get_all_hooks()
         hooks = super()._get_all_hooks()
-        if isinstance(self.contents, LiteralComponentVar):
-            hooks |= self.contents._var_value._get_all_hooks()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                hooks |= component._get_all_hooks()
         return hooks
         return hooks
 
 
-    def _get_all_imports(self) -> ParsedImportDict:
+    def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
         """Include the imports for the component.
         """Include the imports for the component.
 
 
+        Args:
+            collapse: Whether to collapse the imports.
+
         Returns:
         Returns:
             The imports for the component.
             The imports for the component.
         """
         """
-        imports = super()._get_all_imports()
-        if isinstance(self.contents, LiteralComponentVar):
+        imports = super()._get_all_imports(collapse=collapse)
+        if isinstance(self.contents, Var):
             var_data = self.contents._get_all_var_data()
             var_data = self.contents._get_all_var_data()
             if var_data:
             if var_data:
                 imports |= {k: list(v) for k, v in var_data.imports}
                 imports |= {k: list(v) for k, v in var_data.imports}
@@ -115,8 +125,9 @@ class Bare(Component):
             The dynamic imports.
             The dynamic imports.
         """
         """
         dynamic_imports = super()._get_all_dynamic_imports()
         dynamic_imports = super()._get_all_dynamic_imports()
-        if isinstance(self.contents, LiteralComponentVar):
-            dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                dynamic_imports |= component._get_all_dynamic_imports()
         return dynamic_imports
         return dynamic_imports
 
 
     def _get_all_custom_code(self) -> set[str]:
     def _get_all_custom_code(self) -> set[str]:
@@ -126,10 +137,24 @@ class Bare(Component):
             The custom code.
             The custom code.
         """
         """
         custom_code = super()._get_all_custom_code()
         custom_code = super()._get_all_custom_code()
-        if isinstance(self.contents, LiteralComponentVar):
-            custom_code |= self.contents._var_value._get_all_custom_code()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                custom_code |= component._get_all_custom_code()
         return custom_code
         return custom_code
 
 
+    def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
+        """Get the components that should be wrapped in the app.
+
+        Returns:
+            The components that should be wrapped in the app.
+        """
+        app_wrap_components = super()._get_all_app_wrap_components()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                if isinstance(component, Component):
+                    app_wrap_components |= component._get_all_app_wrap_components()
+        return app_wrap_components
+
     def _get_all_refs(self) -> set[str]:
     def _get_all_refs(self) -> set[str]:
         """Get the refs for the children of the component.
         """Get the refs for the children of the component.
 
 
@@ -137,8 +162,9 @@ class Bare(Component):
             The refs for the children.
             The refs for the children.
         """
         """
         refs = super()._get_all_refs()
         refs = super()._get_all_refs()
-        if isinstance(self.contents, LiteralComponentVar):
-            refs |= self.contents._var_value._get_all_refs()
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                refs |= component._get_all_refs()
         return refs
         return refs
 
 
     def _render(self) -> Tag:
     def _render(self) -> Tag:
@@ -148,6 +174,33 @@ class Bare(Component):
             return Tagless(contents=f"{{{self.contents!s}}}")
             return Tagless(contents=f"{{{self.contents!s}}}")
         return Tagless(contents=str(self.contents))
         return Tagless(contents=str(self.contents))
 
 
+    def _add_style_recursive(
+        self, style: ComponentStyle, theme: Component | None = None
+    ) -> Component:
+        """Add style to the component and its children.
+
+        Args:
+            style: The style to add.
+            theme: The theme to add.
+
+        Returns:
+            The component with the style added.
+        """
+        new_self = super()._add_style_recursive(style, theme)
+
+        are_components_touched = False
+
+        if isinstance(self.contents, Var):
+            for component in _components_from_var(self.contents):
+                if isinstance(component, Component):
+                    component._add_style_recursive(style, theme)
+                    are_components_touched = True
+
+        if are_components_touched:
+            GLOBAL_CACHE.clear()
+
+        return new_self
+
     def _get_vars(
     def _get_vars(
         self, include_children: bool = False, ignore_ids: set[int] | None = None
         self, include_children: bool = False, ignore_ids: set[int] | None = None
     ) -> Iterator[Var]:
     ) -> Iterator[Var]:

+ 168 - 102
reflex/components/component.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 
 import copy
 import copy
 import dataclasses
 import dataclasses
+import inspect
 import typing
 import typing
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from functools import lru_cache, wraps
 from functools import lru_cache, wraps
@@ -21,6 +22,8 @@ from typing import (
     Set,
     Set,
     Type,
     Type,
     Union,
     Union,
+    get_args,
+    get_origin,
 )
 )
 
 
 from typing_extensions import Self
 from typing_extensions import Self
@@ -43,6 +46,7 @@ from reflex.constants import (
 from reflex.constants.compiler import SpecialAttributes
 from reflex.constants.compiler import SpecialAttributes
 from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.constants.state import FRONTEND_EVENT_STATE
 from reflex.event import (
 from reflex.event import (
+    EventActionsMixin,
     EventCallback,
     EventCallback,
     EventChain,
     EventChain,
     EventHandler,
     EventHandler,
@@ -191,6 +195,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
     return types._isinstance(obj, type_hint, nested=1)
     return types._isinstance(obj, type_hint, nested=1)
 
 
 
 
+def _components_from(
+    component_or_var: Union[BaseComponent, Var],
+) -> tuple[BaseComponent, ...]:
+    """Get the components from a component or Var.
+
+    Args:
+        component_or_var: The component or Var to get the components from.
+
+    Returns:
+        The components.
+    """
+    if isinstance(component_or_var, Var):
+        var_data = component_or_var._get_all_var_data()
+        return var_data.components if var_data else ()
+    if isinstance(component_or_var, BaseComponent):
+        return (component_or_var,)
+    return ()
+
+
 class Component(BaseComponent, ABC):
 class Component(BaseComponent, ABC):
     """A component with style, event trigger and other props."""
     """A component with style, event trigger and other props."""
 
 
@@ -489,7 +512,7 @@ class Component(BaseComponent, ABC):
 
 
         # Remove any keys that were added as events.
         # Remove any keys that were added as events.
         for key in kwargs["event_triggers"]:
         for key in kwargs["event_triggers"]:
-            del kwargs[key]
+            kwargs.pop(key, None)
 
 
         # Place data_ and aria_ attributes into custom_attrs
         # Place data_ and aria_ attributes into custom_attrs
         special_attributes = tuple(
         special_attributes = tuple(
@@ -665,13 +688,22 @@ class Component(BaseComponent, ABC):
         """
         """
         return set()
         return set()
 
 
+    @classmethod
+    def _are_fields_known(cls) -> bool:
+        """Check if all fields are known at compile time. True for most components.
+
+        Returns:
+            Whether all fields are known at compile time.
+        """
+        return True
+
     @classmethod
     @classmethod
     @lru_cache(maxsize=None)
     @lru_cache(maxsize=None)
-    def get_component_props(cls) -> set[str]:
-        """Get the props that expected a component as value.
+    def _get_component_prop_names(cls) -> Set[str]:
+        """Get the names of the component props. NOTE: This assumes all fields are known.
 
 
         Returns:
         Returns:
-            The components props.
+            The names of the component props.
         """
         """
         return {
         return {
             name
             name
@@ -680,6 +712,26 @@ class Component(BaseComponent, ABC):
             and types._issubclass(field.outer_type_, Component)
             and types._issubclass(field.outer_type_, Component)
         }
         }
 
 
+    def _get_components_in_props(self) -> Sequence[BaseComponent]:
+        """Get the components in the props.
+
+        Returns:
+            The components in the props
+        """
+        if self._are_fields_known():
+            return [
+                component
+                for name in self._get_component_prop_names()
+                for component in _components_from(getattr(self, name))
+            ]
+        return [
+            component
+            for prop in self.get_props()
+            if (value := getattr(self, prop)) is not None
+            and isinstance(value, (BaseComponent, Var))
+            for component in _components_from(value)
+        ]
+
     @classmethod
     @classmethod
     def create(cls, *children, **props) -> Self:
     def create(cls, *children, **props) -> Self:
         """Create the component.
         """Create the component.
@@ -1136,6 +1188,9 @@ class Component(BaseComponent, ABC):
         if custom_code is not None:
         if custom_code is not None:
             code.add(custom_code)
             code.add(custom_code)
 
 
+        for component in self._get_components_in_props():
+            code |= component._get_all_custom_code()
+
         # Add the custom code from add_custom_code method.
         # Add the custom code from add_custom_code method.
         for clz in self._iter_parent_classes_with_method("add_custom_code"):
         for clz in self._iter_parent_classes_with_method("add_custom_code"):
             for item in clz.add_custom_code(self):
             for item in clz.add_custom_code(self):
@@ -1163,7 +1218,7 @@ class Component(BaseComponent, ABC):
             The dynamic imports.
             The dynamic imports.
         """
         """
         # Store the import in a set to avoid duplicates.
         # Store the import in a set to avoid duplicates.
-        dynamic_imports = set()
+        dynamic_imports: set[str] = set()
 
 
         # Get dynamic import for this component.
         # Get dynamic import for this component.
         dynamic_import = self._get_dynamic_imports()
         dynamic_import = self._get_dynamic_imports()
@@ -1174,25 +1229,12 @@ class Component(BaseComponent, ABC):
         for child in self.children:
         for child in self.children:
             dynamic_imports |= child._get_all_dynamic_imports()
             dynamic_imports |= child._get_all_dynamic_imports()
 
 
-        for prop in self.get_component_props():
-            if getattr(self, prop) is not None:
-                dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
+        for component in self._get_components_in_props():
+            dynamic_imports |= component._get_all_dynamic_imports()
 
 
         # Return the dynamic imports
         # Return the dynamic imports
         return dynamic_imports
         return dynamic_imports
 
 
-    def _get_props_imports(self) -> List[ParsedImportDict]:
-        """Get the imports needed for components props.
-
-        Returns:
-            The  imports for the components props of the component.
-        """
-        return [
-            getattr(self, prop)._get_all_imports()
-            for prop in self.get_component_props()
-            if getattr(self, prop) is not None
-        ]
-
     def _should_transpile(self, dep: str | None) -> bool:
     def _should_transpile(self, dep: str | None) -> bool:
         """Check if a dependency should be transpiled.
         """Check if a dependency should be transpiled.
 
 
@@ -1303,7 +1345,6 @@ class Component(BaseComponent, ABC):
             )
             )
 
 
         return imports.merge_imports(
         return imports.merge_imports(
-            *self._get_props_imports(),
             self._get_dependencies_imports(),
             self._get_dependencies_imports(),
             self._get_hooks_imports(),
             self._get_hooks_imports(),
             _imports,
             _imports,
@@ -1380,6 +1421,8 @@ class Component(BaseComponent, ABC):
                         for k in var_data.hooks
                         for k in var_data.hooks
                     }
                     }
                 )
                 )
+                for component in var_data.components:
+                    vars_hooks.update(component._get_all_hooks())
         return vars_hooks
         return vars_hooks
 
 
     def _get_events_hooks(self) -> dict[str, VarData | None]:
     def _get_events_hooks(self) -> dict[str, VarData | None]:
@@ -1528,6 +1571,9 @@ class Component(BaseComponent, ABC):
             refs.add(ref)
             refs.add(ref)
         for child in self.children:
         for child in self.children:
             refs |= child._get_all_refs()
             refs |= child._get_all_refs()
+        for component in self._get_components_in_props():
+            refs |= component._get_all_refs()
+
         return refs
         return refs
 
 
     def _get_all_custom_components(
     def _get_all_custom_components(
@@ -1551,6 +1597,9 @@ class Component(BaseComponent, ABC):
             if not isinstance(child, Component):
             if not isinstance(child, Component):
                 continue
                 continue
             custom_components |= child._get_all_custom_components(seen=seen)
             custom_components |= child._get_all_custom_components(seen=seen)
+        for component in self._get_components_in_props():
+            if isinstance(component, Component) and component.tag is not None:
+                custom_components |= component._get_all_custom_components(seen=seen)
         return custom_components
         return custom_components
 
 
     @property
     @property
@@ -1614,17 +1663,65 @@ class CustomComponent(Component):
     # The props of the component.
     # The props of the component.
     props: Dict[str, Any] = {}
     props: Dict[str, Any] = {}
 
 
-    # Props that reference other components.
-    component_props: Dict[str, Component] = {}
-
-    def __init__(self, *args, **kwargs):
+    def __init__(self, **kwargs):
         """Initialize the custom component.
         """Initialize the custom component.
 
 
         Args:
         Args:
-            *args: The args to pass to the component.
             **kwargs: The kwargs to pass to the component.
             **kwargs: The kwargs to pass to the component.
         """
         """
-        super().__init__(*args, **kwargs)
+        component_fn = kwargs.get("component_fn")
+
+        # Set the props.
+        props_types = typing.get_type_hints(component_fn) if component_fn else {}
+        props = {key: value for key, value in kwargs.items() if key in props_types}
+        kwargs = {key: value for key, value in kwargs.items() if key not in props_types}
+
+        event_types = {
+            key
+            for key in props
+            if (
+                (get_origin((annotation := props_types.get(key))) or annotation)
+                == EventHandler
+            )
+        }
+
+        def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
+            type_ = props_types[key]
+
+            return (
+                args[0]
+                if (args := get_args(type_))
+                else (
+                    annotation_args[1]
+                    if get_origin(
+                        (
+                            annotation := inspect.getfullargspec(
+                                component_fn
+                            ).annotations[key]
+                        )
+                    )
+                    is typing.Annotated
+                    and (annotation_args := get_args(annotation))
+                    else no_args_event_spec
+                )
+            )
+
+        super().__init__(
+            event_triggers={
+                key: EventChain.create(
+                    value=props[key],
+                    args_spec=get_args_spec(key),
+                    key=key,
+                )
+                for key in event_types
+            },
+            **kwargs,
+        )
+
+        to_camel_cased_props = {
+            format.to_camel_case(key) for key in props if key not in event_types
+        }
+        self.get_props = lambda: to_camel_cased_props  # pyright: ignore [reportIncompatibleVariableOverride]
 
 
         # Unset the style.
         # Unset the style.
         self.style = Style()
         self.style = Style()
@@ -1632,51 +1729,36 @@ class CustomComponent(Component):
         # Set the tag to the name of the function.
         # Set the tag to the name of the function.
         self.tag = format.to_title_case(self.component_fn.__name__)
         self.tag = format.to_title_case(self.component_fn.__name__)
 
 
-        # Get the event triggers defined in the component declaration.
-        event_triggers_in_component_declaration = self.get_event_triggers()
-
-        # Set the props.
-        props = typing.get_type_hints(self.component_fn)
-        for key, value in kwargs.items():
+        for key, value in props.items():
             # Skip kwargs that are not props.
             # Skip kwargs that are not props.
-            if key not in props:
+            if key not in props_types:
                 continue
                 continue
 
 
+            camel_cased_key = format.to_camel_case(key)
+
             # Get the type based on the annotation.
             # Get the type based on the annotation.
-            type_ = props[key]
+            type_ = props_types[key]
 
 
             # Handle event chains.
             # Handle event chains.
-            if types._issubclass(type_, EventChain):
-                value = EventChain.create(
-                    value=value,
-                    args_spec=event_triggers_in_component_declaration.get(
-                        key, no_args_event_spec
-                    ),
-                    key=key,
+            if types._issubclass(type_, EventActionsMixin):
+                inspect.getfullargspec(component_fn).annotations[key]
+                self.props[camel_cased_key] = EventChain.create(
+                    value=value, args_spec=get_args_spec(key), key=key
                 )
                 )
-                self.props[format.to_camel_case(key)] = value
                 continue
                 continue
 
 
-            # Handle subclasses of Base.
-            if isinstance(value, Base):
-                base_value = LiteralVar.create(value)
-
-                # Track hooks and imports associated with Component instances.
-                if base_value is not None and isinstance(value, Component):
-                    self.component_props[key] = value
-                    value = base_value._replace(
-                        merge_var_data=VarData(
-                            imports=value._get_all_imports(),
-                            hooks=value._get_all_hooks(),
-                        )
-                    )
-                else:
-                    value = base_value
-            else:
-                value = LiteralVar.create(value)
+            value = LiteralVar.create(value)
+            self.props[camel_cased_key] = value
+            setattr(self, camel_cased_key, value)
 
 
-            # Set the prop.
-            self.props[format.to_camel_case(key)] = value
+    @classmethod
+    def _are_fields_known(cls) -> bool:
+        """Check if the fields are known.
+
+        Returns:
+            Whether the fields are known.
+        """
+        return False
 
 
     def __eq__(self, other: Any) -> bool:
     def __eq__(self, other: Any) -> bool:
         """Check if the component is equal to another.
         """Check if the component is equal to another.
@@ -1698,7 +1780,7 @@ class CustomComponent(Component):
         return hash(self.tag)
         return hash(self.tag)
 
 
     @classmethod
     @classmethod
-    def get_props(cls) -> Set[str]:  # pyright: ignore [reportIncompatibleVariableOverride]
+    def get_props(cls) -> Set[str]:
         """Get the props for the component.
         """Get the props for the component.
 
 
         Returns:
         Returns:
@@ -1735,27 +1817,8 @@ class CustomComponent(Component):
                 seen=seen
                 seen=seen
             )
             )
 
 
-        # Fetch custom components from props as well.
-        for child_component in self.component_props.values():
-            if child_component.tag is None:
-                continue
-            if child_component.tag not in seen:
-                seen.add(child_component.tag)
-                if isinstance(child_component, CustomComponent):
-                    custom_components |= {child_component}
-                custom_components |= child_component._get_all_custom_components(
-                    seen=seen
-                )
         return custom_components
         return custom_components
 
 
-    def _render(self) -> Tag:
-        """Define how to render the component in React.
-
-        Returns:
-            The tag to render.
-        """
-        return super()._render(props=self.props)
-
     def get_prop_vars(self) -> List[Var]:
     def get_prop_vars(self) -> List[Var]:
         """Get the prop vars.
         """Get the prop vars.
 
 
@@ -1765,29 +1828,19 @@ class CustomComponent(Component):
         return [
         return [
             Var(
             Var(
                 _js_expr=name,
                 _js_expr=name,
-                _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
+                _var_type=(
+                    prop._var_type
+                    if isinstance(prop, Var)
+                    else (
+                        type(prop)
+                        if not isinstance(prop, EventActionsMixin)
+                        else EventChain
+                    )
+                ),
             ).guess_type()
             ).guess_type()
             for name, prop in self.props.items()
             for name, prop in self.props.items()
         ]
         ]
 
 
-    def _get_vars(
-        self, include_children: bool = False, ignore_ids: set[int] | None = None
-    ) -> Iterator[Var]:
-        """Walk all Vars used in this component.
-
-        Args:
-            include_children: Whether to include Vars from children.
-            ignore_ids: The ids to ignore.
-
-        Yields:
-            Each var referenced by the component (props, styles, event handlers).
-        """
-        ignore_ids = ignore_ids or set()
-        yield from super()._get_vars(
-            include_children=include_children, ignore_ids=ignore_ids
-        )
-        yield from filter(lambda prop: isinstance(prop, Var), self.props.values())
-
     @lru_cache(maxsize=None)  # noqa: B019
     @lru_cache(maxsize=None)  # noqa: B019
     def get_component(self) -> Component:
     def get_component(self) -> Component:
         """Render the component.
         """Render the component.
@@ -2475,6 +2528,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
             The VarData for the var.
             The VarData for the var.
         """
         """
         return VarData.merge(
         return VarData.merge(
+            self._var_data,
             VarData(
             VarData(
                 imports={
                 imports={
                     "@emotion/react": [
                     "@emotion/react": [
@@ -2517,9 +2571,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
         Returns:
         Returns:
             The var.
             The var.
         """
         """
+        var_datas = [
+            var_data
+            for var in value._get_vars(include_children=True)
+            if (var_data := var._get_all_var_data())
+        ]
+
         return LiteralComponentVar(
         return LiteralComponentVar(
             _js_expr="",
             _js_expr="",
             _var_type=type(value),
             _var_type=type(value),
-            _var_data=_var_data,
+            _var_data=VarData.merge(
+                _var_data,
+                *var_datas,
+                VarData(
+                    components=(value,),
+                ),
+            ),
             _var_value=value,
             _var_value=value,
         )
         )

+ 0 - 8
reflex/components/core/cond.py

@@ -61,14 +61,6 @@ class Cond(MemoizationLeaf):
             )
             )
         )
         )
 
 
-    def _get_props_imports(self):
-        """Get the imports needed for component's props.
-
-        Returns:
-            The imports for the component's props of the component.
-        """
-        return []
-
     def _render(self) -> Tag:
     def _render(self) -> Tag:
         return CondTag(
         return CondTag(
             cond=self.cond,
             cond=self.cond,

+ 21 - 10
reflex/vars/base.py

@@ -76,6 +76,7 @@ from reflex.utils.types import (
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from reflex.components.component import BaseComponent
     from reflex.state import BaseState
     from reflex.state import BaseState
 
 
     from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
     from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
@@ -132,6 +133,9 @@ class VarData:
     # Position of the hook in the component
     # Position of the hook in the component
     position: Hooks.HookPosition | None = None
     position: Hooks.HookPosition | None = None
 
 
+    # Components that are part of this var
+    components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
+
     def __init__(
     def __init__(
         self,
         self,
         state: str = "",
         state: str = "",
@@ -140,6 +144,7 @@ class VarData:
         hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
         hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
         deps: list[Var] | None = None,
         deps: list[Var] | None = None,
         position: Hooks.HookPosition | None = None,
         position: Hooks.HookPosition | None = None,
+        components: Iterable[BaseComponent] | None = None,
     ):
     ):
         """Initialize the var data.
         """Initialize the var data.
 
 
@@ -150,6 +155,7 @@ class VarData:
             hooks: Hooks that need to be present in the component to render this var.
             hooks: Hooks that need to be present in the component to render this var.
             deps: Dependencies of the var for useCallback.
             deps: Dependencies of the var for useCallback.
             position: Position of the hook in the component.
             position: Position of the hook in the component.
+            components: Components that are part of this var.
         """
         """
         if isinstance(hooks, str):
         if isinstance(hooks, str):
             hooks = [hooks]
             hooks = [hooks]
@@ -164,6 +170,7 @@ class VarData:
         object.__setattr__(self, "hooks", tuple(hooks or {}))
         object.__setattr__(self, "hooks", tuple(hooks or {}))
         object.__setattr__(self, "deps", tuple(deps or []))
         object.__setattr__(self, "deps", tuple(deps or []))
         object.__setattr__(self, "position", position or None)
         object.__setattr__(self, "position", position or None)
+        object.__setattr__(self, "components", tuple(components or []))
 
 
         if hooks and any(hooks.values()):
         if hooks and any(hooks.values()):
             merged_var_data = VarData.merge(self, *hooks.values())
             merged_var_data = VarData.merge(self, *hooks.values())
@@ -174,6 +181,7 @@ class VarData:
                 object.__setattr__(self, "hooks", merged_var_data.hooks)
                 object.__setattr__(self, "hooks", merged_var_data.hooks)
                 object.__setattr__(self, "deps", merged_var_data.deps)
                 object.__setattr__(self, "deps", merged_var_data.deps)
                 object.__setattr__(self, "position", merged_var_data.position)
                 object.__setattr__(self, "position", merged_var_data.position)
+                object.__setattr__(self, "components", merged_var_data.components)
 
 
     def old_school_imports(self) -> ImportDict:
     def old_school_imports(self) -> ImportDict:
         """Return the imports as a mutable dict.
         """Return the imports as a mutable dict.
@@ -242,17 +250,19 @@ class VarData:
         else:
         else:
             position = None
             position = None
 
 
-        if state or _imports or hooks or field_name or deps or position:
-            return VarData(
-                state=state,
-                field_name=field_name,
-                imports=_imports,
-                hooks=hooks,
-                deps=deps,
-                position=position,
-            )
+        components = tuple(
+            component for var_data in all_var_datas for component in var_data.components
+        )
 
 
-        return None
+        return VarData(
+            state=state,
+            field_name=field_name,
+            imports=_imports,
+            hooks=hooks,
+            deps=deps,
+            position=position,
+            components=components,
+        )
 
 
     def __bool__(self) -> bool:
     def __bool__(self) -> bool:
         """Check if the var data is non-empty.
         """Check if the var data is non-empty.
@@ -267,6 +277,7 @@ class VarData:
             or self.field_name
             or self.field_name
             or self.deps
             or self.deps
             or self.position
             or self.position
+            or self.components
         )
         )
 
 
     @classmethod
     @classmethod

+ 1 - 1
tests/units/components/test_component.py

@@ -871,7 +871,7 @@ def test_create_custom_component(my_component):
     """
     """
     component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
     component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
     assert component.tag == "MyComponent"
     assert component.tag == "MyComponent"
-    assert component.get_props() == set()
+    assert component.get_props() == {"prop1", "prop2"}
     assert component._get_all_custom_components() == {component}
     assert component._get_all_custom_components() == {component}