浏览代码

improve backend var determination (#3587)

* cleanup reserved backend var check

* cleanup backend var determination mess

* add tests for hidden and class methods
benedikt-bartscher 11 月之前
父节点
当前提交
bcc7a61452
共有 4 个文件被更改,包括 48 次插入28 次删除
  1. 4 23
      reflex/state.py
  2. 26 0
      reflex/utils/types.py
  3. 8 4
      tests/test_state.py
  4. 10 1
      tests/utils/test_utils.py

+ 4 - 23
reflex/state.py

@@ -199,13 +199,6 @@ def _no_chain_background_task(
     raise TypeError(f"{fn} is marked as a background task, but is not async.")
 
 
-RESERVED_BACKEND_VAR_NAMES = {
-    "_abc_impl",
-    "_backend_vars",
-    "_was_touched",
-}
-
-
 def _substate_key(
     token: str,
     state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
@@ -500,10 +493,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             name: value
             for name, value in cls.__dict__.items()
             if types.is_backend_variable(name, cls)
-            and name not in RESERVED_BACKEND_VAR_NAMES
-            and name not in cls.inherited_backend_vars
-            and not isinstance(value, FunctionType)
-            and not isinstance(value, ComputedVar)
         }
 
         # Get backend computed vars
@@ -559,12 +548,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     cls.computed_vars[newcv._var_name] = newcv
                     cls.vars[newcv._var_name] = newcv
                     continue
-                if (
-                    types.is_backend_variable(name, cls)
-                    and name not in RESERVED_BACKEND_VAR_NAMES
-                    and name not in cls.inherited_backend_vars
-                    and not isinstance(value, FunctionType)
-                ):
+                if types.is_backend_variable(name, mixin):
                     cls.backend_vars[name] = copy.deepcopy(value)
                     continue
                 if events.get(name) is not None:
@@ -750,7 +734,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 "dirty_substates",
                 "router_data",
             }
-            | RESERVED_BACKEND_VAR_NAMES
+            | types.RESERVED_BACKEND_VAR_NAMES
         )
 
     @classmethod
@@ -1103,10 +1087,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             setattr(self.parent_state, name, value)
             return
 
-        if (
-            types.is_backend_variable(name, self.__class__)
-            and name not in RESERVED_BACKEND_VAR_NAMES
-        ):
+        if types.is_backend_variable(name, type(self)):
             self._backend_vars.__setitem__(name, value)
             self.dirty_vars.add(name)
             self._mark_dirty()
@@ -1617,7 +1598,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         subdelta = {
             prop: getattr(self, prop)
             for prop in delta_vars
-            if not types.is_backend_variable(prop, self.__class__)
+            if not types.is_backend_variable(prop, type(self))
         }
         if len(subdelta) > 0:
             delta[self.get_full_name()] = subdelta

+ 26 - 0
reflex/utils/types.py

@@ -100,6 +100,12 @@ PrimitiveToAnnotation = {
     dict: Dict,
 }
 
+RESERVED_BACKEND_VAR_NAMES = {
+    "_abc_impl",
+    "_backend_vars",
+    "_was_touched",
+}
+
 
 class Unset:
     """A class to represent an unset value.
@@ -414,6 +420,9 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
     Returns:
         bool: The result of the check
     """
+    if name in RESERVED_BACKEND_VAR_NAMES:
+        return False
+
     if not name.startswith("_"):
         return False
 
@@ -429,6 +438,23 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
             if hint == ClassVar:
                 return False
 
+        if name in cls.inherited_backend_vars:
+            return False
+
+        if name in cls.__dict__:
+            value = cls.__dict__[name]
+            if type(value) == classmethod:
+                return False
+            if callable(value):
+                return False
+            if isinstance(value, types.FunctionType):
+                return False
+            # enable after #3573 is merged
+            # from reflex.vars import ComputedVar
+            #
+            # if isinstance(value, ComputedVar):
+            #     return False
+
     return True
 
 

+ 8 - 4
tests/test_state.py

@@ -965,7 +965,9 @@ def interdependent_state() -> BaseState:
     return s
 
 
-def test_not_dirty_computed_var_from_var(interdependent_state):
+def test_not_dirty_computed_var_from_var(
+    interdependent_state: InterdependentState,
+) -> None:
     """Set Var that no ComputedVar depends on, expect no recalculation.
 
     Args:
@@ -977,7 +979,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
     }
 
 
-def test_dirty_computed_var_from_var(interdependent_state):
+def test_dirty_computed_var_from_var(interdependent_state: InterdependentState) -> None:
     """Set Var that ComputedVar depends on, expect recalculation.
 
     The other ComputedVar depends on the changed ComputedVar and should also be
@@ -992,7 +994,9 @@ def test_dirty_computed_var_from_var(interdependent_state):
     }
 
 
-def test_dirty_computed_var_from_backend_var(interdependent_state):
+def test_dirty_computed_var_from_backend_var(
+    interdependent_state: InterdependentState,
+) -> None:
     """Set backend var that ComputedVar depends on, expect recalculation.
 
     Args:
@@ -1005,7 +1009,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
     assert "_v3" in InterdependentState.backend_vars
 
 
-def test_per_state_backend_var(interdependent_state):
+def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
     """Set backend var on one instance, expect no affect in other instances.
 
     Args:

+ 10 - 1
tests/utils/test_utils.py

@@ -146,7 +146,7 @@ def test_setup_frontend(tmp_path, mocker):
 
 @pytest.fixture
 def test_backend_variable_cls():
-    class TestBackendVariable:
+    class TestBackendVariable(BaseState):
         """Test backend variable."""
 
         _classvar: ClassVar[int] = 0
@@ -154,6 +154,13 @@ def test_backend_variable_cls():
         not_hidden: int = 0
         __dunderattr__: int = 0
 
+        @classmethod
+        def _class_method(cls):
+            pass
+
+        def _hidden_method(self):
+            pass
+
     return TestBackendVariable
 
 
@@ -161,6 +168,8 @@ def test_backend_variable_cls():
     "input, output",
     [
         ("_classvar", False),
+        ("_class_method", False),
+        ("_hidden_method", False),
         ("_hidden", True),
         ("not_hidden", False),
         ("__dundermethod__", False),