浏览代码

improve backend var determination (#3587)

* cleanup reserved backend var check

* cleanup backend var determination mess

* add tests for hidden and class methods
benedikt-bartscher 11 月之前
父节点
当前提交
bcc7a61452
共有 4 个文件被更改,包括 48 次插入28 次删除
  1. 4 23
      reflex/state.py
  2. 26 0
      reflex/utils/types.py
  3. 8 4
      tests/test_state.py
  4. 10 1
      tests/utils/test_utils.py

+ 4 - 23
reflex/state.py

@@ -199,13 +199,6 @@ def _no_chain_background_task(
     raise TypeError(f"{fn} is marked as a background task, but is not async.")
     raise TypeError(f"{fn} is marked as a background task, but is not async.")
 
 
 
 
-RESERVED_BACKEND_VAR_NAMES = {
-    "_abc_impl",
-    "_backend_vars",
-    "_was_touched",
-}
-
-
 def _substate_key(
 def _substate_key(
     token: str,
     token: str,
     state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
     state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
@@ -500,10 +493,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             name: value
             name: value
             for name, value in cls.__dict__.items()
             for name, value in cls.__dict__.items()
             if types.is_backend_variable(name, cls)
             if types.is_backend_variable(name, cls)
-            and name not in RESERVED_BACKEND_VAR_NAMES
-            and name not in cls.inherited_backend_vars
-            and not isinstance(value, FunctionType)
-            and not isinstance(value, ComputedVar)
         }
         }
 
 
         # Get backend computed vars
         # Get backend computed vars
@@ -559,12 +548,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                     cls.computed_vars[newcv._var_name] = newcv
                     cls.computed_vars[newcv._var_name] = newcv
                     cls.vars[newcv._var_name] = newcv
                     cls.vars[newcv._var_name] = newcv
                     continue
                     continue
-                if (
-                    types.is_backend_variable(name, cls)
-                    and name not in RESERVED_BACKEND_VAR_NAMES
-                    and name not in cls.inherited_backend_vars
-                    and not isinstance(value, FunctionType)
-                ):
+                if types.is_backend_variable(name, mixin):
                     cls.backend_vars[name] = copy.deepcopy(value)
                     cls.backend_vars[name] = copy.deepcopy(value)
                     continue
                     continue
                 if events.get(name) is not None:
                 if events.get(name) is not None:
@@ -750,7 +734,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 "dirty_substates",
                 "dirty_substates",
                 "router_data",
                 "router_data",
             }
             }
-            | RESERVED_BACKEND_VAR_NAMES
+            | types.RESERVED_BACKEND_VAR_NAMES
         )
         )
 
 
     @classmethod
     @classmethod
@@ -1103,10 +1087,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             setattr(self.parent_state, name, value)
             setattr(self.parent_state, name, value)
             return
             return
 
 
-        if (
-            types.is_backend_variable(name, self.__class__)
-            and name not in RESERVED_BACKEND_VAR_NAMES
-        ):
+        if types.is_backend_variable(name, type(self)):
             self._backend_vars.__setitem__(name, value)
             self._backend_vars.__setitem__(name, value)
             self.dirty_vars.add(name)
             self.dirty_vars.add(name)
             self._mark_dirty()
             self._mark_dirty()
@@ -1617,7 +1598,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         subdelta = {
         subdelta = {
             prop: getattr(self, prop)
             prop: getattr(self, prop)
             for prop in delta_vars
             for prop in delta_vars
-            if not types.is_backend_variable(prop, self.__class__)
+            if not types.is_backend_variable(prop, type(self))
         }
         }
         if len(subdelta) > 0:
         if len(subdelta) > 0:
             delta[self.get_full_name()] = subdelta
             delta[self.get_full_name()] = subdelta

+ 26 - 0
reflex/utils/types.py

@@ -100,6 +100,12 @@ PrimitiveToAnnotation = {
     dict: Dict,
     dict: Dict,
 }
 }
 
 
+RESERVED_BACKEND_VAR_NAMES = {
+    "_abc_impl",
+    "_backend_vars",
+    "_was_touched",
+}
+
 
 
 class Unset:
 class Unset:
     """A class to represent an unset value.
     """A class to represent an unset value.
@@ -414,6 +420,9 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
     Returns:
     Returns:
         bool: The result of the check
         bool: The result of the check
     """
     """
+    if name in RESERVED_BACKEND_VAR_NAMES:
+        return False
+
     if not name.startswith("_"):
     if not name.startswith("_"):
         return False
         return False
 
 
@@ -429,6 +438,23 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
             if hint == ClassVar:
             if hint == ClassVar:
                 return False
                 return False
 
 
+        if name in cls.inherited_backend_vars:
+            return False
+
+        if name in cls.__dict__:
+            value = cls.__dict__[name]
+            if type(value) == classmethod:
+                return False
+            if callable(value):
+                return False
+            if isinstance(value, types.FunctionType):
+                return False
+            # enable after #3573 is merged
+            # from reflex.vars import ComputedVar
+            #
+            # if isinstance(value, ComputedVar):
+            #     return False
+
     return True
     return True
 
 
 
 

+ 8 - 4
tests/test_state.py

@@ -965,7 +965,9 @@ def interdependent_state() -> BaseState:
     return s
     return s
 
 
 
 
-def test_not_dirty_computed_var_from_var(interdependent_state):
+def test_not_dirty_computed_var_from_var(
+    interdependent_state: InterdependentState,
+) -> None:
     """Set Var that no ComputedVar depends on, expect no recalculation.
     """Set Var that no ComputedVar depends on, expect no recalculation.
 
 
     Args:
     Args:
@@ -977,7 +979,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
     }
     }
 
 
 
 
-def test_dirty_computed_var_from_var(interdependent_state):
+def test_dirty_computed_var_from_var(interdependent_state: InterdependentState) -> None:
     """Set Var that ComputedVar depends on, expect recalculation.
     """Set Var that ComputedVar depends on, expect recalculation.
 
 
     The other ComputedVar depends on the changed ComputedVar and should also be
     The other ComputedVar depends on the changed ComputedVar and should also be
@@ -992,7 +994,9 @@ def test_dirty_computed_var_from_var(interdependent_state):
     }
     }
 
 
 
 
-def test_dirty_computed_var_from_backend_var(interdependent_state):
+def test_dirty_computed_var_from_backend_var(
+    interdependent_state: InterdependentState,
+) -> None:
     """Set backend var that ComputedVar depends on, expect recalculation.
     """Set backend var that ComputedVar depends on, expect recalculation.
 
 
     Args:
     Args:
@@ -1005,7 +1009,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
     assert "_v3" in InterdependentState.backend_vars
     assert "_v3" in InterdependentState.backend_vars
 
 
 
 
-def test_per_state_backend_var(interdependent_state):
+def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
     """Set backend var on one instance, expect no affect in other instances.
     """Set backend var on one instance, expect no affect in other instances.
 
 
     Args:
     Args:

+ 10 - 1
tests/utils/test_utils.py

@@ -146,7 +146,7 @@ def test_setup_frontend(tmp_path, mocker):
 
 
 @pytest.fixture
 @pytest.fixture
 def test_backend_variable_cls():
 def test_backend_variable_cls():
-    class TestBackendVariable:
+    class TestBackendVariable(BaseState):
         """Test backend variable."""
         """Test backend variable."""
 
 
         _classvar: ClassVar[int] = 0
         _classvar: ClassVar[int] = 0
@@ -154,6 +154,13 @@ def test_backend_variable_cls():
         not_hidden: int = 0
         not_hidden: int = 0
         __dunderattr__: int = 0
         __dunderattr__: int = 0
 
 
+        @classmethod
+        def _class_method(cls):
+            pass
+
+        def _hidden_method(self):
+            pass
+
     return TestBackendVariable
     return TestBackendVariable
 
 
 
 
@@ -161,6 +168,8 @@ def test_backend_variable_cls():
     "input, output",
     "input, output",
     [
     [
         ("_classvar", False),
         ("_classvar", False),
+        ("_class_method", False),
+        ("_hidden_method", False),
         ("_hidden", True),
         ("_hidden", True),
         ("not_hidden", False),
         ("not_hidden", False),
         ("__dundermethod__", False),
         ("__dundermethod__", False),