Kaynağa Gözat

use dict instead of set to store hooks (#2995)

Thomas Brandého 1 yıl önce
ebeveyn
işleme
34ee07ecd1

+ 0 - 4
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -8,10 +8,6 @@
 
 
 {% block export %}
 {% block export %}
 export default function Component() {
 export default function Component() {
-  {% for ref_hook in ref_hooks %}
-  {{ ref_hook }}
-  {% endfor %}
-  
   {% for hook in hooks %}
   {% for hook in hooks %}
   {{ hook }}
   {{ hook }}
   {% endfor %}
   {% endfor %}

+ 0 - 4
reflex/.templates/jinja/web/pages/stateful_component.js.jinja2

@@ -1,10 +1,6 @@
 {% import 'web/pages/utils.js.jinja2' as utils %}
 {% import 'web/pages/utils.js.jinja2' as utils %}
 
 
 export function {{tag_name}} () {
 export function {{tag_name}} () {
-  {% for ref_hook in component.get_ref_hooks() %}
-  {{ ref_hook }}
-  {% endfor %}
-
   {% for hook in component.get_hooks_internal() %}
   {% for hook in component.get_hooks_internal() %}
   {{ hook }}
   {{ hook }}
   {% endfor %}
   {% endfor %}

+ 2 - 3
reflex/compiler/compiler.py

@@ -50,7 +50,7 @@ def _compile_app(app_root: Component) -> str:
     return templates.APP_ROOT.render(
     return templates.APP_ROOT.render(
         imports=utils.compile_imports(app_root.get_imports()),
         imports=utils.compile_imports(app_root.get_imports()),
         custom_codes=app_root.get_custom_code(),
         custom_codes=app_root.get_custom_code(),
-        hooks=app_root.get_hooks_internal() | app_root.get_hooks(),
+        hooks={**app_root.get_hooks_internal(), **app_root.get_hooks()},
         render=app_root.render(),
         render=app_root.render(),
     )
     )
 
 
@@ -119,8 +119,7 @@ def _compile_page(
         imports=imports,
         imports=imports,
         dynamic_imports=component.get_dynamic_imports(),
         dynamic_imports=component.get_dynamic_imports(),
         custom_codes=component.get_custom_code(),
         custom_codes=component.get_custom_code(),
-        ref_hooks=component.get_ref_hooks(),
-        hooks=component.get_hooks_internal() | component.get_hooks(),
+        hooks={**component.get_hooks_internal(), **component.get_hooks()},
         render=component.render(),
         render=component.render(),
         **kwargs,
         **kwargs,
     )
     )

+ 1 - 1
reflex/compiler/utils.py

@@ -265,7 +265,7 @@ def compile_custom_component(
             "name": component.tag,
             "name": component.tag,
             "props": props,
             "props": props,
             "render": render.render(),
             "render": render.render(),
-            "hooks": render.get_hooks_internal() | render.get_hooks(),
+            "hooks": {**render.get_hooks_internal(), **render.get_hooks()},
             "custom_code": render.get_custom_code(),
             "custom_code": render.get_custom_code(),
         },
         },
         imports,
         imports,

+ 32 - 70
reflex/components/component.py

@@ -76,15 +76,7 @@ class BaseComponent(Base, ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def get_ref_hooks(self) -> set[str]:
-        """Get the hooks required by refs in this component.
-
-        Returns:
-            The hooks for the refs.
-        """
-
-    @abstractmethod
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
@@ -92,7 +84,7 @@ class BaseComponent(Base, ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def get_hooks(self) -> set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
         Returns:
         Returns:
@@ -929,7 +921,7 @@ class Component(BaseComponent, ABC):
         """
         """
         return None
         return None
 
 
-    def get_custom_code(self) -> Set[str]:
+    def get_custom_code(self) -> set[str]:
         """Get custom code for the component and its children.
         """Get custom code for the component and its children.
 
 
         Returns:
         Returns:
@@ -1108,62 +1100,53 @@ class Component(BaseComponent, ABC):
         if ref is not None:
         if ref is not None:
             return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};"
             return f"const {ref} = useRef(null); {str(Var.create_safe(ref).as_ref())} = {ref};"
 
 
-    def _get_vars_hooks(self) -> set[str]:
+    def _get_vars_hooks(self) -> dict[str, None]:
         """Get the hooks required by vars referenced in this component.
         """Get the hooks required by vars referenced in this component.
 
 
         Returns:
         Returns:
             The hooks for the vars.
             The hooks for the vars.
         """
         """
-        vars_hooks = set()
+        vars_hooks = {}
         for var in self._get_vars():
         for var in self._get_vars():
             if var._var_data:
             if var._var_data:
                 vars_hooks.update(var._var_data.hooks)
                 vars_hooks.update(var._var_data.hooks)
         return vars_hooks
         return vars_hooks
 
 
-    def _get_events_hooks(self) -> set[str]:
+    def _get_events_hooks(self) -> dict[str, None]:
         """Get the hooks required by events referenced in this component.
         """Get the hooks required by events referenced in this component.
 
 
         Returns:
         Returns:
             The hooks for the events.
             The hooks for the events.
         """
         """
-        if self.event_triggers:
-            return {Hooks.EVENTS}
-        return set()
+        return {Hooks.EVENTS: None} if self.event_triggers else {}
 
 
-    def _get_special_hooks(self) -> set[str]:
+    def _get_special_hooks(self) -> dict[str, None]:
         """Get the hooks required by special actions referenced in this component.
         """Get the hooks required by special actions referenced in this component.
 
 
         Returns:
         Returns:
             The hooks for special actions.
             The hooks for special actions.
         """
         """
-        if self.autofocus:
-            return {
-                """
-                // Set focus to the specified element.
-                const focusRef = useRef(null)
-                useEffect(() => {
-                  if (focusRef.current) {
-                    focusRef.current.focus();
-                  }
-                })""",
-            }
-        return set()
+        return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
 
 
-    def _get_hooks_internal(self) -> Set[str]:
+    def _get_hooks_internal(self) -> dict[str, None]:
         """Get the React hooks for this component managed by the framework.
         """Get the React hooks for this component managed by the framework.
 
 
         Downstream components should NOT override this method to avoid breaking
         Downstream components should NOT override this method to avoid breaking
         framework functionality.
         framework functionality.
 
 
         Returns:
         Returns:
-            Set of internally managed hooks.
+            The internally managed hooks.
         """
         """
-        return (
-            self._get_vars_hooks()
-            | self._get_events_hooks()
-            | self._get_special_hooks()
-            | set(hook for hook in [self._get_mount_lifecycle_hook()] if hook)
-        )
+        return {
+            **{
+                hook: None
+                for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
+                if hook is not None
+            },
+            **self._get_vars_hooks(),
+            **self._get_events_hooks(),
+            **self._get_special_hooks(),
+        }
 
 
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
@@ -1175,20 +1158,7 @@ class Component(BaseComponent, ABC):
         """
         """
         return
         return
 
 
-    def get_ref_hooks(self) -> Set[str]:
-        """Get the ref hooks for the component and its children.
-
-        Returns:
-            The ref hooks.
-        """
-        ref_hook = self._get_ref_hook()
-        hooks = set() if ref_hook is None else {ref_hook}
-
-        for child in self.children:
-            hooks |= child.get_ref_hooks()
-        return hooks
-
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
@@ -1199,26 +1169,26 @@ class Component(BaseComponent, ABC):
 
 
         # Add the hook code for the children.
         # Add the hook code for the children.
         for child in self.children:
         for child in self.children:
-            code |= child.get_hooks_internal()
+            code = {**code, **child.get_hooks_internal()}
 
 
         return code
         return code
 
 
-    def get_hooks(self) -> Set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component and its children.
         """Get the React hooks for this component and its children.
 
 
         Returns:
         Returns:
             The code that should appear just before returning the rendered component.
             The code that should appear just before returning the rendered component.
         """
         """
-        code = set()
+        code = {}
 
 
         # Add the hook code for this component.
         # Add the hook code for this component.
         hooks = self._get_hooks()
         hooks = self._get_hooks()
         if hooks is not None:
         if hooks is not None:
-            code.add(hooks)
+            code[hooks] = None
 
 
         # Add the hook code for the children.
         # Add the hook code for the children.
         for child in self.children:
         for child in self.children:
-            code |= child.get_hooks()
+            code = {**code, **child.get_hooks()}
 
 
         return code
         return code
 
 
@@ -1233,7 +1203,7 @@ class Component(BaseComponent, ABC):
             return None
             return None
         return format.format_ref(self.id)
         return format.format_ref(self.id)
 
 
-    def get_refs(self) -> Set[str]:
+    def get_refs(self) -> set[str]:
         """Get the refs for the children of the component.
         """Get the refs for the children of the component.
 
 
         Returns:
         Returns:
@@ -1854,29 +1824,21 @@ class StatefulComponent(BaseComponent):
             )
             )
         return trigger_memo
         return trigger_memo
 
 
-    def get_ref_hooks(self) -> set[str]:
-        """Get the ref hooks for the component and its children.
-
-        Returns:
-            The ref hooks.
-        """
-        return set()
-
-    def get_hooks_internal(self) -> set[str]:
+    def get_hooks_internal(self) -> dict[str, None]:
         """Get the reflex internal hooks for the component and its children.
         """Get the reflex internal hooks for the component and its children.
 
 
         Returns:
         Returns:
             The code that should appear just before user-defined hooks.
             The code that should appear just before user-defined hooks.
         """
         """
-        return set()
+        return {}
 
 
-    def get_hooks(self) -> set[str]:
+    def get_hooks(self) -> dict[str, None]:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
         Returns:
         Returns:
             The code that should appear just before returning the rendered component.
             The code that should appear just before returning the rendered component.
         """
         """
-        return set()
+        return {}
 
 
     def get_imports(self) -> imports.ImportDict:
     def get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.

+ 1 - 1
reflex/components/core/banner.py

@@ -22,7 +22,7 @@ from reflex.vars import Var, VarData
 
 
 connect_error_var_data: VarData = VarData(  # type: ignore
 connect_error_var_data: VarData = VarData(  # type: ignore
     imports=Imports.EVENTS,
     imports=Imports.EVENTS,
-    hooks={Hooks.EVENTS},
+    hooks={Hooks.EVENTS: None},
 )
 )
 
 
 connection_error: Var = Var.create_safe(
 connection_error: Var = Var.create_safe(

+ 2 - 1
reflex/components/core/upload.py

@@ -1,4 +1,5 @@
 """A file upload component."""
 """A file upload component."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
 import os
 import os
@@ -31,7 +32,7 @@ upload_files_context_var_data: VarData = VarData(  # type: ignore
         },
         },
     },
     },
     hooks={
     hooks={
-        "const [filesById, setFilesById] = useContext(UploadFilesContext);",
+        "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
     },
     },
 )
 )
 
 

+ 2 - 1
reflex/components/el/elements/forms.py

@@ -1,4 +1,5 @@
 """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
 """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
 from hashlib import md5
 from hashlib import md5
@@ -162,7 +163,7 @@ class Form(BaseHTML):
         props["handle_submit_unique_name"] = ""
         props["handle_submit_unique_name"] = ""
         form = super().create(*children, **props)
         form = super().create(*children, **props)
         form.handle_submit_unique_name = md5(
         form.handle_submit_unique_name = md5(
-            str(form.get_hooks_internal().union(form.get_hooks())).encode("utf-8")
+            str({**form.get_hooks_internal(), **form.get_hooks()}).encode("utf-8")
         ).hexdigest()
         ).hexdigest()
         return form
         return form
 
 

+ 2 - 2
reflex/components/markdown/markdown.py

@@ -293,8 +293,8 @@ class Markdown(Component):
         hooks = set()
         hooks = set()
         for _component in self.component_map.values():
         for _component in self.component_map.values():
             comp = _component(_MOCK_ARG)
             comp = _component(_MOCK_ARG)
-            hooks |= comp.get_hooks_internal()
-            hooks |= comp.get_hooks()
+            hooks.update(comp.get_hooks_internal())
+            hooks.update(comp.get_hooks())
         formatted_hooks = "\n".join(hooks)
         formatted_hooks = "\n".join(hooks)
         return f"""
         return f"""
         function {self._get_component_map_name()} () {{
         function {self._get_component_map_name()} () {{

+ 8 - 0
reflex/constants/compiler.py

@@ -111,6 +111,14 @@ class Hooks(SimpleNamespace):
     """Common sets of hook declarations."""
     """Common sets of hook declarations."""
 
 
     EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
     EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
+    AUTOFOCUS = """
+                // Set focus to the specified element.
+                const focusRef = useRef(null)
+                useEffect(() => {
+                  if (focusRef.current) {
+                    focusRef.current.focus();
+                  }
+                })"""
 
 
 
 
 class MemoizationDisposition(enum.Enum):
 class MemoizationDisposition(enum.Enum):

+ 4 - 4
reflex/style.py

@@ -22,7 +22,7 @@ color_mode_var_data = VarData(  # type: ignore
         "react": {ImportVar(tag="useContext")},
         "react": {ImportVar(tag="useContext")},
     },
     },
     hooks={
     hooks={
-        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+        f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
     },
     },
 )
 )
 # Var resolves to the current color mode for the app ("light" or "dark")
 # Var resolves to the current color mode for the app ("light" or "dark")
@@ -240,9 +240,9 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None:
         if isinstance(value, list):
         if isinstance(value, list):
             # Apply media queries from responsive value list.
             # Apply media queries from responsive value list.
             mbps = {
             mbps = {
-                media_query(bp): bp_value
-                if isinstance(bp_value, dict)
-                else {key: bp_value}
+                media_query(bp): (
+                    bp_value if isinstance(bp_value, dict) else {key: bp_value}
+                )
                 for bp, bp_value in enumerate(value)
                 for bp, bp_value in enumerate(value)
             }
             }
             if key.startswith("&:"):
             if key.startswith("&:"):

+ 6 - 6
reflex/vars.py

@@ -1,4 +1,5 @@
 """Define a state var."""
 """Define a state var."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
 import contextlib
 import contextlib
@@ -21,7 +22,6 @@ from typing import (
     List,
     List,
     Literal,
     Literal,
     Optional,
     Optional,
-    Set,
     Tuple,
     Tuple,
     Type,
     Type,
     Union,
     Union,
@@ -119,7 +119,7 @@ class VarData(Base):
     imports: ImportDict = {}
     imports: ImportDict = {}
 
 
     # Hooks that need to be present in the component to render this var
     # Hooks that need to be present in the component to render this var
-    hooks: Set[str] = set()
+    hooks: Dict[str, None] = {}
 
 
     # Positions of interpolated strings. This is used by the decoder to figure
     # Positions of interpolated strings. This is used by the decoder to figure
     # out where the interpolations are and only escape the non-interpolated
     # out where the interpolations are and only escape the non-interpolated
@@ -138,7 +138,7 @@ class VarData(Base):
         """
         """
         state = ""
         state = ""
         _imports = {}
         _imports = {}
-        hooks = set()
+        hooks = {}
         interpolations = []
         interpolations = []
         for var_data in others:
         for var_data in others:
             if var_data is None:
             if var_data is None:
@@ -182,7 +182,7 @@ class VarData(Base):
         # not part of the vardata itself.
         # not part of the vardata itself.
         return (
         return (
             self.state == other.state
             self.state == other.state
-            and self.hooks == other.hooks
+            and self.hooks.keys() == other.hooks.keys()
             and imports.collapse_imports(self.imports)
             and imports.collapse_imports(self.imports)
             == imports.collapse_imports(other.imports)
             == imports.collapse_imports(other.imports)
         )
         )
@@ -200,7 +200,7 @@ class VarData(Base):
                 lib: [import_var.dict() for import_var in import_vars]
                 lib: [import_var.dict() for import_var in import_vars]
                 for lib, import_vars in self.imports.items()
                 for lib, import_vars in self.imports.items()
             },
             },
-            "hooks": list(self.hooks),
+            "hooks": self.hooks,
         }
         }
 
 
 
 
@@ -1659,7 +1659,7 @@ class Var:
             hooks={
             hooks={
                 "const {0} = useContext(StateContexts.{0})".format(
                 "const {0} = useContext(StateContexts.{0})".format(
                     format.format_state_name(state_name)
                     format.format_state_name(state_name)
-                )
+                ): None
             },
             },
             imports={
             imports={
                 f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
                 f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],

+ 2 - 1
reflex/vars.pyi

@@ -1,4 +1,5 @@
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
 """ Generated with stubgen from mypy, then manually edited, do not regen."""
+
 from __future__ import annotations
 from __future__ import annotations
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
@@ -35,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 class VarData(Base):
 class VarData(Base):
     state: str
     state: str
     imports: dict[str, set[ImportVar]]
     imports: dict[str, set[ImportVar]]
-    hooks: set[str]
+    hooks: Dict[str, None]
     interpolations: List[Tuple[int, int]]
     interpolations: List[Tuple[int, int]]
     @classmethod
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...
     def merge(cls, *others: VarData | None) -> VarData | None: ...

+ 2 - 2
tests/components/test_component.py

@@ -596,7 +596,7 @@ def test_get_hooks_nested2(component3, component4):
         component3: component with hooks defined.
         component3: component with hooks defined.
         component4: component with different hooks defined.
         component4: component with different hooks defined.
     """
     """
-    exp_hooks = component3().get_hooks().union(component4().get_hooks())
+    exp_hooks = {**component3().get_hooks(), **component4().get_hooks()}
     assert component3.create(component4.create()).get_hooks() == exp_hooks
     assert component3.create(component4.create()).get_hooks() == exp_hooks
     assert component4.create(component3.create()).get_hooks() == exp_hooks
     assert component4.create(component3.create()).get_hooks() == exp_hooks
     assert (
     assert (
@@ -725,7 +725,7 @@ def test_stateful_banner():
 
 
 TEST_VAR = Var.create_safe("test")._replace(
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
     merge_var_data=VarData(
-        hooks={"useTest"},
+        hooks={"useTest": None},
         imports={"test": {ImportVar(tag="test")}},
         imports={"test": {ImportVar(tag="test")}},
         state="Test",
         state="Test",
         interpolations=[],
         interpolations=[],

+ 1 - 1
tests/test_var.py

@@ -836,7 +836,7 @@ def test_state_with_initial_computed_var(
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
         (
         (
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
             f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
-            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
+            'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": {"const state = useContext(StateContexts.state)": null}, "string_length": 13}</reflex.Var>{state.myvar}',
         ),
         ),
         (
         (
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
             f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",