Browse Source

[REF-2787] add_hooks supports Var-wrapped hooks (#3248)

* [REF-2787] add_hooks supports Var-wrapped hooks

* Fix VarData definition in .pyi file to allow removal of type ignore comments
* Var.create and Var.create_safe accept _var_data parameter
* Replace instances where a set of imports was being passed to VarData
* Update code throughout reduce use of `._replace` to add VarData

* Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict
Masen Furer 1 year ago
parent
commit
c5f32db756

+ 46 - 11
reflex/components/component.py

@@ -241,7 +241,7 @@ class Component(BaseComponent, ABC):
         """
         return {}
 
-    def add_hooks(self) -> list[str]:
+    def add_hooks(self) -> list[str | Var]:
         """Add hooks inside the component function.
 
         Hooks are pieces of literal Javascript code that is inserted inside the
@@ -1265,11 +1265,20 @@ class Component(BaseComponent, ABC):
                 },
             )
 
+        other_imports = []
         user_hooks = self._get_hooks()
-        if user_hooks is not None and isinstance(user_hooks, Var):
-            _imports = imports.merge_imports(_imports, user_hooks._var_data.imports)  # type: ignore
+        if (
+            user_hooks is not None
+            and isinstance(user_hooks, Var)
+            and user_hooks._var_data is not None
+            and user_hooks._var_data.imports
+        ):
+            other_imports.append(user_hooks._var_data.imports)
+        other_imports.extend(
+            hook_imports for hook_imports in self._get_added_hooks().values()
+        )
 
-        return _imports
+        return imports.merge_imports(_imports, *other_imports)
 
     def _get_imports(self) -> imports.ImportDict:
         """Get all the libraries and fields that are used by the component.
@@ -1416,6 +1425,36 @@ class Component(BaseComponent, ABC):
             **self._get_special_hooks(),
         }
 
+    def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
+        """Get the hooks added via `add_hooks` method.
+
+        Returns:
+            The deduplicated hooks and imports added by the component and parent components.
+        """
+        code = {}
+
+        def extract_var_hooks(hook: Var):
+            _imports = {}
+            if hook._var_data is not None:
+                for sub_hook in hook._var_data.hooks:
+                    code[sub_hook] = {}
+                if hook._var_data.imports:
+                    _imports = hook._var_data.imports
+            if str(hook) in code:
+                code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
+            else:
+                code[str(hook)] = _imports
+
+        # Add the hook code from add_hooks for each parent class (this is reversed to preserve
+        # the order of the hooks in the final output)
+        for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
+            for hook in clz.add_hooks(self):
+                if isinstance(hook, Var):
+                    extract_var_hooks(hook)
+                else:
+                    code[hook] = {}
+        return code
+
     def _get_hooks(self) -> str | None:
         """Get the React hooks for this component.
 
@@ -1454,11 +1493,7 @@ class Component(BaseComponent, ABC):
         if hooks is not None:
             code[hooks] = None
 
-        # Add the hook code from add_hooks for each parent class (this is reversed to preserve
-        # the order of the hooks in the final output)
-        for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
-            for hook in clz.add_hooks(self):
-                code[hook] = None
+        code.update(self._get_added_hooks())
 
         # Add the hook code for the children.
         for child in self.children:
@@ -2092,8 +2127,8 @@ class StatefulComponent(BaseComponent):
                     var_deps.extend(cls._get_hook_deps(hook))
             memo_var_data = VarData.merge(
                 *[var._var_data for var in event_args],
-                VarData(  # type: ignore
-                    imports={"react": {ImportVar(tag="useCallback")}},
+                VarData(
+                    imports={"react": [ImportVar(tag="useCallback")]},
                 ),
             )
 

+ 8 - 4
reflex/components/core/banner.py

@@ -29,23 +29,27 @@ connection_error: Var = Var.create_safe(
     value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
     _var_is_local=False,
     _var_is_string=False,
-)._replace(merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+)
 
 connection_errors_count: Var = Var.create_safe(
     value="connectErrors.length",
     _var_is_string=False,
     _var_is_local=False,
-)._replace(merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+)
 
 has_connection_errors: Var = Var.create_safe(
     value="connectErrors.length > 0",
     _var_is_string=False,
-)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+).to(bool)
 
 has_too_many_connection_errors: Var = Var.create_safe(
     value="connectErrors.length >= 2",
     _var_is_string=False,
-)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
+    _var_data=connect_error_var_data,
+).to(bool)
 
 
 class WebsocketTargetURL(Bare):

+ 1 - 1
reflex/components/core/cond.py

@@ -13,7 +13,7 @@ from reflex.utils import format, imports
 from reflex.vars import BaseVar, Var, VarData
 
 _IS_TRUE_IMPORT = {
-    f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
+    f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
 }
 
 

+ 2 - 4
reflex/components/core/debounce.py

@@ -109,13 +109,11 @@ class DebounceInput(Component):
                 "{%s}" % (child.alias or child.tag),
                 _var_is_local=False,
                 _var_is_string=False,
-            )._replace(
-                _var_type=Type[Component],
-                merge_var_data=VarData(  # type: ignore
+                _var_data=VarData(
                     imports=child._get_imports(),
                     hooks=child._get_hooks_internal(),
                 ),
-            ),
+            ).to(Type[Component]),
         )
 
         component = super().create(**props)

+ 9 - 10
reflex/components/core/upload.py

@@ -24,12 +24,12 @@ from reflex.vars import BaseVar, CallableVar, Var, VarData
 
 DEFAULT_UPLOAD_ID: str = "default"
 
-upload_files_context_var_data: VarData = VarData(  # type: ignore
+upload_files_context_var_data: VarData = VarData(
     imports={
-        "react": {imports.ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {
+        "react": [imports.ImportVar(tag="useContext")],
+        f"/{Dirs.CONTEXTS_PATH}": [
             imports.ImportVar(tag="UploadFilesContext"),
-        },
+        ],
     },
     hooks={
         "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@@ -118,14 +118,13 @@ def get_upload_dir() -> Path:
 
 
 uploaded_files_url_prefix: Var = Var.create_safe(
-    "${getBackendURL(env.UPLOAD)}"
-)._replace(
-    merge_var_data=VarData(  # type: ignore
+    "${getBackendURL(env.UPLOAD)}",
+    _var_data=VarData(
         imports={
-            f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
-            "/env.json": {imports.ImportVar(tag="env", is_default=True)},
+            f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
+            "/env.json": [imports.ImportVar(tag="env", is_default=True)],
         }
-    )
+    ),
 )
 
 

+ 12 - 6
reflex/components/el/elements/forms.py

@@ -216,13 +216,17 @@ class Form(BaseHTML):
             if ref.startswith("refs_"):
                 ref_var = Var.create_safe(ref[:-3]).as_ref()
                 form_refs[ref[5:-3]] = Var.create_safe(
-                    f"getRefValues({str(ref_var)})", _var_is_local=False
-                )._replace(merge_var_data=ref_var._var_data)
+                    f"getRefValues({str(ref_var)})",
+                    _var_is_local=False,
+                    _var_data=ref_var._var_data,
+                )
             else:
                 ref_var = Var.create_safe(ref).as_ref()
                 form_refs[ref[4:]] = Var.create_safe(
-                    f"getRefValue({str(ref_var)})", _var_is_local=False
-                )._replace(merge_var_data=ref_var._var_data)
+                    f"getRefValue({str(ref_var)})",
+                    _var_is_local=False,
+                    _var_data=ref_var._var_data,
+                )
         return form_refs
 
     def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
@@ -619,14 +623,16 @@ class Textarea(BaseHTML):
                 on_key_down=Var.create_safe(
                     f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
                     _var_is_local=False,
-                )._replace(merge_var_data=self.enter_key_submit._var_data),
+                    _var_data=self.enter_key_submit._var_data,
+                )
             )
         if self.auto_height is not None:
             tag.add_props(
                 on_input=Var.create_safe(
                     f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
                     _var_is_local=False,
-                )._replace(merge_var_data=self.auto_height._var_data),
+                    _var_data=self.auto_height._var_data,
+                )
             )
         return tag
 

+ 4 - 2
reflex/components/gridjs/datatable.py

@@ -114,12 +114,14 @@ class DataTable(Gridjs):
                 _var_name=f"{self.data._var_name}.columns",
                 _var_type=List[Any],
                 _var_full_name_needs_state_prefix=True,
-            )._replace(merge_var_data=self.data._var_data)
+                _var_data=self.data._var_data,
+            )
             self.data = BaseVar(
                 _var_name=f"{self.data._var_name}.data",
                 _var_type=List[List[Any]],
                 _var_full_name_needs_state_prefix=True,
-            )._replace(merge_var_data=self.data._var_data)
+                _var_data=self.data._var_data,
+            )
         if types.is_dataframe(type(self.data)):
             # If given a pandas df break up the data and columns
             data = serialize(self.data)

+ 1 - 1
reflex/components/radix/themes/components/tabs.py

@@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
     _valid_parents: List[str] = ["TabsList"]
 
     @classmethod
-    def create(self, *children, **props) -> Component:
+    def create(cls, *children, **props) -> Component:
         """Create a TabsTrigger component.
 
         Args:

+ 10 - 7
reflex/components/sonner/toast.py

@@ -162,7 +162,7 @@ class ToastProps(PropsBase):
 class Toaster(Component):
     """A Toaster Component for displaying toast notifications."""
 
-    library = "sonner@1.4.41"
+    library: str = "sonner@1.4.41"
 
     tag = "Toaster"
 
@@ -209,12 +209,15 @@ class Toaster(Component):
     pause_when_page_is_hidden: Var[bool]
 
     def _get_hooks(self) -> Var[str]:
-        hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True)
-        hook._var_data = VarData(  # type: ignore
-            imports={
-                "/utils/state": [ImportVar(tag="refs")],
-                self.library: [ImportVar(tag="toast", install=False)],
-            }
+        hook = Var.create_safe(
+            f"{toast_ref} = toast",
+            _var_is_local=True,
+            _var_data=VarData(
+                imports={
+                    "/utils/state": [ImportVar(tag="refs")],
+                    self.library: [ImportVar(tag="toast", install=False)],
+                }
+            ),
         )
         return hook
 

+ 3 - 3
reflex/constants/compiler.py

@@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
     """Common sets of import vars."""
 
     EVENTS = {
-        "react": {ImportVar(tag="useContext")},
-        f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
-        f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
+        "react": [ImportVar(tag="useContext")],
+        f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
+        f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
     }
 
 

+ 3 - 3
reflex/style.py

@@ -16,10 +16,10 @@ LIGHT_COLOR_MODE: str = "light"
 DARK_COLOR_MODE: str = "dark"
 
 # Reference the global ColorModeContext
-color_mode_var_data = VarData(  # type: ignore
+color_mode_var_data = VarData(
     imports={
-        f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
-        "react": {ImportVar(tag="useContext")},
+        f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
+        "react": [ImportVar(tag="useContext")],
     },
     hooks={
         f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,

+ 14 - 4
reflex/vars.py

@@ -341,7 +341,11 @@ class Var:
 
     @classmethod
     def create(
-        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
+        cls,
+        value: Any,
+        _var_is_local: bool = True,
+        _var_is_string: bool = False,
+        _var_data: Optional[VarData] = None,
     ) -> Var | None:
         """Create a var from a value.
 
@@ -349,6 +353,7 @@ class Var:
             value: The value to create the var from.
             _var_is_local: Whether the var is local.
             _var_is_string: Whether the var is a string literal.
+            _var_data: Additional hooks and imports associated with the Var.
 
         Returns:
             The var.
@@ -365,9 +370,8 @@ class Var:
             return value
 
         # Try to pull the imports and hooks from contained values.
-        _var_data = None
         if not isinstance(value, str):
-            _var_data = VarData.merge(*_extract_var_data(value))
+            _var_data = VarData.merge(*_extract_var_data(value), _var_data)
 
         # Try to serialize the value.
         type_ = type(value)
@@ -388,7 +392,11 @@ class Var:
 
     @classmethod
     def create_safe(
-        cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
+        cls,
+        value: Any,
+        _var_is_local: bool = True,
+        _var_is_string: bool = False,
+        _var_data: Optional[VarData] = None,
     ) -> Var:
         """Create a var from a value, asserting that it is not None.
 
@@ -396,6 +404,7 @@ class Var:
             value: The value to create the var from.
             _var_is_local: Whether the var is local.
             _var_is_string: Whether the var is a string literal.
+            _var_data: Additional hooks and imports associated with the Var.
 
         Returns:
             The var.
@@ -404,6 +413,7 @@ class Var:
             value,
             _var_is_local=_var_is_local,
             _var_is_string=_var_is_string,
+            _var_data=_var_data,
         )
         assert var is not None
         return var

+ 6 - 6
reflex/vars.pyi

@@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ...
 def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
 
 class VarData(Base):
-    state: str
-    imports: dict[str, set[ImportVar]]
-    hooks: Dict[str, None]
-    interpolations: List[Tuple[int, int]]
+    state: str = ""
+    imports: dict[str, List[ImportVar]] = {}
+    hooks: Dict[str, None] = {}
+    interpolations: List[Tuple[int, int]] = []
     @classmethod
     def merge(cls, *others: VarData | None) -> VarData | None: ...
 
@@ -50,11 +50,11 @@ class Var:
     _var_data: VarData | None = None
     @classmethod
     def create(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
+        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
     ) -> Optional[Var]: ...
     @classmethod
     def create_safe(
-        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
+        cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
     ) -> Var: ...
     @classmethod
     def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

+ 39 - 1
tests/components/test_component.py

@@ -1063,7 +1063,7 @@ def test_stateful_banner():
 TEST_VAR = Var.create_safe("test")._replace(
     merge_var_data=VarData(
         hooks={"useTest": None},
-        imports={"test": {ImportVar(tag="test")}},
+        imports={"test": [ImportVar(tag="test")]},
         state="Test",
         interpolations=[],
     )
@@ -1953,6 +1953,44 @@ def test_component_add_custom_code():
     }
 
 
+def test_component_add_hooks_var():
+    class HookComponent(Component):
+        def add_hooks(self):
+            return [
+                "const hook3 = useRef(null)",
+                "const hook1 = 42",
+                Var.create(
+                    "useEffect(() => () => {}, [])",
+                    _var_data=VarData(
+                        hooks={
+                            "const hook2 = 43": None,
+                            "const hook3 = useRef(null)": None,
+                        },
+                        imports={"react": [ImportVar(tag="useEffect")]},
+                    ),
+                ),
+                Var.create(
+                    "const hook3 = useRef(null)",
+                    _var_data=VarData(
+                        imports={"react": [ImportVar(tag="useRef")]},
+                    ),
+                ),
+            ]
+
+    assert list(HookComponent()._get_all_hooks()) == [
+        "const hook3 = useRef(null)",
+        "const hook1 = 42",
+        "const hook2 = 43",
+        "useEffect(() => () => {}, [])",
+    ]
+    imports = HookComponent()._get_all_imports()
+    assert len(imports) == 1
+    assert "react" in imports
+    assert len(imports["react"]) == 2
+    assert ImportVar(tag="useRef") in imports["react"]
+    assert ImportVar(tag="useEffect") in imports["react"]
+
+
 def test_add_style_embedded_vars(test_state: BaseState):
     """Test that add_style works with embedded vars when returning a plain dict.