Bläddra i källkod

Dynamic route vars silently shadow all other vars (#3805)

* fix dynamic route vars silently shadow all other vars

* add test

* fix: allow multiple dynamic routes with the same arg

* add test for multiple dynamic args with the same name

* avoid side-effects with DynamicState tests

* fix dynamic route integration test which shadowed a var

* fix darglint

* refactor to DynamicRouteVar

* old typing stuff again

* from typing_extensions import Self

try to keep typing backward compatible with older releases we support

* Raise a specific exception when encountering dynamic route arg shadowing

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
benedikt-bartscher 8 månader sedan
förälder
incheckning
63bf1b8817
5 ändrade filer med 110 tillägg och 11 borttagningar
  1. 9 3
      reflex/ivars/base.py
  2. 36 5
      reflex/state.py
  3. 4 0
      reflex/utils/exceptions.py
  4. 5 0
      reflex/utils/types.py
  5. 56 3
      tests/test_app.py

+ 9 - 3
reflex/ivars/base.py

@@ -43,7 +43,7 @@ from reflex.utils.exceptions import (
     VarValueError,
 )
 from reflex.utils.format import format_state_name
-from reflex.utils.types import GenericType, get_origin
+from reflex.utils.types import GenericType, Self, get_origin
 from reflex.vars import (
     REPLACED_NAMES,
     Var,
@@ -1467,7 +1467,7 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
         object.__setattr__(self, "_fget", fget)
 
     @override
-    def _replace(self, merge_var_data=None, **kwargs: Any) -> ImmutableComputedVar:
+    def _replace(self, merge_var_data=None, **kwargs: Any) -> Self:
         """Replace the attributes of the ComputedVar.
 
         Args:
@@ -1499,7 +1499,7 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
             unexpected_kwargs = ", ".join(kwargs.keys())
             raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
 
-        return ImmutableComputedVar(**field_values)
+        return type(self)(**field_values)
 
     @property
     def _cache_attr(self) -> str:
@@ -1773,6 +1773,12 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
         return self._fget
 
 
+class DynamicRouteVar(ImmutableComputedVar[Union[str, List[str]]]):
+    """A ComputedVar that represents a dynamic route."""
+
+    pass
+
+
 if TYPE_CHECKING:
     BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
 

+ 36 - 5
reflex/state.py

@@ -35,6 +35,7 @@ from sqlalchemy.orm import DeclarativeBase
 
 from reflex.config import get_config
 from reflex.ivars.base import (
+    DynamicRouteVar,
     ImmutableComputedVar,
     ImmutableVar,
     immutable_computed_var,
@@ -60,7 +61,11 @@ from reflex.event import (
     fix_events,
 )
 from reflex.utils import console, format, path_ops, prerequisites, types
-from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
+from reflex.utils.exceptions import (
+    DynamicRouteArgShadowsStateVar,
+    ImmutableStateError,
+    LockExpiredError,
+)
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import SerializedType, serialize, serializer
 from reflex.utils.types import override
@@ -1023,17 +1028,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         if not args:
             return
 
+        cls._check_overwritten_dynamic_args(list(args.keys()))
+
         def argsingle_factory(param):
             def inner_func(self) -> str:
                 return self.router.page.params.get(param, "")
 
-            return ImmutableComputedVar(fget=inner_func, cache=True)
+            return DynamicRouteVar(fget=inner_func, cache=True)
 
         def arglist_factory(param):
-            def inner_func(self) -> List:
+            def inner_func(self) -> List[str]:
                 return self.router.page.params.get(param, [])
 
-            return ImmutableComputedVar(fget=inner_func, cache=True)
+            return DynamicRouteVar(fget=inner_func, cache=True)
 
         for param, value in args.items():
             if value == constants.RouteArgType.SINGLE:
@@ -1044,12 +1051,36 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 continue
             # to allow passing as a prop, evade python frozen rules (bad practice)
             object.__setattr__(func, "_var_name", param)
-            cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls)  # type: ignore
+            # cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls)  # type: ignore
+            cls.vars[param] = cls.computed_vars[param] = func._replace(
+                _var_data=VarData.from_state(cls)
+            )
             setattr(cls, param, func)
 
         # Reinitialize dependency tracking dicts.
         cls._init_var_dependency_dicts()
 
+    @classmethod
+    def _check_overwritten_dynamic_args(cls, args: list[str]):
+        """Check if dynamic args are shadowing existing vars. Recursively checks all child states.
+
+        Args:
+            args: a dict of args
+
+        Raises:
+            DynamicRouteArgShadowsStateVar: If a dynamic arg is shadowing an existing var.
+        """
+        for arg in args:
+            if (
+                arg in cls.computed_vars
+                and not isinstance(cls.computed_vars[arg], DynamicRouteVar)
+            ) or arg in cls.base_vars:
+                raise DynamicRouteArgShadowsStateVar(
+                    f"Dynamic route arg '{arg}' is shadowing an existing var in {cls.__module__}.{cls.__name__}"
+                )
+        for substate in cls.get_substates():
+            substate._check_overwritten_dynamic_args(args)
+
     def __getattribute__(self, name: str) -> Any:
         """Get the state var.
 

+ 4 - 0
reflex/utils/exceptions.py

@@ -87,3 +87,7 @@ class EventHandlerArgMismatch(ReflexError, TypeError):
 
 class EventFnArgMismatch(ReflexError, TypeError):
     """Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
+
+
+class DynamicRouteArgShadowsStateVar(ReflexError, NameError):
+    """Raised when a dynamic route arg shadows a state var."""

+ 5 - 0
reflex/utils/types.py

@@ -111,6 +111,11 @@ RESERVED_BACKEND_VAR_NAMES = {
     "_was_touched",
 }
 
+if sys.version_info >= (3, 11):
+    from typing import Self as Self
+else:
+    from typing_extensions import Self as Self
+
 
 class Unset:
     """A class to represent an unset value.

+ 56 - 3
tests/test_app.py

@@ -69,7 +69,7 @@ class EmptyState(BaseState):
 
 
 @pytest.fixture
-def index_page():
+def index_page() -> ComponentCallable:
     """An index page.
 
     Returns:
@@ -83,7 +83,7 @@ def index_page():
 
 
 @pytest.fixture
-def about_page():
+def about_page() -> ComponentCallable:
     """An about page.
 
     Returns:
@@ -919,9 +919,62 @@ class DynamicState(BaseState):
     on_load_internal = OnLoadInternalState.on_load_internal.fn
 
 
+def test_dynamic_arg_shadow(
+    index_page: ComponentCallable,
+    windows_platform: bool,
+    token: str,
+    app_module_mock: unittest.mock.Mock,
+    mocker,
+):
+    """Create app with dynamic route var and try to add a page with a dynamic arg that shadows a state var.
+
+    Args:
+        index_page: The index page.
+        windows_platform: Whether the system is windows.
+        token: a Token.
+        app_module_mock: Mocked app module.
+        mocker: pytest mocker object.
+    """
+    arg_name = "counter"
+    route = f"/test/[{arg_name}]"
+    if windows_platform:
+        route.lstrip("/").replace("/", "\\")
+    app = app_module_mock.app = App(state=DynamicState)
+    assert app.state is not None
+    with pytest.raises(NameError):
+        app.add_page(index_page, route=route, on_load=DynamicState.on_load)  # type: ignore
+
+
+def test_multiple_dynamic_args(
+    index_page: ComponentCallable,
+    windows_platform: bool,
+    token: str,
+    app_module_mock: unittest.mock.Mock,
+    mocker,
+):
+    """Create app with multiple dynamic route vars with the same name.
+
+    Args:
+        index_page: The index page.
+        windows_platform: Whether the system is windows.
+        token: a Token.
+        app_module_mock: Mocked app module.
+        mocker: pytest mocker object.
+    """
+    arg_name = "my_arg"
+    route = f"/test/[{arg_name}]"
+    route2 = f"/test2/[{arg_name}]"
+    if windows_platform:
+        route = route.lstrip("/").replace("/", "\\")
+        route2 = route2.lstrip("/").replace("/", "\\")
+    app = app_module_mock.app = App(state=EmptyState)
+    app.add_page(index_page, route=route)
+    app.add_page(index_page, route=route2)
+
+
 @pytest.mark.asyncio
 async def test_dynamic_route_var_route_change_completed_on_load(
-    index_page,
+    index_page: ComponentCallable,
     windows_platform: bool,
     token: str,
     app_module_mock: unittest.mock.Mock,