瀏覽代碼

fix various places that got the wrong name

Khaleel Al-Adhami 3 周之前
父節點
當前提交
9473762f0d
共有 6 個文件被更改,包括 46 次插入74 次删除
  1. 4 4
      reflex/base.py
  2. 24 19
      reflex/state.py
  3. 14 26
      reflex/vars/base.py
  4. 1 0
      tests/integration/test_background_task.py
  5. 3 1
      tests/units/test_state.py
  6. 0 24
      tests/units/test_var.py

+ 4 - 4
reflex/base.py

@@ -97,24 +97,24 @@ class Base(BaseModel):
         return cls.__fields__
         return cls.__fields__
 
 
     @classmethod
     @classmethod
-    def add_field(cls, var: Var, default_value: Any):
+    def add_field(cls, name: str, var: Var, default_value: Any):
         """Add a pydantic field after class definition.
         """Add a pydantic field after class definition.
 
 
         Used by State.add_var() to correctly handle the new variable.
         Used by State.add_var() to correctly handle the new variable.
 
 
         Args:
         Args:
+            name: The name of the field.
             var: The variable to add a pydantic field for.
             var: The variable to add a pydantic field for.
             default_value: The default value of the field
             default_value: The default value of the field
         """
         """
-        var_name = var._var_field_name
         new_field = ModelField.infer(
         new_field = ModelField.infer(
-            name=var_name,
+            name=name,
             value=default_value,
             value=default_value,
             annotation=var._var_type,
             annotation=var._var_type,
             class_validators=None,
             class_validators=None,
             config=cls.__config__,
             config=cls.__config__,
         )
         )
-        cls.__fields__.update({var_name: new_field})
+        cls.__fields__.update({name: new_field})
 
 
     def get_value(self, key: str) -> Any:
     def get_value(self, key: str) -> Any:
         """Get the value of a field.
         """Get the value of a field.

+ 24 - 19
reflex/state.py

@@ -283,9 +283,8 @@ def get_var_for_field(cls: type[BaseState], f: ModelField) -> Var:
     Returns:
     Returns:
         The Var instance.
         The Var instance.
     """
     """
-    field_name = (
-        format.format_state_name(cls.get_full_name()) + "." + f.name + "_rx_state_"
-    )
+    name = f.name + "_rx_state_"
+    field_name = format.format_state_name(cls.get_full_name()) + "." + name
 
 
     return dispatch(
     return dispatch(
         field_name=field_name,
         field_name=field_name,
@@ -566,8 +565,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         cls.event_handlers = {}
         cls.event_handlers = {}
 
 
         # Setup the base vars at the class level.
         # Setup the base vars at the class level.
-        for prop in cls.base_vars.values():
-            cls._init_var(prop)
+        for name, prop in cls.base_vars.items():
+            cls._init_var(name, prop)
 
 
         # Set up the event handlers.
         # Set up the event handlers.
         events = {
         events = {
@@ -1006,10 +1005,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return getattr(substate, name)
         return getattr(substate, name)
 
 
     @classmethod
     @classmethod
-    def _init_var(cls, prop: Var):
+    def _init_var(cls, name: str, prop: Var):
         """Initialize a variable.
         """Initialize a variable.
 
 
         Args:
         Args:
+            name: The name of the variable
             prop: The variable to initialize
             prop: The variable to initialize
 
 
         Raises:
         Raises:
@@ -1024,9 +1024,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 "dictionaries, dataclasses, datetime objects, and pydantic models. "
                 "dictionaries, dataclasses, datetime objects, and pydantic models. "
                 f'Found var "{prop._js_expr}" with type {prop._var_type}.'
                 f'Found var "{prop._js_expr}" with type {prop._var_type}.'
             )
             )
-        cls._set_var(prop)
-        cls._create_setter(prop)
-        cls._set_default_value(prop)
+        cls._set_var(name, prop)
+        cls._create_setter(name, prop)
+        cls._set_default_value(name, prop)
 
 
     @classmethod
     @classmethod
     def add_var(cls, name: str, type_: Any, default_value: Any = None):
     def add_var(cls, name: str, type_: Any, default_value: Any = None):
@@ -1059,9 +1059,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         ).guess_type()
         ).guess_type()
 
 
         # add the pydantic field dynamically (must be done before _init_var)
         # add the pydantic field dynamically (must be done before _init_var)
-        cls.add_field(var, default_value)
+        cls.add_field(name, var, default_value)
 
 
-        cls._init_var(var)
+        cls._init_var(name, var)
 
 
         # update the internal dicts so the new variable is correctly handled
         # update the internal dicts so the new variable is correctly handled
         cls.base_vars.update({name: var})
         cls.base_vars.update({name: var})
@@ -1075,13 +1075,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         cls._init_var_dependency_dicts()
         cls._init_var_dependency_dicts()
 
 
     @classmethod
     @classmethod
-    def _set_var(cls, prop: Var):
+    def _set_var(cls, name: str, prop: Var):
         """Set the var as a class member.
         """Set the var as a class member.
 
 
         Args:
         Args:
+            name: The name of the var.
             prop: The var instance to set.
             prop: The var instance to set.
         """
         """
-        setattr(cls, prop._var_field_name, prop)
+        setattr(cls, name, prop)
 
 
     @classmethod
     @classmethod
     def _create_event_handler(cls, fn: Any):
     def _create_event_handler(cls, fn: Any):
@@ -1101,27 +1102,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)
         cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls)
 
 
     @classmethod
     @classmethod
-    def _create_setter(cls, prop: Var):
+    def _create_setter(cls, name: str, prop: Var):
         """Create a setter for the var.
         """Create a setter for the var.
 
 
         Args:
         Args:
+            name: The name of the var.
             prop: The var to create a setter for.
             prop: The var to create a setter for.
         """
         """
-        setter_name = prop._get_setter_name(include_state=False)
+        setter_name = Var._get_setter_name_for_name(name)
         if setter_name not in cls.__dict__:
         if setter_name not in cls.__dict__:
-            event_handler = cls._create_event_handler(prop._get_setter())
+            event_handler = cls._create_event_handler(prop._get_setter(name))
             cls.event_handlers[setter_name] = event_handler
             cls.event_handlers[setter_name] = event_handler
             setattr(cls, setter_name, event_handler)
             setattr(cls, setter_name, event_handler)
 
 
     @classmethod
     @classmethod
-    def _set_default_value(cls, prop: Var):
+    def _set_default_value(cls, name: str, prop: Var):
         """Set the default value for the var.
         """Set the default value for the var.
 
 
         Args:
         Args:
+            name: The name of the var.
             prop: The var to set the default value for.
             prop: The var to set the default value for.
         """
         """
         # Get the pydantic field for the var.
         # Get the pydantic field for the var.
-        field = cls.get_fields()[prop._var_field_name]
+        field = cls.get_fields()[name]
         if field.required:
         if field.required:
             default_value = prop._get_default_value()
             default_value = prop._get_default_value()
             if default_value is not None:
             if default_value is not None:
@@ -2075,7 +2078,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             computed_vars = {}
             computed_vars = {}
         variables = {**base_vars, **computed_vars}
         variables = {**base_vars, **computed_vars}
         d = {
         d = {
-            self.get_full_name(): {k: variables[k] for k in sorted(variables)},
+            self.get_full_name(): {
+                k + "_rx_state_": variables[k] for k in sorted(variables)
+            },
         }
         }
         for substate_d in [
         for substate_d in [
             v.dict(include_computed=include_computed, initial=initial, **kwargs)
             v.dict(include_computed=include_computed, initial=initial, **kwargs)

+ 14 - 26
reflex/vars/base.py

@@ -406,17 +406,6 @@ class Var(Generic[VAR_TYPE]):
         """
         """
         return self._js_expr
         return self._js_expr
 
 
-    @property
-    def _var_field_name(self) -> str:
-        """The name of the field.
-
-        Returns:
-            The name of the field.
-        """
-        var_data = self._get_all_var_data()
-        field_name = var_data.field_name if var_data else None
-        return field_name or self._js_expr
-
     @property
     @property
     @deprecated("Use `_js_expr` instead.")
     @deprecated("Use `_js_expr` instead.")
     def _var_name_unwrapped(self) -> str:
     def _var_name_unwrapped(self) -> str:
@@ -976,30 +965,29 @@ class Var(Generic[VAR_TYPE]):
                 ) from e
                 ) from e
         return set() if safe_issubclass(type_, set) else None
         return set() if safe_issubclass(type_, set) else None
 
 
-    def _get_setter_name(self, include_state: bool = True) -> str:
+    @staticmethod
+    def _get_setter_name_for_name(
+        name: str,
+    ) -> str:
         """Get the name of the var's generated setter function.
         """Get the name of the var's generated setter function.
 
 
         Args:
         Args:
-            include_state: Whether to include the state name in the setter name.
+            name: The name of the var.
 
 
         Returns:
         Returns:
             The name of the setter function.
             The name of the setter function.
         """
         """
-        setter = constants.SETTER_PREFIX + self._var_field_name
-        var_data = self._get_all_var_data()
-        if var_data is None:
-            return setter
-        if not include_state or var_data.state == "":
-            return setter
-        return ".".join((var_data.state, setter))
+        return constants.SETTER_PREFIX + name
 
 
-    def _get_setter(self) -> Callable[[BaseState, Any], None]:
+    def _get_setter(self, name: str) -> Callable[[BaseState, Any], None]:
         """Get the var's setter function.
         """Get the var's setter function.
 
 
+        Args:
+            name: The name of the var.
+
         Returns:
         Returns:
             A function that that creates a setter for the var.
             A function that that creates a setter for the var.
         """
         """
-        actual_name = self._var_field_name
 
 
         def setter(state: Any, value: Any):
         def setter(state: Any, value: Any):
             """Get the setter for the var.
             """Get the setter for the var.
@@ -1011,17 +999,17 @@ class Var(Generic[VAR_TYPE]):
             if self._var_type in [int, float]:
             if self._var_type in [int, float]:
                 try:
                 try:
                     value = self._var_type(value)
                     value = self._var_type(value)
-                    setattr(state, actual_name, value)
+                    setattr(state, name, value)
                 except ValueError:
                 except ValueError:
                     console.debug(
                     console.debug(
                         f"{type(state).__name__}.{self._js_expr}: Failed conversion of {value!s} to '{self._var_type.__name__}'. Value not set.",
                         f"{type(state).__name__}.{self._js_expr}: Failed conversion of {value!s} to '{self._var_type.__name__}'. Value not set.",
                     )
                     )
             else:
             else:
-                setattr(state, actual_name, value)
+                setattr(state, name, value)
 
 
         setter.__annotations__["value"] = self._var_type
         setter.__annotations__["value"] = self._var_type
 
 
-        setter.__qualname__ = self._get_setter_name()
+        setter.__qualname__ = Var._get_setter_name_for_name(name)
 
 
         return setter
         return setter
 
 
@@ -2132,7 +2120,7 @@ class ComputedVar(Var[RETURN_TYPE]):
 
 
         if hint is Any:
         if hint is Any:
             raise UntypedComputedVarError(var_name=fget.__name__)
             raise UntypedComputedVarError(var_name=fget.__name__)
-        kwargs.setdefault("_js_expr", fget.__name__)
+        kwargs.setdefault("_js_expr", fget.__name__ + "_rx_state_")
         kwargs.setdefault("_var_type", hint)
         kwargs.setdefault("_var_type", hint)
 
 
         Var.__init__(
         Var.__init__(

+ 1 - 0
tests/integration/test_background_task.py

@@ -241,6 +241,7 @@ def token(background_task: AppHarness, driver: WebDriver) -> str:
         The token for the connected client
         The token for the connected client
     """
     """
     assert background_task.app_instance is not None
     assert background_task.app_instance is not None
+
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
     assert token_input
     assert token_input
 
 

+ 3 - 1
tests/units/test_state.py

@@ -552,7 +552,9 @@ def test_set_class_var():
     """Test setting the var of a class."""
     """Test setting the var of a class."""
     with pytest.raises(AttributeError):
     with pytest.raises(AttributeError):
         TestState.num3  # pyright: ignore [reportAttributeAccessIssue]
         TestState.num3  # pyright: ignore [reportAttributeAccessIssue]
-    TestState._set_var(Var(_js_expr="num3", _var_type=int)._var_set_state(TestState))
+    TestState._set_var(
+        "num3", Var(_js_expr="num3", _var_type=int)._var_set_state(TestState)
+    )
     var = TestState.num3  # pyright: ignore [reportAttributeAccessIssue]
     var = TestState.num3  # pyright: ignore [reportAttributeAccessIssue]
     assert var._js_expr == TestState.get_full_name() + ".num3"
     assert var._js_expr == TestState.get_full_name() + ".num3"
     assert var._var_type is int
     assert var._var_type is int

+ 0 - 24
tests/units/test_var.py

@@ -244,30 +244,6 @@ def test_default_value(prop: Var, expected):
     assert prop._get_default_value() == expected
     assert prop._get_default_value() == expected
 
 
 
 
-@pytest.mark.parametrize(
-    "prop,expected",
-    zip(
-        test_vars,
-        [
-            "set_prop1",
-            "set_key",
-            "state.set_value",
-            "state.set_local",
-            "set_local2",
-        ],
-        strict=True,
-    ),
-)
-def test_get_setter(prop: Var, expected):
-    """Test that the name of the setter function of a var is correct.
-
-    Args:
-        prop: The var to test.
-        expected: The expected name of the setter function.
-    """
-    assert prop._get_setter_name() == expected
-
-
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "value,expected",
     "value,expected",
     [
     [