|
@@ -241,7 +241,7 @@ class Component(BaseComponent, ABC):
|
|
"""
|
|
"""
|
|
return {}
|
|
return {}
|
|
|
|
|
|
- def add_hooks(self) -> list[str]:
|
|
|
|
|
|
+ def add_hooks(self) -> list[str | Var]:
|
|
"""Add hooks inside the component function.
|
|
"""Add hooks inside the component function.
|
|
|
|
|
|
Hooks are pieces of literal Javascript code that is inserted inside the
|
|
Hooks are pieces of literal Javascript code that is inserted inside the
|
|
@@ -1265,11 +1265,20 @@ class Component(BaseComponent, ABC):
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ other_imports = []
|
|
user_hooks = self._get_hooks()
|
|
user_hooks = self._get_hooks()
|
|
- if user_hooks is not None and isinstance(user_hooks, Var):
|
|
|
|
- _imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore
|
|
|
|
|
|
+ if (
|
|
|
|
+ user_hooks is not None
|
|
|
|
+ and isinstance(user_hooks, Var)
|
|
|
|
+ and user_hooks._var_data is not None
|
|
|
|
+ and user_hooks._var_data.imports
|
|
|
|
+ ):
|
|
|
|
+ other_imports.append(user_hooks._var_data.imports)
|
|
|
|
+ other_imports.extend(
|
|
|
|
+ hook_imports for hook_imports in self._get_added_hooks().values()
|
|
|
|
+ )
|
|
|
|
|
|
- return _imports
|
|
|
|
|
|
+ return imports.merge_imports(_imports, *other_imports)
|
|
|
|
|
|
def _get_imports(self) -> imports.ImportDict:
|
|
def _get_imports(self) -> imports.ImportDict:
|
|
"""Get all the libraries and fields that are used by the component.
|
|
"""Get all the libraries and fields that are used by the component.
|
|
@@ -1416,6 +1425,36 @@ class Component(BaseComponent, ABC):
|
|
**self._get_special_hooks(),
|
|
**self._get_special_hooks(),
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
|
|
|
|
+ """Get the hooks added via `add_hooks` method.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The deduplicated hooks and imports added by the component and parent components.
|
|
|
|
+ """
|
|
|
|
+ code = {}
|
|
|
|
+
|
|
|
|
+ def extract_var_hooks(hook: Var):
|
|
|
|
+ _imports = {}
|
|
|
|
+ if hook._var_data is not None:
|
|
|
|
+ for sub_hook in hook._var_data.hooks:
|
|
|
|
+ code[sub_hook] = {}
|
|
|
|
+ if hook._var_data.imports:
|
|
|
|
+ _imports = hook._var_data.imports
|
|
|
|
+ if str(hook) in code:
|
|
|
|
+ code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
|
|
|
+ else:
|
|
|
|
+ code[str(hook)] = _imports
|
|
|
|
+
|
|
|
|
+ # Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
|
|
|
+ # the order of the hooks in the final output)
|
|
|
|
+ for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
|
|
|
|
+ for hook in clz.add_hooks(self):
|
|
|
|
+ if isinstance(hook, Var):
|
|
|
|
+ extract_var_hooks(hook)
|
|
|
|
+ else:
|
|
|
|
+ code[hook] = {}
|
|
|
|
+ return code
|
|
|
|
+
|
|
def _get_hooks(self) -> str | None:
|
|
def _get_hooks(self) -> str | None:
|
|
"""Get the React hooks for this component.
|
|
"""Get the React hooks for this component.
|
|
|
|
|
|
@@ -1454,11 +1493,7 @@ class Component(BaseComponent, ABC):
|
|
if hooks is not None:
|
|
if hooks is not None:
|
|
code[hooks] = None
|
|
code[hooks] = None
|
|
|
|
|
|
- # Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
|
|
|
- # the order of the hooks in the final output)
|
|
|
|
- for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
|
|
|
|
- for hook in clz.add_hooks(self):
|
|
|
|
- code[hook] = None
|
|
|
|
|
|
+ code.update(self._get_added_hooks())
|
|
|
|
|
|
# Add the hook code for the children.
|
|
# Add the hook code for the children.
|
|
for child in self.children:
|
|
for child in self.children:
|
|
@@ -2092,8 +2127,8 @@ class StatefulComponent(BaseComponent):
|
|
var_deps.extend(cls._get_hook_deps(hook))
|
|
var_deps.extend(cls._get_hook_deps(hook))
|
|
memo_var_data = VarData.merge(
|
|
memo_var_data = VarData.merge(
|
|
*[var._var_data for var in event_args],
|
|
*[var._var_data for var in event_args],
|
|
- VarData( # type: ignore
|
|
|
|
- imports={"react": {ImportVar(tag="useCallback")}},
|
|
|
|
|
|
+ VarData(
|
|
|
|
+ imports={"react": [ImportVar(tag="useCallback")]},
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
|