瀏覽代碼

properly replace ComputedVars (#3254)

* properly replace ComputedVars

* provide dummy override decorator for older python versions

* adjust pyi

* fix darglint

* cleanup var_data for computed vars inherited from mixin

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 1 年之前
父節點
當前提交
48c504666e
共有 4 個文件被更改,包括 54 次插入7 次删除
  1. 3 1
      reflex/state.py
  2. 16 0
      reflex/utils/types.py
  3. 34 6
      reflex/vars.py
  4. 1 0
      reflex/vars.pyi

+ 3 - 1
reflex/state.py

@@ -550,7 +550,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for name, value in mixin.__dict__.items():
                 if isinstance(value, ComputedVar):
                     fget = cls._copy_fn(value.fget)
-                    newcv = ComputedVar(fget=fget, _var_name=value._var_name)
+                    newcv = value._replace(fget=fget)
+                    # cleanup refs to mixin cls in var_data
+                    newcv._var_data = None
                     newcv._var_set_state(cls)
                     setattr(cls, name, newcv)
                     cls.computed_vars[newcv._var_name] = newcv

+ 16 - 0
reflex/utils/types.py

@@ -44,6 +44,22 @@ from reflex import constants
 from reflex.base import Base
 from reflex.utils import console, serializers
 
+if sys.version_info >= (3, 12):
+    from typing import override
+else:
+
+    def override(func: Callable) -> Callable:
+        """Fallback for @override decorator.
+
+        Args:
+            func: The function to decorate.
+
+        Returns:
+            The unmodified function.
+        """
+        return func
+
+
 # Potential GenericAlias types for isinstance checks.
 GenericAliasTypes = [_GenericAlias]
 

+ 34 - 6
reflex/vars.py

@@ -39,10 +39,12 @@ from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueErr
 
 # This module used to export ImportVar itself, so we still import it for export here
 from reflex.utils.imports import ImportDict, ImportVar
+from reflex.utils.types import override
 
 if TYPE_CHECKING:
     from reflex.state import BaseState
 
+
 # Set of unique variable names.
 USED_VARIABLES = set()
 
@@ -832,19 +834,19 @@ class Var:
                 if invoke_fn:
                     # invoke the function on left operand.
                     operation_name = (
-                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"
-                    )  # type: ignore
+                        f"{left_operand_full_name}.{fn}({right_operand_full_name})"  # type: ignore
+                    )
                 else:
                     # pass the operands as arguments to the function.
                     operation_name = (
-                        f"{left_operand_full_name} {op} {right_operand_full_name}"
-                    )  # type: ignore
+                        f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                    )
                     operation_name = f"{fn}({operation_name})"
             else:
                 # apply operator to operands (left operand <operator> right_operand)
                 operation_name = (
-                    f"{left_operand_full_name} {op} {right_operand_full_name}"
-                )  # type: ignore
+                    f"{left_operand_full_name} {op} {right_operand_full_name}"  # type: ignore
+                )
                 operation_name = format.wrap(operation_name, "(")
         else:
             # apply operator to left operand (<operator> left_operand)
@@ -1882,6 +1884,32 @@ class ComputedVar(Var, property):
         kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
         BaseVar.__init__(self, **kwargs)  # type: ignore
 
+    @override
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar:
+        """Replace the attributes of the ComputedVar.
+
+        Args:
+            merge_var_data: VarData to merge into the existing VarData.
+            **kwargs: Var fields to update.
+
+        Returns:
+            The new ComputedVar instance.
+        """
+        return ComputedVar(
+            fget=kwargs.get("fget", self.fget),
+            initial_value=kwargs.get("initial_value", self._initial_value),
+            cache=kwargs.get("cache", self._cache),
+            _var_name=kwargs.get("_var_name", self._var_name),
+            _var_type=kwargs.get("_var_type", self._var_type),
+            _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
+            _var_is_string=kwargs.get("_var_is_string", self._var_is_string),
+            _var_full_name_needs_state_prefix=kwargs.get(
+                "_var_full_name_needs_state_prefix",
+                self._var_full_name_needs_state_prefix,
+            ),
+            _var_data=VarData.merge(self._var_data, merge_var_data),
+        )
+
     @property
     def _cache_attr(self) -> str:
         """Get the attribute used to cache the value on the instance.

+ 1 - 0
reflex/vars.pyi

@@ -139,6 +139,7 @@ class ComputedVar(Var):
     def _cache_attr(self) -> str: ...
     def __get__(self, instance, owner): ...
     def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
     def mark_dirty(self, instance) -> None: ...
     def _determine_var_type(self) -> Type: ...
     @overload