浏览代码

deprecate get_ methods for router_data, use BaseVars instead (#1967)

Thomas Brandého 1 年之前
父节点
当前提交
b1bab1206d
共有 7 个文件被更改,包括 219 次插入16 次删除
  1. 2 0
      reflex/app.py
  2. 4 2
      reflex/constants/__init__.py
  3. 1 0
      reflex/constants/route.py
  4. 136 6
      reflex/state.py
  5. 7 6
      tests/test_app.py
  6. 40 2
      tests/test_state.py
  7. 29 0
      tests/utils/test_format.py

+ 2 - 0
reflex/app.py

@@ -54,6 +54,7 @@ from reflex.route import (
 )
 )
 from reflex.state import (
 from reflex.state import (
     DefaultState,
     DefaultState,
+    RouterData,
     State,
     State,
     StateManager,
     StateManager,
     StateManagerMemory,
     StateManagerMemory,
@@ -803,6 +804,7 @@ async def process(
             # assignment will recurse into substates and force recalculation of
             # assignment will recurse into substates and force recalculation of
             # dependent ComputedVar (dynamic route variables)
             # dependent ComputedVar (dynamic route variables)
             state.router_data = router_data
             state.router_data = router_data
+            state.router = RouterData(router_data)
 
 
         # Preprocess the event.
         # Preprocess the event.
         update = await app.preprocess(state, event)
         update = await app.preprocess(state, event)

+ 4 - 2
reflex/constants/__init__.py

@@ -39,6 +39,7 @@ from .installer import (
 )
 )
 from .route import (
 from .route import (
     ROUTE_NOT_FOUND,
     ROUTE_NOT_FOUND,
+    ROUTER,
     ROUTER_DATA,
     ROUTER_DATA,
     DefaultPage,
     DefaultPage,
     Page404,
     Page404,
@@ -77,9 +78,10 @@ __ALL__ = [
     PYTEST_CURRENT_TEST,
     PYTEST_CURRENT_TEST,
     PRODUCTION_BACKEND_URL,
     PRODUCTION_BACKEND_URL,
     Reflex,
     Reflex,
-    RouteVar,
-    RouteRegex,
     RouteArgType,
     RouteArgType,
+    RouteRegex,
+    RouteVar,
+    ROUTER,
     ROUTER_DATA,
     ROUTER_DATA,
     ROUTE_NOT_FOUND,
     ROUTE_NOT_FOUND,
     SETTER_PREFIX,
     SETTER_PREFIX,

+ 1 - 0
reflex/constants/route.py

@@ -13,6 +13,7 @@ class RouteArgType(SimpleNamespace):
 
 
 
 
 # the name of the backend var containing path and client information
 # the name of the backend var containing path and client information
+ROUTER = "router"
 ROUTER_DATA = "router_data"
 ROUTER_DATA = "router_data"
 
 
 
 

+ 136 - 6
reflex/state.py

@@ -48,6 +48,99 @@ from reflex.vars import BaseVar, ComputedVar, Var
 Delta = Dict[str, Any]
 Delta = Dict[str, Any]
 
 
 
 
+class HeaderData(Base):
+    """An object containing headers data."""
+
+    host: str = ""
+    origin: str = ""
+    upgrade: str = ""
+    connection: str = ""
+    pragma: str = ""
+    cache_control: str = ""
+    user_agent: str = ""
+    sec_websocket_version: str = ""
+    sec_websocket_key: str = ""
+    sec_websocket_extensions: str = ""
+    accept_encoding: str = ""
+    accept_language: str = ""
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the HeaderData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
+                setattr(self, format.to_snake_case(k), v)
+
+
+class PageData(Base):
+    """An object containing page data."""
+
+    host: str = ""  #  repeated with self.headers.origin (remove or keep the duplicate?)
+    path: str = ""
+    raw_path: str = ""
+    full_path: str = ""
+    full_raw_path: str = ""
+    params: dict = {}
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the PageData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin")
+            self.path = router_data.get(constants.RouteVar.PATH, "")
+            self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "")
+            self.full_path = f"{self.host}{self.path}"
+            self.full_raw_path = f"{self.host}{self.raw_path}"
+            self.params = router_data.get(constants.RouteVar.QUERY, {})
+
+
+class SessionData(Base):
+    """An object containing session data."""
+
+    client_token: str = ""
+    client_ip: str = ""
+    session_id: str = ""
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initalize the SessionData object based on router_data.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        if router_data:
+            self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
+            self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
+            self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
+
+
+class RouterData(Base):
+    """An object containing RouterData."""
+
+    session: SessionData = SessionData()
+    headers: HeaderData = HeaderData()
+    page: PageData = PageData()
+
+    def __init__(self, router_data: Optional[dict] = None):
+        """Initialize the RouterData object.
+
+        Args:
+            router_data: the router_data dict.
+        """
+        super().__init__()
+        self.session = SessionData(router_data)
+        self.headers = HeaderData(router_data)
+        self.page = PageData(router_data)
+
+
 class State(Base, ABC, extra=pydantic.Extra.allow):
 class State(Base, ABC, extra=pydantic.Extra.allow):
     """The state of the app."""
     """The state of the app."""
 
 
@@ -96,6 +189,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
     # Per-instance copy of backend variable values
     # Per-instance copy of backend variable values
     _backend_vars: Dict[str, Any] = {}
     _backend_vars: Dict[str, Any] = {}
 
 
+    # The router data for the current page
+    router: RouterData = RouterData()
+
     def __init__(self, *args, parent_state: State | None = None, **kwargs):
     def __init__(self, *args, parent_state: State | None = None, **kwargs):
         """Initialize the state.
         """Initialize the state.
 
 
@@ -494,6 +590,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The token of the client.
             The token of the client.
         """
         """
+        console.deprecate(
+            feature_name="get_token",
+            reason="replaced by `State.router.session.client_token`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
         return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
 
 
     def get_sid(self) -> str:
     def get_sid(self) -> str:
@@ -502,6 +604,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The session ID of the client.
             The session ID of the client.
         """
         """
+        console.deprecate(
+            feature_name="get_sid",
+            reason="replaced by `State.router.session.session_id`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.SESSION_ID, "")
         return self.router_data.get(constants.RouteVar.SESSION_ID, "")
 
 
     def get_headers(self) -> Dict:
     def get_headers(self) -> Dict:
@@ -510,6 +618,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The headers of the client.
             The headers of the client.
         """
         """
+        console.deprecate(
+            feature_name="get_headers",
+            reason="replaced by `State.router.headers`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.HEADERS, {})
         return self.router_data.get(constants.RouteVar.HEADERS, {})
 
 
     def get_client_ip(self) -> str:
     def get_client_ip(self) -> str:
@@ -518,6 +632,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The IP of the client.
             The IP of the client.
         """
         """
+        console.deprecate(
+            feature_name="get_client_ip",
+            reason="replaced by `State.router.session.client_ip`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
         return self.router_data.get(constants.RouteVar.CLIENT_IP, "")
 
 
     def get_current_page(self, origin=False) -> str:
     def get_current_page(self, origin=False) -> str:
@@ -529,10 +649,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The current page.
             The current page.
         """
         """
-        if origin:
-            return self.router_data.get(constants.RouteVar.ORIGIN, "")
-        else:
-            return self.router_data.get(constants.RouteVar.PATH, "")
+        console.deprecate(
+            feature_name="get_current_page",
+            reason="replaced by State.router.page / self.router.page",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
+
+        return self.router.page.raw_path if origin else self.router.page.path
 
 
     def get_query_params(self) -> dict[str, str]:
     def get_query_params(self) -> dict[str, str]:
         """Obtain the query parameters for the queried page.
         """Obtain the query parameters for the queried page.
@@ -542,6 +666,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         Returns:
         Returns:
             The dict of query parameters.
             The dict of query parameters.
         """
         """
+        console.deprecate(
+            feature_name="get_query_params",
+            reason="replaced by `State.router.page.params`",
+            deprecation_version="0.3.0",
+            removal_version="0.3.1",
+        )
         return self.router_data.get(constants.RouteVar.QUERY, {})
         return self.router_data.get(constants.RouteVar.QUERY, {})
 
 
     def get_cookies(self) -> dict[str, str]:
     def get_cookies(self) -> dict[str, str]:
@@ -583,14 +713,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
         def argsingle_factory(param):
         def argsingle_factory(param):
             @ComputedVar
             @ComputedVar
             def inner_func(self) -> str:
             def inner_func(self) -> str:
-                return self.get_query_params().get(param, "")
+                return self.router.page.params.get(param, "")
 
 
             return inner_func
             return inner_func
 
 
         def arglist_factory(param):
         def arglist_factory(param):
             @ComputedVar
             @ComputedVar
             def inner_func(self) -> List:
             def inner_func(self) -> List:
-                return self.get_query_params().get(param, [])
+                return self.router.page.params.get(param, [])
 
 
             return inner_func
             return inner_func
 
 

+ 7 - 6
tests/test_app.py

@@ -34,7 +34,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
 from reflex.model import Model
-from reflex.state import State, StateManagerRedis, StateUpdate
+from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
 from reflex.style import Style
 from reflex.style import Style
 from reflex.utils import format
 from reflex.utils import format
 from reflex.vars import ComputedVar
 from reflex.vars import ComputedVar
@@ -255,9 +255,9 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool
     assert set(app.pages.keys()) == {"test/[dynamic]"}
     assert set(app.pages.keys()) == {"test/[dynamic]"}
     assert "dynamic" in app.state.computed_vars
     assert "dynamic" in app.state.computed_vars
     assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == {
     assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == {
-        constants.ROUTER_DATA
+        constants.ROUTER
     }
     }
-    assert constants.ROUTER_DATA in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state().computed_var_dependencies
 
 
 
 
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
 def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -874,9 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
     assert arg_name in app.state.vars
     assert arg_name in app.state.vars
     assert arg_name in app.state.computed_vars
     assert arg_name in app.state.computed_vars
     assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
     assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
-        constants.ROUTER_DATA
+        constants.ROUTER
     }
     }
-    assert constants.ROUTER_DATA in app.state().computed_var_dependencies
+    assert constants.ROUTER in app.state().computed_var_dependencies
 
 
     sid = "mock_sid"
     sid = "mock_sid"
     client_ip = "127.0.0.1"
     client_ip = "127.0.0.1"
@@ -912,6 +912,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             "token": token,
             "token": token,
             **hydrate_event.router_data,
             **hydrate_event.router_data,
         }
         }
+        exp_router = RouterData(exp_router_data)
         process_coro = process(
         process_coro = process(
             app,
             app,
             event=hydrate_event,
             event=hydrate_event,
@@ -920,7 +921,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
             client_ip=client_ip,
             client_ip=client_ip,
         )
         )
         update = await process_coro.__anext__()  # type: ignore
         update = await process_coro.__anext__()  # type: ignore
-
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
         assert update == StateUpdate(
         assert update == StateUpdate(
             delta={
             delta={
@@ -930,6 +930,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     constants.CompileVars.IS_HYDRATED: False,
                     constants.CompileVars.IS_HYDRATED: False,
                     "loaded": exp_index,
                     "loaded": exp_index,
                     "counter": exp_index,
                     "counter": exp_index,
+                    "router": exp_router,
                     # "side_effect_counter": exp_index,
                     # "side_effect_counter": exp_index,
                 }
                 }
             },
             },

+ 40 - 2
tests/test_state.py

@@ -22,6 +22,7 @@ from reflex.state import (
     ImmutableStateError,
     ImmutableStateError,
     LockExpiredError,
     LockExpiredError,
     MutableProxy,
     MutableProxy,
+    RouterData,
     State,
     State,
     StateManager,
     StateManager,
     StateManagerMemory,
     StateManagerMemory,
@@ -40,6 +41,33 @@ LOCK_EXPIRATION = 2000 if CI else 100
 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
 
 
 
 
+formatted_router = {
+    "session": {"client_token": "", "client_ip": "", "session_id": ""},
+    "headers": {
+        "host": "",
+        "origin": "",
+        "upgrade": "",
+        "connection": "",
+        "pragma": "",
+        "cache_control": "",
+        "user_agent": "",
+        "sec_websocket_version": "",
+        "sec_websocket_key": "",
+        "sec_websocket_extensions": "",
+        "accept_encoding": "",
+        "accept_language": "",
+    },
+    "page": {
+        "host": "",
+        "path": "",
+        "raw_path": "",
+        "full_path": "",
+        "full_raw_path": "",
+        "params": {},
+    },
+}
+
+
 class Object(Base):
 class Object(Base):
     """A test object fixture."""
     """A test object fixture."""
 
 
@@ -226,6 +254,7 @@ def test_class_vars(test_state):
     cls = type(test_state)
     cls = type(test_state)
     assert set(cls.vars.keys()) == {
     assert set(cls.vars.keys()) == {
         CompileVars.IS_HYDRATED,  # added by hydrate_middleware to all State
         CompileVars.IS_HYDRATED,  # added by hydrate_middleware to all State
+        "router",
         "num1",
         "num1",
         "num2",
         "num2",
         "key",
         "key",
@@ -614,6 +643,7 @@ def test_reset(test_state, child_state):
         "map_key",
         "map_key",
         "mapping",
         "mapping",
         "dt",
         "dt",
+        "router",
     }
     }
 
 
     # The dirty vars should be reset.
     # The dirty vars should be reset.
@@ -787,7 +817,7 @@ def test_get_current_page(test_state):
     assert test_state.get_current_page() == ""
     assert test_state.get_current_page() == ""
 
 
     route = "mypage/subpage"
     route = "mypage/subpage"
-    test_state.router_data = {RouteVar.PATH: route}
+    test_state.router = RouterData({RouteVar.PATH: route})
 
 
     assert test_state.get_current_page() == route
     assert test_state.get_current_page() == route
 
 
@@ -1131,16 +1161,19 @@ def test_computed_var_depends_on_parent_non_cached():
         cs.get_name(): {"dep_v": 2},
         cs.get_name(): {"dep_v": 2},
         "no_cache_v": 1,
         "no_cache_v": 1,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     }
     assert ps.dict() == {
     assert ps.dict() == {
         cs.get_name(): {"dep_v": 4},
         cs.get_name(): {"dep_v": 4},
         "no_cache_v": 3,
         "no_cache_v": 3,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     }
     assert ps.dict() == {
     assert ps.dict() == {
         cs.get_name(): {"dep_v": 6},
         cs.get_name(): {"dep_v": 6},
         "no_cache_v": 5,
         "no_cache_v": 5,
         CompileVars.IS_HYDRATED: False,
         CompileVars.IS_HYDRATED: False,
+        "router": formatted_router,
     }
     }
     assert counter == 6
     assert counter == 6
 
 
@@ -2114,7 +2147,12 @@ def test_json_dumps_with_mutables():
     dict_val = MutableContainsBase().dict()
     dict_val = MutableContainsBase().dict()
     assert isinstance(dict_val["items"][0], dict)
     assert isinstance(dict_val["items"][0], dict)
     val = json_dumps(dict_val)
     val = json_dumps(dict_val)
-    assert val == '{"is_hydrated": false, "items": [{"tags": ["123", "456"]}]}'
+    f_items = '[{"tags": ["123", "456"]}]'
+    f_formatted_router = str(formatted_router).replace("'", '"')
+    assert (
+        val
+        == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
+    )
 
 
 
 
 def test_reset_with_mutables():
 def test_reset_with_mutables():

+ 29 - 0
tests/utils/test_format.py

@@ -446,6 +446,33 @@ def test_format_query_params(input, output):
     assert format.format_query_params(input) == output
     assert format.format_query_params(input) == output
 
 
 
 
+formatted_router = {
+    "session": {"client_token": "", "client_ip": "", "session_id": ""},
+    "headers": {
+        "host": "",
+        "origin": "",
+        "upgrade": "",
+        "connection": "",
+        "pragma": "",
+        "cache_control": "",
+        "user_agent": "",
+        "sec_websocket_version": "",
+        "sec_websocket_key": "",
+        "sec_websocket_extensions": "",
+        "accept_encoding": "",
+        "accept_language": "",
+    },
+    "page": {
+        "host": "",
+        "path": "",
+        "raw_path": "",
+        "full_path": "",
+        "full_raw_path": "",
+        "params": {},
+    },
+}
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "input, output",
     "input, output",
     [
     [
@@ -474,6 +501,7 @@ def test_format_query_params(input, output):
                 "obj": {"prop1": 42, "prop2": "hello"},
                 "obj": {"prop1": 42, "prop2": "hello"},
                 "sum": 3.14,
                 "sum": 3.14,
                 "upper": "",
                 "upper": "",
+                "router": formatted_router,
             },
             },
         ),
         ),
         (
         (
@@ -484,6 +512,7 @@ def test_format_query_params(input, output):
                 "is_hydrated": False,
                 "is_hydrated": False,
                 "t": "18:53:00+01:00",
                 "t": "18:53:00+01:00",
                 "td": "11 days, 0:11:00",
                 "td": "11 days, 0:11:00",
+                "router": formatted_router,
             },
             },
         ),
         ),
     ],
     ],