瀏覽代碼

use dict instead of set to store hooks (#2995)

Thomas Brandého 1 年之前
父節點
當前提交
34ee07ecd1

+ 0 - 4
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -8,10 +8,6 @@
 
 {% block export %}
 export default function Component() {
-  {% for ref_hook in ref_hooks %}
-  {{ ref_hook }}
-  {% endfor %}
-  
   {% for hook in hooks %}
   {{ hook }}
   {% endfor %}

+ 0 - 4
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -1,10 +1,6 @@
 {% import 'web/pages/utils.js.jinja2' as utils %}
 
 export function {{tag_name}} () {
-  {% for ref_hook in component.get_ref_hooks() %}
-  {{ ref_hook }}
-  {% endfor %}
-
   {% for hook in component.get_hooks_internal() %}
   {{ hook }}
   {% endfor %}

+ 2 - 3
reflex/compiler/compiler.py

@@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str:
     return templates.APP_ROOT.render(
         imports=utils.compile_imports(app_root.get_imports()),
         custom_codes=app_root.get_custom_code(),
-        hooks=app_root.get_hooks_internal() | app_root.get_hooks(),
+        hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()},
         render=app_root.render(),
     )
 
@@ -119,8 +119,7 @@ def _compile_page(
         imports=imports,
         dynamic_imports=component.get_dynamic_imports(),
         custom_codes=component.get_custom_code(),
-        ref_hooks=component.get_ref_hooks(),
-        hooks=component.get_hooks_internal() | component.get_hooks(),
+        hooks={**component.get_hooks_internal(), **component.get_hooks()},
         render=component.render(),
         **kwargs,
     )

+ 1 - 1
reflex/compiler/utils.py

@@ -265,7 +265,7 @@ def compile_custom_component(
             "name": component.tag,
             "props": props,
             "render": render.render(),
-            "hooks": render.get_hooks_internal() | render.get_hooks(),
+            "hooks": {**render.get_hooks_internal(), **render.get_hooks()},
             "custom_code": render.get_custom_code(),
         },
         imports,

+ 32 - 70
reflex/components/component.py

@@ -76,15 +76,7 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def get_ref_hooks(self) -> set[str]:
-        """Get the hooks required by refs in this component.
-
-        Returns:
-            The hooks for the refs.
-        """
-
-    @abstractmethod
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
@@ -92,7 +84,7 @@ class BaseComponent(Base, ABC):
         """
 
     @abstractmethod
-    def get_hooks(self) -> set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component.
 
         Returns:
@@ -929,7 +921,7 @@ class Component(BaseComponent, ABC):
         """
         return None
 
-    def get_custom_code(self) -> Set[str]:
+    def get_custom_code(self) -> set[str]:
         """Get custom code for the component and its children.
 
         Returns:
@@ -1108,62 +1100,53 @@ class Component(BaseComponent, ABC):
         if ref is not None:
             return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};"
 
-    def _get_vars_hooks(self) -> set[str]:
+    def _get_vars_hooks(self) -> dict[str, None]:
         """Get the hooks required by vars referenced in this component.
 
         Returns:
             The hooks for the vars.
         """
-        vars_hooks = set()
+        vars_hooks = {}
         for var in self._get_vars():
             if var._var_data:
                 vars_hooks.update(var._var_data.hooks)
         return vars_hooks
 
-    def _get_events_hooks(self) -> set[str]:
+    def _get_events_hooks(self) -> dict[str, None]:
         """Get the hooks required by events referenced in this component.
 
         Returns:
             The hooks for the events.
         """
-        if self.event_triggers:
-            return {Hooks.EVENTS}
-        return set()
+        return {Hooks.EVENTS: None} if self.event_triggers else {}
 
-    def _get_special_hooks(self) -> set[str]:
+    def _get_special_hooks(self) -> dict[str, None]:
         """Get the hooks required by special actions referenced in this component.
 
         Returns:
             The hooks for special actions.
         """
-        if self.autofocus:
-            return {
-                """
-                // Set focus to the specified element.
-                const focusRef = useRef(null)
-                useEffect(() => {
-                  if (focusRef.current) {
-                    focusRef.current.focus();
-                  }
-                })""",
-            }
-        return set()
+        return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
 
-    def _get_hooks_internal(self) -> Set[str]:
+    def _get_hooks_internal(self) -> dict[str, None]:
         """Get the React hooks for this component managed by the framework.
 
         Downstream components should NOT override this method to avoid breaking
         framework functionality.
 
         Returns:
-            Set of internally managed hooks.
+            The internally managed hooks.
         """
-        return (
-            self._get_vars_hooks()
-            | self._get_events_hooks()
-            | self._get_special_hooks()
-            | set(hook for hook in [self._get_mount_lifecycle_hook()] if hook)
-        )
+        return {
+            **{
+                hook: None
+                for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
+                if hook is not None
+            },
+            **self._get_vars_hooks(),
+            **self._get_events_hooks(),
+            **self._get_special_hooks(),
+        }
 
     def _get_hooks(self) -> str | None:
         """Get the React hooks for this component.
@@ -1175,20 +1158,7 @@ class Component(BaseComponent, ABC):
         """
         return
 
-    def get_ref_hooks(self) -> Set[str]:
-        """Get the ref hooks for the component and its children.
-
-        Returns:
-            The ref hooks.
-        """
-        ref_hook = self._get_ref_hook()
-        hooks = set() if ref_hook is None else {ref_hook}
-
-        for child in self.children:
-            hooks |= child.get_ref_hooks()
-        return hooks
-
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
@@ -1199,26 +1169,26 @@ class Component(BaseComponent, ABC):
 
         # Add the hook code for the children.
         for child in self.children:
-            code |= child.get_hooks_internal()
+            code = {**code, **child.get_hooks_internal()}
 
         return code
 
-    def get_hooks(self) -> Set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component and its children.
 
         Returns:
             The code that should appear just before returning the rendered component.
         """
-        code = set()
+        code = {}
 
         # Add the hook code for this component.
         hooks = self._get_hooks()
         if hooks is not None:
-            code.add(hooks)
+            code[hooks] = None
 
         # Add the hook code for the children.
         for child in self.children:
-            code |= child.get_hooks()
+            code = {**code, **child.get_hooks()}
 
         return code
 
@@ -1233,7 +1203,7 @@ class Component(BaseComponent, ABC):
             return None
         return format.format_ref(self.id)
 
-    def get_refs(self) -> Set[str]:
+    def get_refs(self) -> set[str]:
         """Get the refs for the children of the component.
 
         Returns:
@@ -1854,29 +1824,21 @@ class StatefulComponent(BaseComponent):
             )
         return trigger_memo
 
-    def get_ref_hooks(self) -> set[str]:
-        """Get the ref hooks for the component and its children.
-
-        Returns:
-            The ref hooks.
-        """
-        return set()
-
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
 
         Returns:
             The code that should appear just before user-defined hooks.
         """
-        return set()
+        return {}
 
-    def get_hooks(self) -> set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component.
 
         Returns:
             The code that should appear just before returning the rendered component.
         """
-        return set()
+        return {}
 
     def get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.

+ 1 - 1
reflex/components/core/banner.py

@@ -22,7 +22,7 @@ from reflex.vars import Var, VarData
 
 connect_error_var_data: VarData = VarData(  # type: ignore
     imports=Imports.EVENTS,
-    hooks={Hooks.EVENTS},
+    hooks={Hooks.EVENTS: None},
 )
 
 connection_error: Var = Var.create_safe(

+ 2 - 1
reflex/components/core/upload.py

@@ -1,4 +1,5 @@
 """A file upload component."""
+
 from __future__ import annotations
 
 import os
@@ -31,7 +32,7 @@ upload_files_context_var_data: VarData = VarData(  # type: ignore
         },
     },
     hooks={
-        "const [filesById, setFilesById] = useContext(UploadFilesContext);",
+        "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
     },
 )
 

+ 2 - 1
reflex/components/el/elements/forms.py

@@ -1,4 +1,5 @@
 """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
+
 from __future__ import annotations
 
 from hashlib import md5
@@ -162,7 +163,7 @@ class Form(BaseHTML):
         props["handle_submit_unique_name"] = ""
         form = super().create(*children, **props)
         form.handle_submit_unique_name = md5(
-            str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8")
+            str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8")
         ).hexdigest()
         return form
 

+ 2 - 2
reflex/components/markdown/markdown.py

@@ -293,8 +293,8 @@ class Markdown(Component):
         hooks = set()
         for _component in self.component_map.values():
             comp = _component(_MOCK_ARG)
-            hooks |= comp.get_hooks_internal()
-            hooks |= comp.get_hooks()
+            hooks.update(comp.get_hooks_internal())
+            hooks.update(comp.get_hooks())
         formatted_hooks = "\n".join(hooks)
         return f"""
         function {self._get_component_map_name()} () {{

+ 8 - 0
reflex/constants/compiler.py

@@ -111,6 +111,14 @@ class Hooks(SimpleNamespace):
     """Common sets of hook declarations."""
 
     EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
+    AUTOFOCUS = """
+                // Set focus to the specified element.
+                const focusRef = useRef(null)
+                useEffect(() => {
+                  if (focusRef.current) {
+                    focusRef.current.focus();
+                  }
+                })"""
 
 
 class MemoizationDisposition(enum.Enum):

+ 4 - 4
reflex/style.py

@@ -22,7 +22,7 @@ color_mode_var_data = VarData(  # type: ignore
         "react": {ImportVar(tag="useContext")},
     },
     hooks={
-        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
     },
 )
 # Var resolves to the current color mode for the app ("light" or "dark")
@@ -240,9 +240,9 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
         if isinstance(value, list):
             # Apply media queries from responsive value list.
             mbps = {
-                media_query(bp): bp_value
-                if isinstance(bp_value, dict)
-                else {key: bp_value}
+                media_query(bp): (
+                    bp_value if isinstance(bp_value, dict) else {key: bp_value}
+                )
                 for bp, bp_value in enumerate(value)
             }
             if key.startswith("&:"):

+ 6 - 6
reflex/vars.py

@@ -1,4 +1,5 @@
 """Define a state var."""
+
 from __future__ import annotations
 
 import contextlib
@@ -21,7 +22,6 @@ from typing import (
     List,
     Literal,
     Optional,
-    Set,
     Tuple,
     Type,
     Union,
@@ -119,7 +119,7 @@ class VarData(Base):
     imports: ImportDict = {}
 
     # Hooks that need to be present in the component to render this var
-    hooks: Set[str] = set()
+    hooks: Dict[str, None] = {}
 
     # Positions of interpolated strings. This is used by the decoder to figure
     # out where the interpolations are and only escape the non-interpolated
@@ -138,7 +138,7 @@ class VarData(Base):
         """
         state = ""
         _imports = {}
-        hooks = set()
+        hooks = {}
         interpolations = []
         for var_data in others:
             if var_data is None:
@@ -182,7 +182,7 @@ class VarData(Base):
         # not part of the vardata itself.
         return (
             self.state == other.state
-            and self.hooks == other.hooks
+            and self.hooks.keys() == other.hooks.keys()
             and imports.collapse_imports(self.imports)
             == imports.collapse_imports(other.imports)
         )
@@ -200,7 +200,7 @@ class VarData(Base):
                 lib: [import_var.dict() for import_var in import_vars]
                 for lib, import_vars in self.imports.items()
             },
-            "hooks": list(self.hooks),
+            "hooks": self.hooks,
         }
 
 
@@ -1659,7 +1659,7 @@ class Var:
             hooks={
                 "const {0} = useContext(StateContexts.{0})".format(
                     format.format_state_name(state_name)
-                )
+                ): None
             },
             imports={
                 f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],

+ 2 - 1
reflex/vars.pyi

@@ -1,4 +1,5 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
+
 from __future__ import annotations
 
 from dataclasses import dataclass
@@ -35,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 class VarData(Base):
     state: str
     imports: dict[str, set[ImportVar]]
-    hooks: set[str]
+    hooks: Dict[str, None]
     interpolations: List[Tuple[int, int]]
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...

+ 2 - 2
tests/components/test_component.py

@@ -596,7 +596,7 @@ def test_get_hooks_nested2(component3, component4):
         component3: component with hooks defined.
         component4: component with different hooks defined.
     """
-    exp_hooks = component3().get_hooks().union(component4().get_hooks())
+    exp_hooks = {**component3().get_hooks(), **component4().get_hooks()}
     assert component3.create(component4.create()).get_hooks() == exp_hooks
     assert component4.create(component3.create()).get_hooks() == exp_hooks
     assert (
@@ -725,7 +725,7 @@ def test_stateful_banner():
 
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
-        hooks={"useTest"},
+        hooks={"useTest": None},
         imports={"test": {ImportVar(tag="test")}},
         state="Test",
         interpolations=[],

+ 1 - 1
tests/test_var.py

@@ -836,7 +836,7 @@ def test_state_with_initial_computed_var(
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",