Преглед на файлове

use better typing for on_load (#4274)

* use better typing for on_load

* make app dataclass

* get it right pyright

* make lifespan into a dataclass
Khaleel Al-Adhami преди 6 месеца
родител
ревизия
4254eadce3
променени са 7 файла, в които са добавени 45 реда и са изтрити 43 реда
  1. 30 30
      reflex/app.py
  2. 5 1
      reflex/app_mixins/lifespan.py
  3. 3 1
      reflex/app_mixins/middleware.py
  4. 3 2
      reflex/app_mixins/mixin.py
  5. 2 1
      reflex/page.py
  6. 1 8
      tests/units/test_app.py
  7. 1 0
      tests/units/test_state.py

+ 30 - 30
reflex/app.py

@@ -46,7 +46,6 @@ from starlette_admin.contrib.sqla.view import ModelView
 from reflex import constants
 from reflex import constants
 from reflex.admin import AdminDash
 from reflex.admin import AdminDash
 from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
 from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin
-from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler.compiler import (
 from reflex.compiler.compiler import (
@@ -70,7 +69,14 @@ from reflex.components.core.client_side_routing import (
 from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.core.upload import Upload, get_upload_dir
 from reflex.components.radix import themes
 from reflex.components.radix import themes
 from reflex.config import environment, get_config
 from reflex.config import environment, get_config
-from reflex.event import Event, EventHandler, EventSpec, window_alert
+from reflex.event import (
+    Event,
+    EventHandler,
+    EventSpec,
+    EventType,
+    IndividualEventType,
+    window_alert,
+)
 from reflex.model import Model, get_db_status
 from reflex.model import Model, get_db_status
 from reflex.page import (
 from reflex.page import (
     DECORATED_PAGES,
     DECORATED_PAGES,
@@ -189,11 +195,12 @@ class UnevaluatedPage:
     title: Union[Var, str, None]
     title: Union[Var, str, None]
     description: Union[Var, str, None]
     description: Union[Var, str, None]
     image: str
     image: str
-    on_load: Union[EventHandler, EventSpec, List[Union[EventHandler, EventSpec]], None]
+    on_load: Union[EventType[[]], None]
     meta: List[Dict[str, str]]
     meta: List[Dict[str, str]]
 
 
 
 
-class App(MiddlewareMixin, LifespanMixin, Base):
+@dataclasses.dataclass()
+class App(MiddlewareMixin, LifespanMixin):
     """The main Reflex app that encapsulates the backend and frontend.
     """The main Reflex app that encapsulates the backend and frontend.
 
 
     Every Reflex app needs an app defined in its main module.
     Every Reflex app needs an app defined in its main module.
@@ -215,24 +222,26 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     """
     """
 
 
     # The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app.
     # The global [theme](https://reflex.dev/docs/styling/theming/#theme) for the entire app.
-    theme: Optional[Component] = themes.theme(accent_color="blue")
+    theme: Optional[Component] = dataclasses.field(
+        default_factory=lambda: themes.theme(accent_color="blue")
+    )
 
 
     # The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app.
     # The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app.
-    style: ComponentStyle = {}
+    style: ComponentStyle = dataclasses.field(default_factory=dict)
 
 
     # A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app.
     # A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app.
-    stylesheets: List[str] = []
+    stylesheets: List[str] = dataclasses.field(default_factory=list)
 
 
     # A component that is present on every page (defaults to the Connection Error banner).
     # A component that is present on every page (defaults to the Connection Error banner).
     overlay_component: Optional[Union[Component, ComponentCallable]] = (
     overlay_component: Optional[Union[Component, ComponentCallable]] = (
-        default_overlay_component()
+        dataclasses.field(default_factory=default_overlay_component)
     )
     )
 
 
     # Error boundary component to wrap the app with.
     # Error boundary component to wrap the app with.
     error_boundary: Optional[ComponentCallable] = default_error_boundary
     error_boundary: Optional[ComponentCallable] = default_error_boundary
 
 
     # Components to add to the head of every page.
     # Components to add to the head of every page.
-    head_components: List[Component] = []
+    head_components: List[Component] = dataclasses.field(default_factory=list)
 
 
     # The Socket.IO AsyncServer instance.
     # The Socket.IO AsyncServer instance.
     sio: Optional[AsyncServer] = None
     sio: Optional[AsyncServer] = None
@@ -244,10 +253,12 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     html_custom_attrs: Optional[Dict[str, str]] = None
     html_custom_attrs: Optional[Dict[str, str]] = None
 
 
     # A map from a route to an unevaluated page. PRIVATE.
     # A map from a route to an unevaluated page. PRIVATE.
-    unevaluated_pages: Dict[str, UnevaluatedPage] = {}
+    unevaluated_pages: Dict[str, UnevaluatedPage] = dataclasses.field(
+        default_factory=dict
+    )
 
 
     # A map from a page route to the component to render. Users should use `add_page`. PRIVATE.
     # A map from a page route to the component to render. Users should use `add_page`. PRIVATE.
-    pages: Dict[str, Component] = {}
+    pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
 
 
     # The backend API object. PRIVATE.
     # The backend API object. PRIVATE.
     api: FastAPI = None  # type: ignore
     api: FastAPI = None  # type: ignore
@@ -259,7 +270,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     _state_manager: Optional[StateManager] = None
     _state_manager: Optional[StateManager] = None
 
 
     # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
     # Mapping from a route to event handlers to trigger when the page loads. PRIVATE.
-    load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
+    load_events: Dict[str, List[IndividualEventType[[]]]] = dataclasses.field(
+        default_factory=dict
+    )
 
 
     # Admin dashboard to view and manage the database. PRIVATE.
     # Admin dashboard to view and manage the database. PRIVATE.
     admin_dash: Optional[AdminDash] = None
     admin_dash: Optional[AdminDash] = None
@@ -268,7 +281,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
     event_namespace: Optional[EventNamespace] = None
     event_namespace: Optional[EventNamespace] = None
 
 
     # Background tasks that are currently running. PRIVATE.
     # Background tasks that are currently running. PRIVATE.
-    background_tasks: Set[asyncio.Task] = set()
+    background_tasks: Set[asyncio.Task] = dataclasses.field(default_factory=set)
 
 
     # Frontend Error Handler Function
     # Frontend Error Handler Function
     frontend_exception_handler: Callable[[Exception], None] = (
     frontend_exception_handler: Callable[[Exception], None] = (
@@ -280,23 +293,14 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         [Exception], Union[EventSpec, List[EventSpec], None]
         [Exception], Union[EventSpec, List[EventSpec], None]
     ] = default_backend_exception_handler
     ] = default_backend_exception_handler
 
 
-    def __init__(self, **kwargs):
+    def __post_init__(self):
         """Initialize the app.
         """Initialize the app.
 
 
-        Args:
-            **kwargs: Kwargs to initialize the app with.
-
         Raises:
         Raises:
             ValueError: If the event namespace is not provided in the config.
             ValueError: If the event namespace is not provided in the config.
                         Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
                         Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
                         of the DefaultState and the client app state).
                         of the DefaultState and the client app state).
         """
         """
-        if "connect_error_component" in kwargs:
-            raise ValueError(
-                "`connect_error_component` is deprecated, use `overlay_component` instead"
-            )
-        super().__init__(**kwargs)
-
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
         # Special case to allow test cases have multiple subclasses of rx.BaseState.
         if not is_testing_env() and BaseState.__subclasses__() != [State]:
         if not is_testing_env() and BaseState.__subclasses__() != [State]:
             # Only rx.State is allowed as Base State subclass.
             # Only rx.State is allowed as Base State subclass.
@@ -471,9 +475,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         title: str | Var | None = None,
         title: str | Var | None = None,
         description: str | Var | None = None,
         description: str | Var | None = None,
         image: str = constants.DefaultPage.IMAGE,
         image: str = constants.DefaultPage.IMAGE,
-        on_load: (
-            EventHandler | EventSpec | list[EventHandler | EventSpec] | None
-        ) = None,
+        on_load: EventType[[]] | None = None,
         meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
         meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
     ):
     ):
         """Add a page to the app.
         """Add a page to the app.
@@ -559,7 +561,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         self._check_routes_conflict(route)
         self._check_routes_conflict(route)
         self.pages[route] = component
         self.pages[route] = component
 
 
-    def get_load_events(self, route: str) -> list[EventHandler | EventSpec]:
+    def get_load_events(self, route: str) -> list[IndividualEventType[[]]]:
         """Get the load events for a route.
         """Get the load events for a route.
 
 
         Args:
         Args:
@@ -618,9 +620,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
         title: str = constants.Page404.TITLE,
         title: str = constants.Page404.TITLE,
         image: str = constants.Page404.IMAGE,
         image: str = constants.Page404.IMAGE,
         description: str = constants.Page404.DESCRIPTION,
         description: str = constants.Page404.DESCRIPTION,
-        on_load: (
-            EventHandler | EventSpec | list[EventHandler | EventSpec] | None
-        ) = None,
+        on_load: EventType[[]] | None = None,
         meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
         meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
     ):
     ):
         """Define a custom 404 page for any url having no match.
         """Define a custom 404 page for any url having no match.

+ 5 - 1
reflex/app_mixins/lifespan.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import contextlib
 import contextlib
+import dataclasses
 import functools
 import functools
 import inspect
 import inspect
 from typing import Callable, Coroutine, Set, Union
 from typing import Callable, Coroutine, Set, Union
@@ -16,11 +17,14 @@ from reflex.utils.exceptions import InvalidLifespanTaskType
 from .mixin import AppMixin
 from .mixin import AppMixin
 
 
 
 
+@dataclasses.dataclass
 class LifespanMixin(AppMixin):
 class LifespanMixin(AppMixin):
     """A Mixin that allow tasks to run during the whole app lifespan."""
     """A Mixin that allow tasks to run during the whole app lifespan."""
 
 
     # Lifespan tasks that are planned to run.
     # Lifespan tasks that are planned to run.
-    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = set()
+    lifespan_tasks: Set[Union[asyncio.Task, Callable]] = dataclasses.field(
+        default_factory=set
+    )
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def _run_lifespan_tasks(self, app: FastAPI):
     async def _run_lifespan_tasks(self, app: FastAPI):

+ 3 - 1
reflex/app_mixins/middleware.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import dataclasses
 from typing import List
 from typing import List
 
 
 from reflex.event import Event
 from reflex.event import Event
@@ -12,11 +13,12 @@ from reflex.state import BaseState, StateUpdate
 from .mixin import AppMixin
 from .mixin import AppMixin
 
 
 
 
+@dataclasses.dataclass
 class MiddlewareMixin(AppMixin):
 class MiddlewareMixin(AppMixin):
     """Middleware Mixin that allow to add middleware to the app."""
     """Middleware Mixin that allow to add middleware to the app."""
 
 
     # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
     # Middleware to add to the app. Users should use `add_middleware`. PRIVATE.
-    middleware: List[Middleware] = []
+    middleware: List[Middleware] = dataclasses.field(default_factory=list)
 
 
     def _init_mixin(self):
     def _init_mixin(self):
         self.middleware.append(HydrateMiddleware())
         self.middleware.append(HydrateMiddleware())

+ 3 - 2
reflex/app_mixins/mixin.py

@@ -1,9 +1,10 @@
 """Default mixin for all app mixins."""
 """Default mixin for all app mixins."""
 
 
-from reflex.base import Base
+import dataclasses
 
 
 
 
-class AppMixin(Base):
+@dataclasses.dataclass
+class AppMixin:
     """Define the base class for all app mixins."""
     """Define the base class for all app mixins."""
 
 
     def _init_mixin(self):
     def _init_mixin(self):

+ 2 - 1
reflex/page.py

@@ -6,6 +6,7 @@ from collections import defaultdict
 from typing import Any, Dict, List
 from typing import Any, Dict, List
 
 
 from reflex.config import get_config
 from reflex.config import get_config
+from reflex.event import EventType
 
 
 DECORATED_PAGES: Dict[str, List] = defaultdict(list)
 DECORATED_PAGES: Dict[str, List] = defaultdict(list)
 
 
@@ -17,7 +18,7 @@ def page(
     description: str | None = None,
     description: str | None = None,
     meta: list[Any] | None = None,
     meta: list[Any] | None = None,
     script_tags: list[Any] | None = None,
     script_tags: list[Any] | None = None,
-    on_load: Any | list[Any] | None = None,
+    on_load: EventType[[]] | None = None,
 ):
 ):
     """Decorate a function as a page.
     """Decorate a function as a page.
 
 

+ 1 - 8
tests/units/test_app.py

@@ -1211,7 +1211,7 @@ async def test_process_events(mocker, token: str):
     ],
     ],
 )
 )
 def test_overlay_component(
 def test_overlay_component(
-    state: State | None,
+    state: Type[State] | None,
     overlay_component: Component | ComponentCallable | None,
     overlay_component: Component | ComponentCallable | None,
     exp_page_child: Type[Component] | None,
     exp_page_child: Type[Component] | None,
 ):
 ):
@@ -1403,13 +1403,6 @@ def test_app_state_determination():
     assert a4.state is not None
     assert a4.state is not None
 
 
 
 
-# for coverage
-def test_raise_on_connect_error():
-    """Test that the connect_error function is called."""
-    with pytest.raises(ValueError):
-        App(connect_error_component="Foo")
-
-
 def test_raise_on_state():
 def test_raise_on_state():
     """Test that the state is set."""
     """Test that the state is set."""
     # state kwargs is deprecated, we just make sure the app is created anyway.
     # state kwargs is deprecated, we just make sure the app is created anyway.

+ 1 - 0
tests/units/test_state.py

@@ -2725,6 +2725,7 @@ class OnLoadState(State):
 
 
     num: int = 0
     num: int = 0
 
 
+    @rx.event
     def test_handler(self):
     def test_handler(self):
         """Test handler."""
         """Test handler."""
         self.num += 1
         self.num += 1