Browse Source

add deps and position field in VarData (#4518)

* fix memoized event trigger order

* allow to declare deps in event signature for memoized event triggers

* clean up the code to pass tests

* handle position of hooks

* clean up code

* revert test changes

* add future annotations

* remove non-necessary stuff

* reuse data_callback name if already set during first call to add_hooks

* remove HookVar and use Var with VarData instead

* remove test change

* readd removed line

* fix order of stmt for cleaner code

* fix typing

* something broke during the merge I guess

* remove hack and pass proper const for position

* oops, bad syntax in jinja

* use "hook_position" instead of "hook_positions"

match the name of the enum

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Thomas Brandého 5 tháng trước cách đây
mục cha
commit
1444421766

+ 5 - 1
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -5,11 +5,15 @@ export function {{tag_name}} () {
   {{ hook }}
   {% endfor %}
 
+  {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
+  {{ hook }}
+  {% endfor %}
+
   {% for hook in memo_trigger_hooks %}
   {{ hook }}
   {% endfor %}
 
-  {% for hook in component._get_all_hooks() %}
+  {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
   {{ hook }}
   {% endfor %}
 

+ 1 - 0
reflex/compiler/templates.py

@@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
             "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
             "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
             "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
+            "hook_position": constants.Hooks.HookPosition,
         }
 
 

+ 41 - 11
reflex/components/component.py

@@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
         if user_hooks_data is not None:
             other_imports.append(user_hooks_data.imports)
         other_imports.extend(
-            hook_imports for hook_imports in self._get_added_hooks().values()
+            hook_vardata.imports
+            for hook_vardata in self._get_added_hooks().values()
+            if hook_vardata is not None
         )
 
         return imports.merge_imports(_imports, *other_imports)
@@ -1516,7 +1518,7 @@ class Component(BaseComponent, ABC):
             **self._get_special_hooks(),
         }
 
-    def _get_added_hooks(self) -> dict[str, ImportDict]:
+    def _get_added_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks added via `add_hooks` method.
 
         Returns:
@@ -1525,17 +1527,15 @@ class Component(BaseComponent, ABC):
         code = {}
 
         def extract_var_hooks(hook: Var):
-            _imports = {}
             var_data = VarData.merge(hook._get_all_var_data())
             if var_data is not None:
                 for sub_hook in var_data.hooks:
-                    code[sub_hook] = {}
-                if var_data.imports:
-                    _imports = var_data.imports
+                    code[sub_hook] = None
+
             if str(hook) in code:
-                code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
+                code[str(hook)] = VarData.merge(var_data, code[str(hook)])
             else:
-                code[str(hook)] = _imports
+                code[str(hook)] = var_data
 
         # Add the hook code from add_hooks for each parent class (this is reversed to preserve
         # the order of the hooks in the final output)
@@ -1544,7 +1544,7 @@ class Component(BaseComponent, ABC):
                 if isinstance(hook, Var):
                     extract_var_hooks(hook)
                 else:
-                    code[hook] = {}
+                    code[hook] = None
 
         return code
 
@@ -1586,8 +1586,8 @@ class Component(BaseComponent, ABC):
         if hooks is not None:
             code[hooks] = None
 
-        for hook in self._get_added_hooks():
-            code[hook] = None
+        for hook, var_data in self._get_added_hooks().items():
+            code[hook] = var_data
 
         # Add the hook code for the children.
         for child in self.children:
@@ -2189,6 +2189,31 @@ class StatefulComponent(BaseComponent):
             ]
         return [var_name]
 
+    @staticmethod
+    def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
+        """Get the dependencies accessed by event triggers.
+
+        Args:
+            event: The event trigger to extract deps from.
+
+        Returns:
+            The dependencies accessed by the event triggers.
+        """
+        events: list = [event]
+        deps = set()
+
+        if isinstance(event, EventChain):
+            events.extend(event.events)
+
+        for ev in events:
+            if isinstance(ev, EventSpec):
+                for arg in ev.args:
+                    for a in arg:
+                        var_datas = VarData.merge(a._get_all_var_data())
+                        if var_datas and var_datas.deps is not None:
+                            deps |= {str(dep) for dep in var_datas.deps}
+        return deps
+
     @classmethod
     def _get_memoized_event_triggers(
         cls,
@@ -2225,6 +2250,11 @@ class StatefulComponent(BaseComponent):
 
             # Calculate Var dependencies accessed by the handler for useCallback dep array.
             var_deps = ["addEvents", "Event"]
+
+            # Get deps from event trigger var data.
+            var_deps.extend(cls._get_deps_from_event_trigger(event))
+
+            # Get deps from hooks.
             for arg in event_args:
                 var_data = arg._get_all_var_data()
                 if var_data is None:

+ 10 - 8
reflex/components/core/clipboard.py

@@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
 
 from reflex.components.base.fragment import Fragment
 from reflex.components.tags.tag import Tag
+from reflex.constants.compiler import Hooks
 from reflex.event import EventChain, EventHandler, passthrough_event_spec
 from reflex.utils.format import format_prop, wrap
 from reflex.utils.imports import ImportVar
 from reflex.vars import get_unique_variable_name
-from reflex.vars.base import Var
+from reflex.vars.base import Var, VarData
 
 
 class Clipboard(Fragment):
@@ -72,7 +73,7 @@ class Clipboard(Fragment):
             ),
         }
 
-    def add_hooks(self) -> list[str]:
+    def add_hooks(self) -> list[str | Var[str]]:
         """Add hook to register paste event listener.
 
         Returns:
@@ -83,13 +84,14 @@ class Clipboard(Fragment):
             return []
         if isinstance(on_paste, EventChain):
             on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
+        hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
+
         return [
-            "usePasteHandler(%s, %s, %s)"
-            % (
-                str(self.targets),
-                str(self.on_paste_event_actions),
-                on_paste,
-            )
+            Var(
+                hook_expr,
+                _var_type="str",
+                _var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
+            ),
         ]
 
 

+ 1 - 1
reflex/components/core/clipboard.pyi

@@ -71,6 +71,6 @@ class Clipboard(Fragment):
         ...
 
     def add_imports(self) -> dict[str, ImportVar]: ...
-    def add_hooks(self) -> list[str]: ...
+    def add_hooks(self) -> list[str | Var[str]]: ...
 
 clipboard = Clipboard.create

+ 5 - 2
reflex/components/datadisplay/dataeditor.py

@@ -339,8 +339,11 @@ class DataEditor(NoSSRComponent):
         editor_id = get_unique_variable_name()
 
         # Define the name of the getData callback associated with this component and assign to get_cell_content.
-        data_callback = f"getData_{editor_id}"
-        self.get_cell_content = Var(_js_expr=data_callback)  # type: ignore
+        if self.get_cell_content is not None:
+            data_callback = self.get_cell_content._js_expr
+        else:
+            data_callback = f"getData_{editor_id}"
+            self.get_cell_content = Var(_js_expr=data_callback)  # type: ignore
 
         code = [f"function {data_callback}([col, row])" "{"]
 

+ 6 - 0
reflex/constants/compiler.py

@@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
                   }
                 })"""
 
+    class HookPosition(enum.Enum):
+        """The position of the hook in the component."""
+
+        PRE_TRIGGER = "pre_trigger"
+        POST_TRIGGER = "post_trigger"
+
 
 class MemoizationDisposition(enum.Enum):
     """The conditions under which a component should be memoized."""

+ 46 - 4
reflex/vars/base.py

@@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
 
 from reflex import constants
 from reflex.base import Base
-from reflex.utils import console, imports, serializers, types
+from reflex.constants.compiler import Hooks
+from reflex.utils import console, exceptions, imports, serializers, types
 from reflex.utils.exceptions import (
     VarAttributeError,
     VarDependencyError,
@@ -115,12 +116,20 @@ class VarData:
     # Hooks that need to be present in the component to render this var
     hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
 
+    # Dependencies of the var
+    deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
+
+    # Position of the hook in the component
+    position: Hooks.HookPosition | None = None
+
     def __init__(
         self,
         state: str = "",
         field_name: str = "",
         imports: ImportDict | ParsedImportDict | None = None,
         hooks: dict[str, None] | None = None,
+        deps: list[Var] | None = None,
+        position: Hooks.HookPosition | None = None,
     ):
         """Initialize the var data.
 
@@ -129,6 +138,8 @@ class VarData:
             field_name: The name of the field in the state.
             imports: Imports needed to render this var.
             hooks: Hooks that need to be present in the component to render this var.
+            deps: Dependencies of the var for useCallback.
+            position: Position of the hook in the component.
         """
         immutable_imports: ImmutableParsedImportDict = tuple(
             sorted(
@@ -139,6 +150,8 @@ class VarData:
         object.__setattr__(self, "field_name", field_name)
         object.__setattr__(self, "imports", immutable_imports)
         object.__setattr__(self, "hooks", tuple(hooks or {}))
+        object.__setattr__(self, "deps", tuple(deps or []))
+        object.__setattr__(self, "position", position or None)
 
     def old_school_imports(self) -> ImportDict:
         """Return the imports as a mutable dict.
@@ -154,6 +167,9 @@ class VarData:
         Args:
             *all: The var data objects to merge.
 
+        Raises:
+            ReflexError: If trying to merge VarData with different positions.
+
         Returns:
             The merged var data object.
 
@@ -184,12 +200,32 @@ class VarData:
             *(var_data.imports for var_data in all_var_datas)
         )
 
-        if state or _imports or hooks or field_name:
+        deps = [dep for var_data in all_var_datas for dep in var_data.deps]
+
+        positions = list(
+            {
+                var_data.position
+                for var_data in all_var_datas
+                if var_data.position is not None
+            }
+        )
+        if positions:
+            if len(positions) > 1:
+                raise exceptions.ReflexError(
+                    f"Cannot merge var data with different positions: {positions}"
+                )
+            position = positions[0]
+        else:
+            position = None
+
+        if state or _imports or hooks or field_name or deps or position:
             return VarData(
                 state=state,
                 field_name=field_name,
                 imports=_imports,
                 hooks=hooks,
+                deps=deps,
+                position=position,
             )
 
         return None
@@ -200,7 +236,14 @@ class VarData:
         Returns:
             True if any field is set to a non-default value.
         """
-        return bool(self.state or self.imports or self.hooks or self.field_name)
+        return bool(
+            self.state
+            or self.imports
+            or self.hooks
+            or self.field_name
+            or self.deps
+            or self.position
+        )
 
     @classmethod
     def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
@@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
             raise TypeError(
                 "The _var_full_name_needs_state_prefix argument is not supported for Var."
             )
-
         value_with_replaced = dataclasses.replace(
             self,
             _var_type=_var_type or self._var_type,