浏览代码

Get default for backend var defined in mixin (#4060)

* Get default for backend var defined in mixin

If the backend var is defined in a mixin class, it won't appear in
`cls.__dict__`, but the value is still retrievable via `getattr` on `cls`.
Prefer to use the actual defined default before using
`Var.get_default_value()`.

If `Var.get_default_value()` fails, set the default to `None` such that the
backend var still gets recognized as a backend var when it is used on `self`.

----

Update test_component_state to include backend vars

Extra coverage for backend vars with and without defaults, defined in a
ComponentState/mixin class.

* fix integration test
Masen Furer 7 月之前
父节点
当前提交
08a493882d
共有 2 个文件被更改,包括 67 次插入3 次删除
  1. 22 2
      reflex/state.py
  2. 45 1
      tests/integration/test_component_state.py

+ 22 - 2
reflex/state.py

@@ -461,10 +461,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             for name, value in cls.__dict__.items()
             if types.is_backend_base_variable(name, cls)
         }
-        # Add annotated backend vars that do not have a default value.
+        # Add annotated backend vars that may not have a default value.
         new_backend_vars.update(
             {
-                name: Var("", _var_type=annotation_value).get_default_value()
+                name: cls._get_var_default(name, annotation_value)
                 for name, annotation_value in get_type_hints(cls).items()
                 if name not in new_backend_vars
                 and types.is_backend_base_variable(name, cls)
@@ -990,6 +990,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             # Ensure frontend uses null coalescing when accessing.
             object.__setattr__(prop, "_var_type", Optional[prop._var_type])
 
+    @classmethod
+    def _get_var_default(cls, name: str, annotation_value: Any) -> Any:
+        """Get the default value of a (backend) var.
+
+        Args:
+            name: The name of the var.
+            annotation_value: The annotation value of the var.
+
+        Returns:
+            The default value of the var or None.
+        """
+        try:
+            return getattr(cls, name)
+        except AttributeError:
+            try:
+                return Var("", _var_type=annotation_value).get_default_value()
+            except TypeError:
+                pass
+        return None
+
     @staticmethod
     def _get_base_functions() -> dict[str, FunctionType]:
         """Get all functions of the state class excluding dunder methods.

+ 45 - 1
tests/integration/test_component_state.py

@@ -5,6 +5,7 @@ from typing import Generator
 import pytest
 from selenium.webdriver.common.by import By
 
+from reflex.state import State, _substate_key
 from reflex.testing import AppHarness
 
 from . import utils
@@ -12,13 +13,21 @@ from . import utils
 
 def ComponentStateApp():
     """App using per component state."""
+    from typing import Generic, TypeVar
+
     import reflex as rx
 
-    class MultiCounter(rx.ComponentState):
+    E = TypeVar("E")
+
+    class MultiCounter(rx.ComponentState, Generic[E]):
         count: int = 0
+        _be: E
+        _be_int: int
+        _be_str: str = "42"
 
         def increment(self):
             self.count += 1
+            self._be = self.count  # type: ignore
 
         @classmethod
         def get_component(cls, *children, **props):
@@ -48,6 +57,14 @@ def ComponentStateApp():
                 on_click=mc_a.State.increment,  # type: ignore
                 id="inc-a",
             ),
+            rx.text(
+                mc_a.State.get_name() if mc_a.State is not None else "",
+                id="a_state_name",
+            ),
+            rx.text(
+                mc_b.State.get_name() if mc_b.State is not None else "",
+                id="b_state_name",
+            ),
         )
 
 
@@ -80,6 +97,7 @@ async def test_component_state_app(component_state_app: AppHarness):
 
     ss = utils.SessionStorage(driver)
     assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
+    root_state_token = _substate_key(ss.get("token"), State)
 
     count_a = driver.find_element(By.ID, "count-a")
     count_b = driver.find_element(By.ID, "count-b")
@@ -87,6 +105,18 @@ async def test_component_state_app(component_state_app: AppHarness):
     button_b = driver.find_element(By.ID, "button-b")
     button_inc_a = driver.find_element(By.ID, "inc-a")
 
+    # Check that backend vars in mixins are okay
+    a_state_name = driver.find_element(By.ID, "a_state_name").text
+    b_state_name = driver.find_element(By.ID, "b_state_name").text
+    root_state = await component_state_app.get_state(root_state_token)
+    a_state = root_state.substates[a_state_name]
+    b_state = root_state.substates[b_state_name]
+    assert a_state._backend_vars == a_state.backend_vars
+    assert a_state._backend_vars == b_state._backend_vars
+    assert a_state._backend_vars["_be"] is None
+    assert a_state._backend_vars["_be_int"] == 0
+    assert a_state._backend_vars["_be_str"] == "42"
+
     assert count_a.text == "0"
 
     button_a.click()
@@ -98,6 +128,14 @@ async def test_component_state_app(component_state_app: AppHarness):
     button_inc_a.click()
     assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
 
+    root_state = await component_state_app.get_state(root_state_token)
+    a_state = root_state.substates[a_state_name]
+    b_state = root_state.substates[b_state_name]
+    assert a_state._backend_vars != a_state.backend_vars
+    assert a_state._be == a_state._backend_vars["_be"] == 3
+    assert b_state._be is None
+    assert b_state._backend_vars["_be"] is None
+
     assert count_b.text == "0"
 
     button_b.click()
@@ -105,3 +143,9 @@ async def test_component_state_app(component_state_app: AppHarness):
 
     button_b.click()
     assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
+
+    root_state = await component_state_app.get_state(root_state_token)
+    a_state = root_state.substates[a_state_name]
+    b_state = root_state.substates[b_state_name]
+    assert b_state._backend_vars != b_state.backend_vars
+    assert b_state._be == b_state._backend_vars["_be"] == 2