Prechádzať zdrojové kódy

[REF-2787] add_hooks supports Var-wrapped hooks (#3248)

* [REF-2787] add_hooks supports Var-wrapped hooks

* Fix VarData definition in .pyi file to allow removal of type ignore comments
* Var.create and Var.create_safe accept _var_data parameter
* Replace instances where a set of imports was being passed to VarData
* Update code throughout reduce use of `._replace` to add VarData

* Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict
Masen Furer 1 rok pred
rodič
commit
c5f32db756

+ 46 - 11
reflex/components/component.py

@@ -241,7 +241,7 @@ class Component(BaseComponent, ABC):
         """
         """
         return {}
         return {}
 
 
-    def add_hooks(self) -> list[str]:
+    def add_hooks(self) -> list[str | Var]:
         """Add hooks inside the component function.
         """Add hooks inside the component function.
 
 
         Hooks are pieces of literal Javascript code that is inserted inside the
         Hooks are pieces of literal Javascript code that is inserted inside the
@@ -1265,11 +1265,20 @@ class Component(BaseComponent, ABC):
                 },
                 },
             )
             )
 
 
+        other_imports = []
         user_hooks = self._get_hooks()
         user_hooks = self._get_hooks()
-        if user_hooks is not None and isinstance(user_hooks, Var):
-            _imports = imports.merge_imports(_imports, user_hooks._var_data.imports)  # type: ignore
+        if (
+            user_hooks is not None
+            and isinstance(user_hooks, Var)
+            and user_hooks._var_data is not None
+            and user_hooks._var_data.imports
+        ):
+            other_imports.append(user_hooks._var_data.imports)
+        other_imports.extend(
+            hook_imports for hook_imports in self._get_added_hooks().values()
+        )
 
 
-        return _imports
+        return imports.merge_imports(_imports, *other_imports)
 
 
     def _get_imports(self) -> imports.ImportDict:
     def _get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
         """Get all the libraries and fields that are used by the component.
@@ -1416,6 +1425,36 @@ class Component(BaseComponent, ABC):
             **self._get_special_hooks(),
             **self._get_special_hooks(),
         }
         }
 
 
+    def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
+        """Get the hooks added via `add_hooks` method.
+
+        Returns:
+            The deduplicated hooks and imports added by the component and parent components.
+        """
+        code = {}
+
+        def extract_var_hooks(hook: Var):
+            _imports = {}
+            if hook._var_data is not None:
+                for sub_hook in hook._var_data.hooks:
+                    code[sub_hook] = {}
+                if hook._var_data.imports:
+                    _imports = hook._var_data.imports
+            if str(hook) in code:
+                code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
+            else:
+                code[str(hook)] = _imports
+
+        # Add the hook code from add_hooks for each parent class (this is reversed to preserve
+        # the order of the hooks in the final output)
+        for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
+            for hook in clz.add_hooks(self):
+                if isinstance(hook, Var):
+                    extract_var_hooks(hook)
+                else:
+                    code[hook] = {}
+        return code
+
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:
         """Get the React hooks for this component.
         """Get the React hooks for this component.
 
 
@@ -1454,11 +1493,7 @@ class Component(BaseComponent, ABC):
         if hooks is not None:
         if hooks is not None:
             code[hooks] = None
             code[hooks] = None
 
 
-        # Add the hook code from add_hooks for each parent class (this is reversed to preserve
-        # the order of the hooks in the final output)
-        for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
-            for hook in clz.add_hooks(self):
-                code[hook] = None
+        code.update(self._get_added_hooks())
 
 
         # Add the hook code for the children.
         # Add the hook code for the children.
         for child in self.children:
         for child in self.children:
@@ -2092,8 +2127,8 @@ class StatefulComponent(BaseComponent):
                     var_deps.extend(cls._get_hook_deps(hook))
                     var_deps.extend(cls._get_hook_deps(hook))
             memo_var_data = VarData.merge(
             memo_var_data = VarData.merge(
                 *[var._var_data for var in event_args],
                 *[var._var_data for var in event_args],
-                VarData(  # type: ignore
-                    imports={"react": {ImportVar(tag="useCallback")}},
+                VarData(
+                    imports={"react": [ImportVar(tag="useCallback")]},
                 ),
                 ),
             )
             )
 
 

+ 8 - 4
reflex/components/core/banner.py

@@ -29,23 +29,27 @@ connection_error: Var = Var.create_safe(
     value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
     value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
     _var_is_local=False,
     _var_is_local=False,
     _var_is_string=False,
     _var_is_string=False,
-)._replace(merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+)
 
 
 connection_errors_count: Var = Var.create_safe(
 connection_errors_count: Var = Var.create_safe(
     value="connectErrors.length",
     value="connectErrors.length",
     _var_is_string=False,
     _var_is_string=False,
     _var_is_local=False,
     _var_is_local=False,
-)._replace(merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+)
 
 
 has_connection_errors: Var = Var.create_safe(
 has_connection_errors: Var = Var.create_safe(
     value="connectErrors.length > 0",
     value="connectErrors.length > 0",
     _var_is_string=False,
     _var_is_string=False,
-)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+).to(bool)
 
 
 has_too_many_connection_errors: Var = Var.create_safe(
 has_too_many_connection_errors: Var = Var.create_safe(
     value="connectErrors.length >= 2",
     value="connectErrors.length >= 2",
     _var_is_string=False,
     _var_is_string=False,
-)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+).to(bool)
 
 
 
 
 class WebsocketTargetURL(Bare):
 class WebsocketTargetURL(Bare):

+ 1 - 1
reflex/components/core/cond.py

@@ -13,7 +13,7 @@ from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 from reflex.vars import BaseVar, Var, VarData
 
 
 _IS_TRUE_IMPORT = {
 _IS_TRUE_IMPORT = {
-    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
+    f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
 }
 }
 
 
 
 

+ 2 - 4
reflex/components/core/debounce.py

@@ -109,13 +109,11 @@ class DebounceInput(Component):
                 "{%s}" % (child.alias or child.tag),
                 "{%s}" % (child.alias or child.tag),
                 _var_is_local=False,
                 _var_is_local=False,
                 _var_is_string=False,
                 _var_is_string=False,
-            )._replace(
-                _var_type=Type[Component],
-                merge_var_data=VarData(  # type: ignore
+                _var_data=VarData(
                     imports=child._get_imports(),
                     imports=child._get_imports(),
                     hooks=child._get_hooks_internal(),
                     hooks=child._get_hooks_internal(),
                 ),
                 ),
-            ),
+            ).to(Type[Component]),
         )
         )
 
 
         component = super().create(**props)
         component = super().create(**props)

+ 9 - 10
reflex/components/core/upload.py

@@ -24,12 +24,12 @@ from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 
 DEFAULT_UPLOAD_ID: str = "default"
 DEFAULT_UPLOAD_ID: str = "default"
 
 
-upload_files_context_var_data: VarData = VarData(  # type: ignore
+upload_files_context_var_data: VarData = VarData(
     imports={
     imports={
-        "react": {imports.ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {
+        "react": [imports.ImportVar(tag="useContext")],
+        f"/{Dirs.CONTEXTS_PATH}": [
             imports.ImportVar(tag="UploadFilesContext"),
             imports.ImportVar(tag="UploadFilesContext"),
-        },
+        ],
     },
     },
     hooks={
     hooks={
         "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
         "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@@ -118,14 +118,13 @@ def get_upload_dir() -> Path:
 
 
 
 
 uploaded_files_url_prefix: Var = Var.create_safe(
 uploaded_files_url_prefix: Var = Var.create_safe(
-    "${getBackendURL(env.UPLOAD)}"
-)._replace(
-    merge_var_data=VarData(  # type: ignore
+    "${getBackendURL(env.UPLOAD)}",
+    _var_data=VarData(
         imports={
         imports={
-            f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
-            "/env.json": {imports.ImportVar(tag="env", is_default=True)},
+            f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
+            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
         }
         }
-    )
+    ),
 )
 )
 
 
 
 

+ 12 - 6
reflex/components/el/elements/forms.py

@@ -216,13 +216,17 @@ class Form(BaseHTML):
             if ref.startswith("refs_"):
             if ref.startswith("refs_"):
                 ref_var = Var.create_safe(ref[:-3]).as_ref()
                 ref_var = Var.create_safe(ref[:-3]).as_ref()
                 form_refs[ref[5:-3]] = Var.create_safe(
                 form_refs[ref[5:-3]] = Var.create_safe(
-                    f"getRefValues({str(ref_var)})", _var_is_local=False
-                )._replace(merge_var_data=ref_var._var_data)
+                    f"getRefValues({str(ref_var)})",
+                    _var_is_local=False,
+                    _var_data=ref_var._var_data,
+                )
             else:
             else:
                 ref_var = Var.create_safe(ref).as_ref()
                 ref_var = Var.create_safe(ref).as_ref()
                 form_refs[ref[4:]] = Var.create_safe(
                 form_refs[ref[4:]] = Var.create_safe(
-                    f"getRefValue({str(ref_var)})", _var_is_local=False
-                )._replace(merge_var_data=ref_var._var_data)
+                    f"getRefValue({str(ref_var)})",
+                    _var_is_local=False,
+                    _var_data=ref_var._var_data,
+                )
         return form_refs
         return form_refs
 
 
     def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
     def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
@@ -619,14 +623,16 @@ class Textarea(BaseHTML):
                 on_key_down=Var.create_safe(
                 on_key_down=Var.create_safe(
                     f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
                     f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
                     _var_is_local=False,
                     _var_is_local=False,
-                )._replace(merge_var_data=self.enter_key_submit._var_data),
+                    _var_data=self.enter_key_submit._var_data,
+                )
             )
             )
         if self.auto_height is not None:
         if self.auto_height is not None:
             tag.add_props(
             tag.add_props(
                 on_input=Var.create_safe(
                 on_input=Var.create_safe(
                     f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
                     f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
                     _var_is_local=False,
                     _var_is_local=False,
-                )._replace(merge_var_data=self.auto_height._var_data),
+                    _var_data=self.auto_height._var_data,
+                )
             )
             )
         return tag
         return tag
 
 

+ 4 - 2
reflex/components/gridjs/datatable.py

@@ -114,12 +114,14 @@ class DataTable(Gridjs):
                 _var_name=f"{self.data._var_name}.columns",
                 _var_name=f"{self.data._var_name}.columns",
                 _var_type=List[Any],
                 _var_type=List[Any],
                 _var_full_name_needs_state_prefix=True,
                 _var_full_name_needs_state_prefix=True,
-            )._replace(merge_var_data=self.data._var_data)
+                _var_data=self.data._var_data,
+            )
             self.data = BaseVar(
             self.data = BaseVar(
                 _var_name=f"{self.data._var_name}.data",
                 _var_name=f"{self.data._var_name}.data",
                 _var_type=List[List[Any]],
                 _var_type=List[List[Any]],
                 _var_full_name_needs_state_prefix=True,
                 _var_full_name_needs_state_prefix=True,
-            )._replace(merge_var_data=self.data._var_data)
+                _var_data=self.data._var_data,
+            )
         if types.is_dataframe(type(self.data)):
         if types.is_dataframe(type(self.data)):
             # If given a pandas df break up the data and columns
             # If given a pandas df break up the data and columns
             data = serialize(self.data)
             data = serialize(self.data)

+ 1 - 1
reflex/components/radix/themes/components/tabs.py

@@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
     _valid_parents: List[str] = ["TabsList"]
     _valid_parents: List[str] = ["TabsList"]
 
 
     @classmethod
     @classmethod
-    def create(self, *children, **props) -> Component:
+    def create(cls, *children, **props) -> Component:
         """Create a TabsTrigger component.
         """Create a TabsTrigger component.
 
 
         Args:
         Args:

+ 10 - 7
reflex/components/sonner/toast.py

@@ -162,7 +162,7 @@ class ToastProps(PropsBase):
 class Toaster(Component):
 class Toaster(Component):
     """A Toaster Component for displaying toast notifications."""
     """A Toaster Component for displaying toast notifications."""
 
 
-    library = "sonner@1.4.41"
+    library: str = "sonner@1.4.41"
 
 
     tag = "Toaster"
     tag = "Toaster"
 
 
@@ -209,12 +209,15 @@ class Toaster(Component):
     pause_when_page_is_hidden: Var[bool]
     pause_when_page_is_hidden: Var[bool]
 
 
     def _get_hooks(self) -> Var[str]:
     def _get_hooks(self) -> Var[str]:
-        hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True)
-        hook._var_data = VarData(  # type: ignore
-            imports={
-                "/utils/state": [ImportVar(tag="refs")],
-                self.library: [ImportVar(tag="toast", install=False)],
-            }
+        hook = Var.create_safe(
+            f"{toast_ref} = toast",
+            _var_is_local=True,
+            _var_data=VarData(
+                imports={
+                    "/utils/state": [ImportVar(tag="refs")],
+                    self.library: [ImportVar(tag="toast", install=False)],
+                }
+            ),
         )
         )
         return hook
         return hook
 
 

+ 3 - 3
reflex/constants/compiler.py

@@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
     """Common sets of import vars."""
     """Common sets of import vars."""
 
 
     EVENTS = {
     EVENTS = {
-        "react": {ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
-        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
+        "react": [ImportVar(tag="useContext")],
+        f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
+        f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
     }
     }
 
 
 
 

+ 3 - 3
reflex/style.py

@@ -16,10 +16,10 @@ LIGHT_COLOR_MODE: str = "light"
 DARK_COLOR_MODE: str = "dark"
 DARK_COLOR_MODE: str = "dark"
 
 
 # Reference the global ColorModeContext
 # Reference the global ColorModeContext
-color_mode_var_data = VarData(  # type: ignore
+color_mode_var_data = VarData(
     imports={
     imports={
-        f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
-        "react": {ImportVar(tag="useContext")},
+        f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
+        "react": [ImportVar(tag="useContext")],
     },
     },
     hooks={
     hooks={
         f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
         f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,

+ 14 - 4
reflex/vars.py

@@ -341,7 +341,11 @@ class Var:
 
 
     @classmethod
     @classmethod
     def create(
     def create(
-        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
+        cls,
+        value: Any,
+        _var_is_local: bool = True,
+        _var_is_string: bool = False,
+        _var_data: Optional[VarData] = None,
     ) -> Var | None:
     ) -> Var | None:
         """Create a var from a value.
         """Create a var from a value.
 
 
@@ -349,6 +353,7 @@ class Var:
             value: The value to create the var from.
             value: The value to create the var from.
             _var_is_local: Whether the var is local.
             _var_is_local: Whether the var is local.
             _var_is_string: Whether the var is a string literal.
             _var_is_string: Whether the var is a string literal.
+            _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
             The var.
             The var.
@@ -365,9 +370,8 @@ class Var:
             return value
             return value
 
 
         # Try to pull the imports and hooks from contained values.
         # Try to pull the imports and hooks from contained values.
-        _var_data = None
         if not isinstance(value, str):
         if not isinstance(value, str):
-            _var_data = VarData.merge(*_extract_var_data(value))
+            _var_data = VarData.merge(*_extract_var_data(value), _var_data)
 
 
         # Try to serialize the value.
         # Try to serialize the value.
         type_ = type(value)
         type_ = type(value)
@@ -388,7 +392,11 @@ class Var:
 
 
     @classmethod
     @classmethod
     def create_safe(
     def create_safe(
-        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
+        cls,
+        value: Any,
+        _var_is_local: bool = True,
+        _var_is_string: bool = False,
+        _var_data: Optional[VarData] = None,
     ) -> Var:
     ) -> Var:
         """Create a var from a value, asserting that it is not None.
         """Create a var from a value, asserting that it is not None.
 
 
@@ -396,6 +404,7 @@ class Var:
             value: The value to create the var from.
             value: The value to create the var from.
             _var_is_local: Whether the var is local.
             _var_is_local: Whether the var is local.
             _var_is_string: Whether the var is a string literal.
             _var_is_string: Whether the var is a string literal.
+            _var_data: Additional hooks and imports associated with the Var.
 
 
         Returns:
         Returns:
             The var.
             The var.
@@ -404,6 +413,7 @@ class Var:
             value,
             value,
             _var_is_local=_var_is_local,
             _var_is_local=_var_is_local,
             _var_is_string=_var_is_string,
             _var_is_string=_var_is_string,
+            _var_data=_var_data,
         )
         )
         assert var is not None
         assert var is not None
         return var
         return var

+ 6 - 6
reflex/vars.pyi

@@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ...
 def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 
 
 class VarData(Base):
 class VarData(Base):
-    state: str
-    imports: dict[str, set[ImportVar]]
-    hooks: Dict[str, None]
-    interpolations: List[Tuple[int, int]]
+    state: str = ""
+    imports: dict[str, List[ImportVar]] = {}
+    hooks: Dict[str, None] = {}
+    interpolations: List[Tuple[int, int]] = []
     @classmethod
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...
     def merge(cls, *others: VarData | None) -> VarData | None: ...
 
 
@@ -50,11 +50,11 @@ class Var:
     _var_data: VarData | None = None
     _var_data: VarData | None = None
     @classmethod
     @classmethod
     def create(
     def create(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
+        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
     ) -> Optional[Var]: ...
     ) -> Optional[Var]: ...
     @classmethod
     @classmethod
     def create_safe(
     def create_safe(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
+        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
     ) -> Var: ...
     ) -> Var: ...
     @classmethod
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

+ 39 - 1
tests/components/test_component.py

@@ -1063,7 +1063,7 @@ def test_stateful_banner():
 TEST_VAR = Var.create_safe("test")._replace(
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
     merge_var_data=VarData(
         hooks={"useTest": None},
         hooks={"useTest": None},
-        imports={"test": {ImportVar(tag="test")}},
+        imports={"test": [ImportVar(tag="test")]},
         state="Test",
         state="Test",
         interpolations=[],
         interpolations=[],
     )
     )
@@ -1953,6 +1953,44 @@ def test_component_add_custom_code():
     }
     }
 
 
 
 
+def test_component_add_hooks_var():
+    class HookComponent(Component):
+        def add_hooks(self):
+            return [
+                "const hook3 = useRef(null)",
+                "const hook1 = 42",
+                Var.create(
+                    "useEffect(() => () => {}, [])",
+                    _var_data=VarData(
+                        hooks={
+                            "const hook2 = 43": None,
+                            "const hook3 = useRef(null)": None,
+                        },
+                        imports={"react": [ImportVar(tag="useEffect")]},
+                    ),
+                ),
+                Var.create(
+                    "const hook3 = useRef(null)",
+                    _var_data=VarData(
+                        imports={"react": [ImportVar(tag="useRef")]},
+                    ),
+                ),
+            ]
+
+    assert list(HookComponent()._get_all_hooks()) == [
+        "const hook3 = useRef(null)",
+        "const hook1 = 42",
+        "const hook2 = 43",
+        "useEffect(() => () => {}, [])",
+    ]
+    imports = HookComponent()._get_all_imports()
+    assert len(imports) == 1
+    assert "react" in imports
+    assert len(imports["react"]) == 2
+    assert ImportVar(tag="useRef") in imports["react"]
+    assert ImportVar(tag="useEffect") in imports["react"]
+
+
 def test_add_style_embedded_vars(test_state: BaseState):
 def test_add_style_embedded_vars(test_state: BaseState):
     """Test that add_style works with embedded vars when returning a plain dict.
     """Test that add_style works with embedded vars when returning a plain dict.