Bläddra i källkod

deprecate get_ methods for router_data, use BaseVars instead (#1967)

Thomas Brandého 1 år sedan
förälder
incheckning
b1bab1206d
7 ändrade filer med 219 tillägg och 16 borttagningar
  1. 2 0
      reflex/app.py
  2. 4 2
      reflex/constants/__init__.py
  3. 1 0
      reflex/constants/route.py
  4. 136 6
      reflex/state.py
  5. 7 6
      tests/test_app.py
  6. 40 2
      tests/test_state.py
  7. 29 0
      tests/utils/test_format.py

+ 2 - 0
reflex/app.py

@@ -54,6 +54,7 @@ from reflex.route import (
 )
 from reflex.state import (
     DefaultState,
+    RouterData,
     State,
     StateManager,
     StateManagerMemory,
@@ -803,6 +804,7 @@ async def process(
             # assignment will recurse into substates and force recalculation of
             # dependent ComputedVar (dynamic route variables)
             state.router_data = router_data
+            state.router = RouterData(router_data)
 
         # Preprocess the event.
         update = await app.preprocess(state, event)

+ 4 - 2
reflex/constants/__init__.py

@@ -39,6 +39,7 @@ from .installer import (
 )
 from .route import (
     ROUTE_NOT_FOUND,
+    ROUTER,
     ROUTER_DATA,
     DefaultPage,
     Page404,
@@ -77,9 +78,10 @@ __ALL__ = [
     PYTEST_CURRENT_TEST,
     PRODUCTION_BACKEND_URL,
     Reflex,
-    RouteVar,
-    RouteRegex,
     RouteArgType,
+    RouteRegex,
+    RouteVar,
+    ROUTER,
     ROUTER_DATA,
     ROUTE_NOT_FOUND,
     SETTER_PREFIX,

+ 1 - 0
reflex/constants/route.py

@@ -13,6 +13,7 @@ class RouteArgType(SimpleNamespace):
 
 
 # the name of the backend var containing path and client information
+ROUTER = "router"
 ROUTER_DATA = "router_data"
 
 

+ 136 - 6
reflex/state.py

@@ -48,6 +48,99 @@ from reflex.vars import BaseVar, ComputedVar, Var
 Delta = Dict[str, Any]
 
 
+class HeaderData(Base):
+    """An object containing headers data."""
+
+    host: str = ""
+    origin: str = ""
+    upgrade: str = ""
+    connection: str = ""
+    pragma: str = ""
+    cache_control: str = ""
+    user_agent: str = ""
+    sec_websocket_version: str = ""
+    sec_websocket_key: str = ""
+    sec_websocket_extensions: str = ""
+    accept_encoding: str = ""
+    accept_language: str = ""
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the HeaderData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
+                setattr(self, format.to_snake_case(k), v)
+
+
+class PageData(Base):
+    """An object containing page data."""
+
+    host: str = ""  #  repeated with self.headers.origin (remove or keep the duplicate?)
+    path: str = ""
+    raw_path: str = ""
+    full_path: str = ""
+    full_raw_path: str = ""
+    params: dict = {}
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the PageData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin")
+            self.path = router_data.get(constants.RouteVar.PATH, "")
+            self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "")
+            self.full_path = f"{self.host}{self.path}"
+            self.full_raw_path = f"{self.host}{self.raw_path}"
+            self.params = router_data.get(constants.RouteVar.QUERY, {})
+
+
+class SessionData(Base):
+    """An object containing session data."""
+
+    client_token: str = ""
+    client_ip: str = ""
+    session_id: str = ""
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the SessionData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
+            self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
+            self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
+
+
+class RouterData(Base):
+    """An object containing RouterData."""
+
+    session: SessionData = SessionData()
+    headers: HeaderData = HeaderData()
+    page: PageData = PageData()
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initialize the RouterData object.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        self.session = SessionData(router_data)
+        self.headers = HeaderData(router_data)
+        self.page = PageData(router_data)
+
+
 class State(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
 
@@ -96,6 +189,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # Per-instance copy of backend variable values
     _backend_vars: Dict[str, Any] = {}
 
+    # The router data for the current page
+    router: RouterData = RouterData()
+
     def __init__(self, *args, parent_state: State | None = None, **kwargs):
         """Initialize the state.
 
@@ -494,6 +590,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The token of the client.
         """
+        console.deprecate(
+            feature_name="get_token",
+            reason="replaced by `State.router.session.client_token`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
 
     def get_sid(self) -> str:
@@ -502,6 +604,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The session ID of the client.
         """
+        console.deprecate(
+            feature_name="get_sid",
+            reason="replaced by `State.router.session.session_id`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.SESSION_ID, "")
 
     def get_headers(self) -> Dict:
@@ -510,6 +618,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The headers of the client.
         """
+        console.deprecate(
+            feature_name="get_headers",
+            reason="replaced by `State.router.headers`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.HEADERS, {})
 
     def get_client_ip(self) -> str:
@@ -518,6 +632,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The IP of the client.
         """
+        console.deprecate(
+            feature_name="get_client_ip",
+            reason="replaced by `State.router.session.client_ip`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
 
     def get_current_page(self, origin=False) -> str:
@@ -529,10 +649,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The current page.
         """
-        if origin:
-            return self.router_data.get(constants.RouteVar.ORIGIN, "")
-        else:
-            return self.router_data.get(constants.RouteVar.PATH, "")
+        console.deprecate(
+            feature_name="get_current_page",
+            reason="replaced by State.router.page / self.router.page",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
+
+        return self.router.page.raw_path if origin else self.router.page.path
 
     def get_query_params(self) -> dict[str, str]:
         """Obtain the query parameters for the queried page.
@@ -542,6 +666,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
             The dict of query parameters.
         """
+        console.deprecate(
+            feature_name="get_query_params",
+            reason="replaced by `State.router.page.params`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.QUERY, {})
 
     def get_cookies(self) -> dict[str, str]:
@@ -583,14 +713,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         def argsingle_factory(param):
             @ComputedVar
             def inner_func(self) -> str:
-                return self.get_query_params().get(param, "")
+                return self.router.page.params.get(param, "")
 
             return inner_func
 
         def arglist_factory(param):
             @ComputedVar
             def inner_func(self) -> List:
-                return self.get_query_params().get(param, [])
+                return self.router.page.params.get(param, [])
 
             return inner_func
 

+ 7 - 6
tests/test_app.py

@@ -34,7 +34,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
-from reflex.state import State, StateManagerRedis, StateUpdate
+from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.utils import format
 from reflex.vars import ComputedVar
@@ -255,9 +255,9 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool
     assert set(app.pages.keys()) == {"test/[dynamic]"}
     assert "dynamic" in app.state.computed_vars
     assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == {
-        constants.ROUTER_DATA
+        constants.ROUTER
     }
-    assert constants.ROUTER_DATA in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state().computed_var_dependencies
 
 
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -874,9 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     assert arg_name in app.state.vars
     assert arg_name in app.state.computed_vars
     assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
-        constants.ROUTER_DATA
+        constants.ROUTER
     }
-    assert constants.ROUTER_DATA in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state().computed_var_dependencies
 
     sid = "mock_sid"
     client_ip = "127.0.0.1"
@@ -912,6 +912,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             "token": token,
             **hydrate_event.router_data,
         }
+        exp_router = RouterData(exp_router_data)
         process_coro = process(
             app,
             event=hydrate_event,
@@ -920,7 +921,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             client_ip=client_ip,
         )
         update = await process_coro.__anext__()  # type: ignore
-
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
             delta={
@@ -930,6 +930,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     constants.CompileVars.IS_HYDRATED: False,
                     "loaded": exp_index,
                     "counter": exp_index,
+                    "router": exp_router,
                     # "side_effect_counter": exp_index,
                 }
             },

+ 40 - 2
tests/test_state.py

@@ -22,6 +22,7 @@ from reflex.state import (
     ImmutableStateError,
     LockExpiredError,
     MutableProxy,
+    RouterData,
     State,
     StateManager,
     StateManagerMemory,
@@ -40,6 +41,33 @@ LOCK_EXPIRATION = 2000 if CI else 100
 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
 
 
+formatted_router = {
+    "session": {"client_token": "", "client_ip": "", "session_id": ""},
+    "headers": {
+        "host": "",
+        "origin": "",
+        "upgrade": "",
+        "connection": "",
+        "pragma": "",
+        "cache_control": "",
+        "user_agent": "",
+        "sec_websocket_version": "",
+        "sec_websocket_key": "",
+        "sec_websocket_extensions": "",
+        "accept_encoding": "",
+        "accept_language": "",
+    },
+    "page": {
+        "host": "",
+        "path": "",
+        "raw_path": "",
+        "full_path": "",
+        "full_raw_path": "",
+        "params": {},
+    },
+}
+
+
 class Object(Base):
     """A test object fixture."""
 
@@ -226,6 +254,7 @@ def test_class_vars(test_state):
     cls = type(test_state)
     assert set(cls.vars.keys()) == {
         CompileVars.IS_HYDRATED,  # added by hydrate_middleware to all State
+        "router",
         "num1",
         "num2",
         "key",
@@ -614,6 +643,7 @@ def test_reset(test_state, child_state):
         "map_key",
         "mapping",
         "dt",
+        "router",
     }
 
     # The dirty vars should be reset.
@@ -787,7 +817,7 @@ def test_get_current_page(test_state):
     assert test_state.get_current_page() == ""
 
     route = "mypage/subpage"
-    test_state.router_data = {RouteVar.PATH: route}
+    test_state.router = RouterData({RouteVar.PATH: route})
 
     assert test_state.get_current_page() == route
 
@@ -1131,16 +1161,19 @@ def test_computed_var_depends_on_parent_non_cached():
         cs.get_name(): {"dep_v": 2},
         "no_cache_v": 1,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     assert ps.dict() == {
         cs.get_name(): {"dep_v": 4},
         "no_cache_v": 3,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     assert ps.dict() == {
         cs.get_name(): {"dep_v": 6},
         "no_cache_v": 5,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     assert counter == 6
 
@@ -2114,7 +2147,12 @@ def test_json_dumps_with_mutables():
     dict_val = MutableContainsBase().dict()
     assert isinstance(dict_val["items"][0], dict)
     val = json_dumps(dict_val)
-    assert val == '{"is_hydrated": false, "items": [{"tags": ["123", "456"]}]}'
+    f_items = '[{"tags": ["123", "456"]}]'
+    f_formatted_router = str(formatted_router).replace("'", '"')
+    assert (
+        val
+        == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
+    )
 
 
 def test_reset_with_mutables():

+ 29 - 0
tests/utils/test_format.py

@@ -446,6 +446,33 @@ def test_format_query_params(input, output):
     assert format.format_query_params(input) == output
 
 
+formatted_router = {
+    "session": {"client_token": "", "client_ip": "", "session_id": ""},
+    "headers": {
+        "host": "",
+        "origin": "",
+        "upgrade": "",
+        "connection": "",
+        "pragma": "",
+        "cache_control": "",
+        "user_agent": "",
+        "sec_websocket_version": "",
+        "sec_websocket_key": "",
+        "sec_websocket_extensions": "",
+        "accept_encoding": "",
+        "accept_language": "",
+    },
+    "page": {
+        "host": "",
+        "path": "",
+        "raw_path": "",
+        "full_path": "",
+        "full_raw_path": "",
+        "params": {},
+    },
+}
+
+
 @pytest.mark.parametrize(
     "input, output",
     [
@@ -474,6 +501,7 @@ def test_format_query_params(input, output):
                 "obj": {"prop1": 42, "prop2": "hello"},
                 "sum": 3.14,
                 "upper": "",
+                "router": formatted_router,
             },
         ),
         (
@@ -484,6 +512,7 @@ def test_format_query_params(input, output):
                 "is_hydrated": False,
                 "t": "18:53:00+01:00",
                 "td": "11 days, 0:11:00",
+                "router": formatted_router,
             },
         ),
     ],