Преглед на файлове

use position in vardata to mark internal hooks (#4549)

* use position in vardata to mark internal hooks

* update all render to use position

* use macros for rendering

* reduce number of iterations over hooks during rendering

* cleanup code and add typing

* add __future__

* use new macros to render component maps in markdown

* remove calls to _get_all_hooks_internal

* fix typo

* forgot to replace this

* unnecessary expand in utils.py
Thomas Brandého преди 4 месеца
родител
ревизия
9fafb6d526

+ 2 - 4
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -1,4 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 
 {% block early_imports %}
 import '$/styles/styles.css'
@@ -18,10 +19,7 @@ import * as {{library_alias}} from "{{library_path}}";
 
 {% block export %}
 function AppWrap({children}) {
-
-  {% for hook in hooks %}
-  {{ hook }}
-  {% endfor %}
+  {{ renderHooks(hooks) }}
 
   return (
     {{utils.render(render, indent_width=0)}}

+ 3 - 4
reflex/.templates/jinja/web/pages/custom_component.js.jinja2

@@ -1,5 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
-
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 {% block export %}
 {% for component in components %}
 
@@ -8,9 +8,8 @@
 {% endfor %}
 
 export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
-    {% for hook in component.hooks %}
-    {{ hook }}
-    {% endfor %}
+    {{ renderHooks(component.hooks) }}
+
     return(
         {{utils.render(component.render)}}
       )

+ 2 - 3
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -1,4 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 
 {% block declaration %}
 {% for custom_code in custom_codes %}
@@ -8,9 +9,7 @@
 
 {% block export %}
 export default function Component() {
-  {% for hook in hooks %}
-  {{ hook }}
-  {% endfor %}
+    {{ renderHooks(hooks)}}
 
   return (
     {{utils.render(render, indent_width=0)}}

+ 38 - 0
reflex/.templates/jinja/web/pages/macros.js.jinja2

@@ -0,0 +1,38 @@
+{% macro renderHooks(hooks) %}
+  {% set sorted_hooks = sort_hooks(hooks) %}
+
+  {# Render the grouped hooks #}
+   {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+{% endmacro %}
+
+{% macro renderHooksWithMemo(hooks, memo)%}
+  {% set sorted_hooks = sort_hooks(hooks) %}
+
+  {# Render the grouped hooks #}
+  {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook in memo %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+{% endmacro %}

+ 4 - 16
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -1,22 +1,10 @@
 {% import 'web/pages/utils.js.jinja2' as utils %}
+{% from 'web/pages/macros.js.jinja2' import renderHooksWithMemo %}
+{% set all_hooks = component._get_all_hooks() %}
 
 export function {{tag_name}} () {
-  {% for hook in component._get_all_hooks_internal() %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook in memo_trigger_hooks %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
-  {{ hook }}
-  {% endfor %}
-
+  {{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }}
+  
   return (
     {{utils.render(component.render(), indent_width=0)}}
   )

+ 2 - 2
reflex/compiler/compiler.py

@@ -75,7 +75,7 @@ def _compile_app(app_root: Component) -> str:
     return templates.APP_ROOT.render(
         imports=utils.compile_imports(app_root._get_all_imports()),
         custom_codes=app_root._get_all_custom_code(),
-        hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()},
+        hooks=app_root._get_all_hooks(),
         window_libraries=window_libraries,
         render=app_root.render(),
     )
@@ -149,7 +149,7 @@ def _compile_page(
         imports=imports,
         dynamic_imports=component._get_all_dynamic_imports(),
         custom_codes=component._get_all_custom_code(),
-        hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()},
+        hooks=component._get_all_hooks(),
         render=component.render(),
         **kwargs,
     )

+ 41 - 0
reflex/compiler/templates.py

@@ -1,9 +1,46 @@
 """Templates to use in the reflex compiler."""
 
+from __future__ import annotations
+
 from jinja2 import Environment, FileSystemLoader, Template
 
 from reflex import constants
+from reflex.constants import Hooks
 from reflex.utils.format import format_state_name, json_dumps
+from reflex.vars.base import VarData
+
+
+def _sort_hooks(hooks: dict[str, VarData | None]):
+    """Sort the hooks by their position.
+
+    Args:
+        hooks: The hooks to sort.
+
+    Returns:
+        The sorted hooks.
+    """
+    sorted_hooks = {
+        Hooks.HookPosition.INTERNAL: [],
+        Hooks.HookPosition.PRE_TRIGGER: [],
+        Hooks.HookPosition.POST_TRIGGER: [],
+    }
+
+    for hook, data in hooks.items():
+        if data and data.position and data.position == Hooks.HookPosition.INTERNAL:
+            sorted_hooks[Hooks.HookPosition.INTERNAL].append((hook, data))
+        elif not data or (
+            not data.position
+            or data.position == constants.Hooks.HookPosition.PRE_TRIGGER
+        ):
+            sorted_hooks[Hooks.HookPosition.PRE_TRIGGER].append((hook, data))
+        elif (
+            data
+            and data.position
+            and data.position == constants.Hooks.HookPosition.POST_TRIGGER
+        ):
+            sorted_hooks[Hooks.HookPosition.POST_TRIGGER].append((hook, data))
+
+    return sorted_hooks
 
 
 class ReflexJinjaEnvironment(Environment):
@@ -47,6 +84,7 @@ class ReflexJinjaEnvironment(Environment):
             "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
             "hook_position": constants.Hooks.HookPosition,
         }
+        self.globals["sort_hooks"] = _sort_hooks
 
 
 def get_template(name: str) -> Template:
@@ -103,6 +141,9 @@ STYLE = get_template("web/styles/styles.css.jinja2")
 # Code that generate the package json file
 PACKAGE_JSON = get_template("web/package.json.jinja2")
 
+# Template containing some macros used in the web pages.
+MACROS = get_template("web/pages/macros.js.jinja2")
+
 # Code that generate the pyproject.toml file for custom components.
 CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
     "custom_components/pyproject.toml.jinja2"

+ 1 - 1
reflex/compiler/utils.py

@@ -290,7 +290,7 @@ def compile_custom_component(
             "name": component.tag,
             "props": props,
             "render": render.render(),
-            "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
+            "hooks": render._get_all_hooks(),
             "custom_code": render._get_all_custom_code(),
         },
         imports,

+ 3 - 2
reflex/components/base/bare.py

@@ -9,6 +9,7 @@ from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
+from reflex.vars.base import VarData
 
 
 class Bare(Component):
@@ -32,7 +33,7 @@ class Bare(Component):
             contents = str(contents) if contents is not None else ""
         return cls(contents=contents)  # type: ignore
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Include the hooks for the component.
 
         Returns:
@@ -43,7 +44,7 @@ class Bare(Component):
             hooks |= self.contents._var_value._get_all_hooks_internal()
         return hooks
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Include the hooks for the component.
 
         Returns:

+ 34 - 19
reflex/components/component.py

@@ -102,7 +102,7 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
@@ -110,7 +110,7 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component.
 
         Returns:
@@ -1272,7 +1272,7 @@ class Component(BaseComponent, ABC):
         """
         _imports = {}
 
-        if self._get_ref_hook():
+        if self._get_ref_hook() is not None:
             # Handle hooks needed for attaching react refs to DOM nodes.
             _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
             _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
@@ -1388,7 +1388,7 @@ class Component(BaseComponent, ABC):
                     }}
                 }}, []);"""
 
-    def _get_ref_hook(self) -> str | None:
+    def _get_ref_hook(self) -> Var | None:
         """Generate the ref hook for the component.
 
         Returns:
@@ -1396,11 +1396,12 @@ class Component(BaseComponent, ABC):
         """
         ref = self.get_ref()
         if ref is not None:
-            return (
-                f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
+            return Var(
+                f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
+                _var_data=VarData(position=Hooks.HookPosition.INTERNAL),
             )
 
-    def _get_vars_hooks(self) -> dict[str, None]:
+    def _get_vars_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by vars referenced in this component.
 
         Returns:
@@ -1413,27 +1414,38 @@ class Component(BaseComponent, ABC):
                 vars_hooks.update(
                     var_data.hooks
                     if isinstance(var_data.hooks, dict)
-                    else {k: None for k in var_data.hooks}
+                    else {
+                        k: VarData(position=Hooks.HookPosition.INTERNAL)
+                        for k in var_data.hooks
+                    }
                 )
         return vars_hooks
 
-    def _get_events_hooks(self) -> dict[str, None]:
+    def _get_events_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by events referenced in this component.
 
         Returns:
             The hooks for the events.
         """
-        return {Hooks.EVENTS: None} if self.event_triggers else {}
+        return (
+            {Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
+            if self.event_triggers
+            else {}
+        )
 
-    def _get_special_hooks(self) -> dict[str, None]:
+    def _get_special_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by special actions referenced in this component.
 
         Returns:
             The hooks for special actions.
         """
-        return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
+        return (
+            {Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
+            if self.autofocus
+            else {}
+        )
 
-    def _get_hooks_internal(self) -> dict[str, None]:
+    def _get_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component managed by the framework.
 
         Downstream components should NOT override this method to avoid breaking
@@ -1444,7 +1456,7 @@ class Component(BaseComponent, ABC):
         """
         return {
             **{
-                hook: None
+                str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
                 for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
                 if hook is not None
             },
@@ -1493,7 +1505,7 @@ class Component(BaseComponent, ABC):
         """
         return
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
@@ -1508,7 +1520,7 @@ class Component(BaseComponent, ABC):
 
         return code
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component and its children.
 
         Returns:
@@ -1516,6 +1528,9 @@ class Component(BaseComponent, ABC):
         """
         code = {}
 
+        # Add the internal hooks for this component.
+        code.update(self._get_hooks_internal())
+
         # Add the hook code for this component.
         hooks = self._get_hooks()
         if hooks is not None:
@@ -2211,7 +2226,7 @@ class StatefulComponent(BaseComponent):
             )
         return trigger_memo
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
@@ -2219,7 +2234,7 @@ class StatefulComponent(BaseComponent):
         """
         return {}
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component.
 
         Returns:
@@ -2337,7 +2352,7 @@ class MemoizationLeaf(Component):
             The memoization leaf
         """
         comp = super().create(*children, **props)
-        if comp._get_all_hooks() or comp._get_all_hooks_internal():
+        if comp._get_all_hooks():
             comp._memoization_mode = cls._memoization_mode.copy(
                 update={"disposition": MemoizationDisposition.ALWAYS}
             )

+ 1 - 3
reflex/components/el/elements/forms.py

@@ -182,9 +182,7 @@ class Form(BaseHTML):
         props["handle_submit_unique_name"] = ""
         form = super().create(*children, **props)
         form.handle_submit_unique_name = md5(
-            str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode(
-                "utf-8"
-            )
+            str(form._get_all_hooks()).encode("utf-8")
         ).hexdigest()
         return form
 

+ 3 - 2
reflex/components/markdown/markdown.py

@@ -420,11 +420,12 @@ const {_LANGUAGE!s} = match ? match[1] : '';
 
     def _get_custom_code(self) -> str | None:
         hooks = {}
+        from reflex.compiler.templates import MACROS
+
         for _component in self.component_map.values():
             comp = _component(_MOCK_ARG)
-            hooks.update(comp._get_all_hooks_internal())
             hooks.update(comp._get_all_hooks())
-        formatted_hooks = "\n".join(hooks.keys())
+        formatted_hooks = MACROS.module.renderHooks(hooks)  # type: ignore
         return f"""
         function {self._get_component_map_name()} () {{
             {formatted_hooks}

+ 1 - 0
reflex/constants/compiler.py

@@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
     class HookPosition(enum.Enum):
         """The position of the hook in the component."""
 
+        INTERNAL = "internal"
         PRE_TRIGGER = "pre_trigger"
         POST_TRIGGER = "post_trigger"
 

+ 1 - 1
reflex/experimental/client_state.py

@@ -105,7 +105,7 @@ class ClientStateVar(Var):
         else:
             default_var = default
         setter_name = f"set{var_name.capitalize()}"
-        hooks = {
+        hooks: dict[str, VarData | None] = {
             f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
         }
         imports = {

+ 4 - 2
reflex/vars/base.py

@@ -127,7 +127,7 @@ class VarData:
         state: str = "",
         field_name: str = "",
         imports: ImportDict | ParsedImportDict | None = None,
-        hooks: dict[str, None] | None = None,
+        hooks: dict[str, VarData | None] | None = None,
         deps: list[Var] | None = None,
         position: Hooks.HookPosition | None = None,
     ):
@@ -194,7 +194,9 @@ class VarData:
             (var_data.state for var_data in all_var_datas if var_data.state), ""
         )
 
-        hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
+        hooks: dict[str, VarData | None] = {
+            hook: None for var_data in all_var_datas for hook in var_data.hooks
+        }
 
         _imports = imports.merge_imports(
             *(var_data.imports for var_data in all_var_datas)