Kaynağa Gözat

use position in vardata to mark internal hooks (#4549)

* use position in vardata to mark internal hooks

* update all render to use position

* use macros for rendering

* reduce number of iterations over hooks during rendering

* cleanup code and add typing

* add __future__

* use new macros to render component maps in markdown

* remove calls to _get_all_hooks_internal

* fix typo

* forgot to replace this

* unnecessary expand in utils.py
Thomas Brandého 4 ay önce
ebeveyn
işleme
9fafb6d526

+ 2 - 4
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -1,4 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 {% extends "web/pages/base_page.js.jinja2" %}
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 
 
 {% block early_imports %}
 {% block early_imports %}
 import '$/styles/styles.css'
 import '$/styles/styles.css'
@@ -18,10 +19,7 @@ import * as {{library_alias}} from "{{library_path}}";
 
 
 {% block export %}
 {% block export %}
 function AppWrap({children}) {
 function AppWrap({children}) {
-
-  {% for hook in hooks %}
-  {{ hook }}
-  {% endfor %}
+  {{ renderHooks(hooks) }}
 
 
   return (
   return (
     {{utils.render(render, indent_width=0)}}
     {{utils.render(render, indent_width=0)}}

+ 3 - 4
reflex/.templates/jinja/web/pages/custom_component.js.jinja2

@@ -1,5 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 {% extends "web/pages/base_page.js.jinja2" %}
-
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 {% block export %}
 {% block export %}
 {% for component in components %}
 {% for component in components %}
 
 
@@ -8,9 +8,8 @@
 {% endfor %}
 {% endfor %}
 
 
 export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
 export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
-    {% for hook in component.hooks %}
-    {{ hook }}
-    {% endfor %}
+    {{ renderHooks(component.hooks) }}
+
     return(
     return(
         {{utils.render(component.render)}}
         {{utils.render(component.render)}}
       )
       )

+ 2 - 3
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -1,4 +1,5 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 {% extends "web/pages/base_page.js.jinja2" %}
+{% from "web/pages/macros.js.jinja2" import renderHooks %}
 
 
 {% block declaration %}
 {% block declaration %}
 {% for custom_code in custom_codes %}
 {% for custom_code in custom_codes %}
@@ -8,9 +9,7 @@
 
 
 {% block export %}
 {% block export %}
 export default function Component() {
 export default function Component() {
-  {% for hook in hooks %}
-  {{ hook }}
-  {% endfor %}
+    {{ renderHooks(hooks)}}
 
 
   return (
   return (
     {{utils.render(render, indent_width=0)}}
     {{utils.render(render, indent_width=0)}}

+ 38 - 0
reflex/.templates/jinja/web/pages/macros.js.jinja2

@@ -0,0 +1,38 @@
+{% macro renderHooks(hooks) %}
+  {% set sorted_hooks = sort_hooks(hooks) %}
+
+  {# Render the grouped hooks #}
+   {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+{% endmacro %}
+
+{% macro renderHooksWithMemo(hooks, memo)%}
+  {% set sorted_hooks = sort_hooks(hooks) %}
+
+  {# Render the grouped hooks #}
+  {% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook in memo %}
+  {{ hook }}
+  {% endfor %}
+
+  {% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
+  {{ hook }}
+  {% endfor %}
+
+{% endmacro %}

+ 4 - 16
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -1,22 +1,10 @@
 {% import 'web/pages/utils.js.jinja2' as utils %}
 {% import 'web/pages/utils.js.jinja2' as utils %}
+{% from 'web/pages/macros.js.jinja2' import renderHooksWithMemo %}
+{% set all_hooks = component._get_all_hooks() %}
 
 
 export function {{tag_name}} () {
 export function {{tag_name}} () {
-  {% for hook in component._get_all_hooks_internal() %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook in memo_trigger_hooks %}
-  {{ hook }}
-  {% endfor %}
-
-  {% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
-  {{ hook }}
-  {% endfor %}
-
+  {{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }}
+  
   return (
   return (
     {{utils.render(component.render(), indent_width=0)}}
     {{utils.render(component.render(), indent_width=0)}}
   )
   )

+ 2 - 2
reflex/compiler/compiler.py

@@ -75,7 +75,7 @@ def _compile_app(app_root: Component) -> str:
     return templates.APP_ROOT.render(
     return templates.APP_ROOT.render(
         imports=utils.compile_imports(app_root._get_all_imports()),
         imports=utils.compile_imports(app_root._get_all_imports()),
         custom_codes=app_root._get_all_custom_code(),
         custom_codes=app_root._get_all_custom_code(),
-        hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()},
+        hooks=app_root._get_all_hooks(),
         window_libraries=window_libraries,
         window_libraries=window_libraries,
         render=app_root.render(),
         render=app_root.render(),
     )
     )
@@ -149,7 +149,7 @@ def _compile_page(
         imports=imports,
         imports=imports,
         dynamic_imports=component._get_all_dynamic_imports(),
         dynamic_imports=component._get_all_dynamic_imports(),
         custom_codes=component._get_all_custom_code(),
         custom_codes=component._get_all_custom_code(),
-        hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()},
+        hooks=component._get_all_hooks(),
         render=component.render(),
         render=component.render(),
         **kwargs,
         **kwargs,
     )
     )

+ 41 - 0
reflex/compiler/templates.py

@@ -1,9 +1,46 @@
 """Templates to use in the reflex compiler."""
 """Templates to use in the reflex compiler."""
 
 
+from __future__ import annotations
+
 from jinja2 import Environment, FileSystemLoader, Template
 from jinja2 import Environment, FileSystemLoader, Template
 
 
 from reflex import constants
 from reflex import constants
+from reflex.constants import Hooks
 from reflex.utils.format import format_state_name, json_dumps
 from reflex.utils.format import format_state_name, json_dumps
+from reflex.vars.base import VarData
+
+
+def _sort_hooks(hooks: dict[str, VarData | None]):
+    """Sort the hooks by their position.
+
+    Args:
+        hooks: The hooks to sort.
+
+    Returns:
+        The sorted hooks.
+    """
+    sorted_hooks = {
+        Hooks.HookPosition.INTERNAL: [],
+        Hooks.HookPosition.PRE_TRIGGER: [],
+        Hooks.HookPosition.POST_TRIGGER: [],
+    }
+
+    for hook, data in hooks.items():
+        if data and data.position and data.position == Hooks.HookPosition.INTERNAL:
+            sorted_hooks[Hooks.HookPosition.INTERNAL].append((hook, data))
+        elif not data or (
+            not data.position
+            or data.position == constants.Hooks.HookPosition.PRE_TRIGGER
+        ):
+            sorted_hooks[Hooks.HookPosition.PRE_TRIGGER].append((hook, data))
+        elif (
+            data
+            and data.position
+            and data.position == constants.Hooks.HookPosition.POST_TRIGGER
+        ):
+            sorted_hooks[Hooks.HookPosition.POST_TRIGGER].append((hook, data))
+
+    return sorted_hooks
 
 
 
 
 class ReflexJinjaEnvironment(Environment):
 class ReflexJinjaEnvironment(Environment):
@@ -47,6 +84,7 @@ class ReflexJinjaEnvironment(Environment):
             "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
             "frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
             "hook_position": constants.Hooks.HookPosition,
             "hook_position": constants.Hooks.HookPosition,
         }
         }
+        self.globals["sort_hooks"] = _sort_hooks
 
 
 
 
 def get_template(name: str) -> Template:
 def get_template(name: str) -> Template:
@@ -103,6 +141,9 @@ STYLE = get_template("web/styles/styles.css.jinja2")
 # Code that generate the package json file
 # Code that generate the package json file
 PACKAGE_JSON = get_template("web/package.json.jinja2")
 PACKAGE_JSON = get_template("web/package.json.jinja2")
 
 
+# Template containing some macros used in the web pages.
+MACROS = get_template("web/pages/macros.js.jinja2")
+
 # Code that generate the pyproject.toml file for custom components.
 # Code that generate the pyproject.toml file for custom components.
 CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
 CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
     "custom_components/pyproject.toml.jinja2"
     "custom_components/pyproject.toml.jinja2"

+ 1 - 1
reflex/compiler/utils.py

@@ -290,7 +290,7 @@ def compile_custom_component(
             "name": component.tag,
             "name": component.tag,
             "props": props,
             "props": props,
             "render": render.render(),
             "render": render.render(),
-            "hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
+            "hooks": render._get_all_hooks(),
             "custom_code": render._get_all_custom_code(),
             "custom_code": render._get_all_custom_code(),
         },
         },
         imports,
         imports,

+ 3 - 2
reflex/components/base/bare.py

@@ -9,6 +9,7 @@ from reflex.components.tags import Tag
 from reflex.components.tags.tagless import Tagless
 from reflex.components.tags.tagless import Tagless
 from reflex.utils.imports import ParsedImportDict
 from reflex.utils.imports import ParsedImportDict
 from reflex.vars import BooleanVar, ObjectVar, Var
 from reflex.vars import BooleanVar, ObjectVar, Var
+from reflex.vars.base import VarData
 
 
 
 
 class Bare(Component):
 class Bare(Component):
@@ -32,7 +33,7 @@ class Bare(Component):
             contents = str(contents) if contents is not None else ""
             contents = str(contents) if contents is not None else ""
         return cls(contents=contents)  # type: ignore
         return cls(contents=contents)  # type: ignore
 
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Include the hooks for the component.
         """Include the hooks for the component.
 
 
         Returns:
         Returns:
@@ -43,7 +44,7 @@ class Bare(Component):
             hooks |= self.contents._var_value._get_all_hooks_internal()
             hooks |= self.contents._var_value._get_all_hooks_internal()
         return hooks
         return hooks
 
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Include the hooks for the component.
         """Include the hooks for the component.
 
 
         Returns:
         Returns:

+ 34 - 19
reflex/components/component.py

@@ -102,7 +102,7 @@ class BaseComponent(Base, ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
@@ -110,7 +110,7 @@ class BaseComponent(Base, ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
         Returns:
         Returns:
@@ -1272,7 +1272,7 @@ class Component(BaseComponent, ABC):
         """
         """
         _imports = {}
         _imports = {}
 
 
-        if self._get_ref_hook():
+        if self._get_ref_hook() is not None:
             # Handle hooks needed for attaching react refs to DOM nodes.
             # Handle hooks needed for attaching react refs to DOM nodes.
             _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
             _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
             _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
             _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
@@ -1388,7 +1388,7 @@ class Component(BaseComponent, ABC):
                     }}
                     }}
                 }}, []);"""
                 }}, []);"""
 
 
-    def _get_ref_hook(self) -> str | None:
+    def _get_ref_hook(self) -> Var | None:
         """Generate the ref hook for the component.
         """Generate the ref hook for the component.
 
 
         Returns:
         Returns:
@@ -1396,11 +1396,12 @@ class Component(BaseComponent, ABC):
         """
         """
         ref = self.get_ref()
         ref = self.get_ref()
         if ref is not None:
         if ref is not None:
-            return (
-                f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
+            return Var(
+                f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
+                _var_data=VarData(position=Hooks.HookPosition.INTERNAL),
             )
             )
 
 
-    def _get_vars_hooks(self) -> dict[str, None]:
+    def _get_vars_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by vars referenced in this component.
         """Get the hooks required by vars referenced in this component.
 
 
         Returns:
         Returns:
@@ -1413,27 +1414,38 @@ class Component(BaseComponent, ABC):
                 vars_hooks.update(
                 vars_hooks.update(
                     var_data.hooks
                     var_data.hooks
                     if isinstance(var_data.hooks, dict)
                     if isinstance(var_data.hooks, dict)
-                    else {k: None for k in var_data.hooks}
+                    else {
+                        k: VarData(position=Hooks.HookPosition.INTERNAL)
+                        for k in var_data.hooks
+                    }
                 )
                 )
         return vars_hooks
         return vars_hooks
 
 
-    def _get_events_hooks(self) -> dict[str, None]:
+    def _get_events_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by events referenced in this component.
         """Get the hooks required by events referenced in this component.
 
 
         Returns:
         Returns:
             The hooks for the events.
             The hooks for the events.
         """
         """
-        return {Hooks.EVENTS: None} if self.event_triggers else {}
+        return (
+            {Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
+            if self.event_triggers
+            else {}
+        )
 
 
-    def _get_special_hooks(self) -> dict[str, None]:
+    def _get_special_hooks(self) -> dict[str, VarData | None]:
         """Get the hooks required by special actions referenced in this component.
         """Get the hooks required by special actions referenced in this component.
 
 
         Returns:
         Returns:
             The hooks for special actions.
             The hooks for special actions.
         """
         """
-        return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
+        return (
+            {Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
+            if self.autofocus
+            else {}
+        )
 
 
-    def _get_hooks_internal(self) -> dict[str, None]:
+    def _get_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component managed by the framework.
         """Get the React hooks for this component managed by the framework.
 
 
         Downstream components should NOT override this method to avoid breaking
         Downstream components should NOT override this method to avoid breaking
@@ -1444,7 +1456,7 @@ class Component(BaseComponent, ABC):
         """
         """
         return {
         return {
             **{
             **{
-                hook: None
+                str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
                 for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
                 for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
                 if hook is not None
                 if hook is not None
             },
             },
@@ -1493,7 +1505,7 @@ class Component(BaseComponent, ABC):
         """
         """
         return
         return
 
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
@@ -1508,7 +1520,7 @@ class Component(BaseComponent, ABC):
 
 
         return code
         return code
 
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component and its children.
         """Get the React hooks for this component and its children.
 
 
         Returns:
         Returns:
@@ -1516,6 +1528,9 @@ class Component(BaseComponent, ABC):
         """
         """
         code = {}
         code = {}
 
 
+        # Add the internal hooks for this component.
+        code.update(self._get_hooks_internal())
+
         # Add the hook code for this component.
         # Add the hook code for this component.
         hooks = self._get_hooks()
         hooks = self._get_hooks()
         if hooks is not None:
         if hooks is not None:
@@ -2211,7 +2226,7 @@ class StatefulComponent(BaseComponent):
             )
             )
         return trigger_memo
         return trigger_memo
 
 
-    def _get_all_hooks_internal(self) -> dict[str, None]:
+    def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
@@ -2219,7 +2234,7 @@ class StatefulComponent(BaseComponent):
         """
         """
         return {}
         return {}
 
 
-    def _get_all_hooks(self) -> dict[str, None]:
+    def _get_all_hooks(self) -> dict[str, VarData | None]:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
         Returns:
         Returns:
@@ -2337,7 +2352,7 @@ class MemoizationLeaf(Component):
             The memoization leaf
             The memoization leaf
         """
         """
         comp = super().create(*children, **props)
         comp = super().create(*children, **props)
-        if comp._get_all_hooks() or comp._get_all_hooks_internal():
+        if comp._get_all_hooks():
             comp._memoization_mode = cls._memoization_mode.copy(
             comp._memoization_mode = cls._memoization_mode.copy(
                 update={"disposition": MemoizationDisposition.ALWAYS}
                 update={"disposition": MemoizationDisposition.ALWAYS}
             )
             )

+ 1 - 3
reflex/components/el/elements/forms.py

@@ -182,9 +182,7 @@ class Form(BaseHTML):
         props["handle_submit_unique_name"] = ""
         props["handle_submit_unique_name"] = ""
         form = super().create(*children, **props)
         form = super().create(*children, **props)
         form.handle_submit_unique_name = md5(
         form.handle_submit_unique_name = md5(
-            str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode(
-                "utf-8"
-            )
+            str(form._get_all_hooks()).encode("utf-8")
         ).hexdigest()
         ).hexdigest()
         return form
         return form
 
 

+ 3 - 2
reflex/components/markdown/markdown.py

@@ -420,11 +420,12 @@ const {_LANGUAGE!s} = match ? match[1] : '';
 
 
     def _get_custom_code(self) -> str | None:
     def _get_custom_code(self) -> str | None:
         hooks = {}
         hooks = {}
+        from reflex.compiler.templates import MACROS
+
         for _component in self.component_map.values():
         for _component in self.component_map.values():
             comp = _component(_MOCK_ARG)
             comp = _component(_MOCK_ARG)
-            hooks.update(comp._get_all_hooks_internal())
             hooks.update(comp._get_all_hooks())
             hooks.update(comp._get_all_hooks())
-        formatted_hooks = "\n".join(hooks.keys())
+        formatted_hooks = MACROS.module.renderHooks(hooks)  # type: ignore
         return f"""
         return f"""
         function {self._get_component_map_name()} () {{
         function {self._get_component_map_name()} () {{
             {formatted_hooks}
             {formatted_hooks}

+ 1 - 0
reflex/constants/compiler.py

@@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
     class HookPosition(enum.Enum):
     class HookPosition(enum.Enum):
         """The position of the hook in the component."""
         """The position of the hook in the component."""
 
 
+        INTERNAL = "internal"
         PRE_TRIGGER = "pre_trigger"
         PRE_TRIGGER = "pre_trigger"
         POST_TRIGGER = "post_trigger"
         POST_TRIGGER = "post_trigger"
 
 

+ 1 - 1
reflex/experimental/client_state.py

@@ -105,7 +105,7 @@ class ClientStateVar(Var):
         else:
         else:
             default_var = default
             default_var = default
         setter_name = f"set{var_name.capitalize()}"
         setter_name = f"set{var_name.capitalize()}"
-        hooks = {
+        hooks: dict[str, VarData | None] = {
             f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
             f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
         }
         }
         imports = {
         imports = {

+ 4 - 2
reflex/vars/base.py

@@ -127,7 +127,7 @@ class VarData:
         state: str = "",
         state: str = "",
         field_name: str = "",
         field_name: str = "",
         imports: ImportDict | ParsedImportDict | None = None,
         imports: ImportDict | ParsedImportDict | None = None,
-        hooks: dict[str, None] | None = None,
+        hooks: dict[str, VarData | None] | None = None,
         deps: list[Var] | None = None,
         deps: list[Var] | None = None,
         position: Hooks.HookPosition | None = None,
         position: Hooks.HookPosition | None = None,
     ):
     ):
@@ -194,7 +194,9 @@ class VarData:
             (var_data.state for var_data in all_var_datas if var_data.state), ""
             (var_data.state for var_data in all_var_datas if var_data.state), ""
         )
         )
 
 
-        hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
+        hooks: dict[str, VarData | None] = {
+            hook: None for var_data in all_var_datas for hook in var_data.hooks
+        }
 
 
         _imports = imports.merge_imports(
         _imports = imports.merge_imports(
             *(var_data.imports for var_data in all_var_datas)
             *(var_data.imports for var_data in all_var_datas)