1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090 |
- """The main Reflex app."""
- from __future__ import annotations
- import asyncio
- import concurrent.futures
- import contextlib
- import copy
- import functools
- import os
- from typing import (
- Any,
- AsyncIterator,
- Callable,
- Coroutine,
- Dict,
- List,
- Optional,
- Set,
- Type,
- Union,
- get_args,
- get_type_hints,
- )
- from fastapi import FastAPI, HTTPException, Request, UploadFile
- from fastapi.middleware import cors
- from fastapi.responses import StreamingResponse
- from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
- from socketio import ASGIApp, AsyncNamespace, AsyncServer
- from starlette_admin.contrib.sqla.admin import Admin
- from starlette_admin.contrib.sqla.view import ModelView
- from reflex import constants
- from reflex.admin import AdminDash
- from reflex.base import Base
- from reflex.compiler import compiler
- from reflex.compiler import utils as compiler_utils
- from reflex.components import connection_modal
- from reflex.components.base.app_wrap import AppWrap
- from reflex.components.component import Component, ComponentStyle
- from reflex.components.layout.fragment import Fragment
- from reflex.components.navigation.client_side_routing import (
- Default404Page,
- wait_for_client_redirect,
- )
- from reflex.config import get_config
- from reflex.event import Event, EventHandler, EventSpec
- from reflex.middleware import HydrateMiddleware, Middleware
- from reflex.model import Model
- from reflex.page import (
- DECORATED_PAGES,
- )
- from reflex.route import (
- catchall_in_route,
- catchall_prefix,
- get_route_args,
- verify_route_validity,
- )
- from reflex.state import (
- BaseState,
- RouterData,
- State,
- StateManager,
- StateUpdate,
- )
- from reflex.utils import console, format, prerequisites, types
- from reflex.utils.imports import ImportVar
- # Define custom types.
- ComponentCallable = Callable[[], Component]
- Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
- def default_overlay_component() -> Component:
- """Default overlay_component attribute for App.
- Returns:
- The default overlay_component, which is a connection_modal.
- """
- return connection_modal()
- class App(Base):
- """A Reflex application."""
- # A map from a page route to the component to render.
- pages: Dict[str, Component] = {}
- # A list of URLs to stylesheets to include in the app.
- stylesheets: List[str] = []
- # The backend API object.
- api: FastAPI = None # type: ignore
- # The Socket.IO AsyncServer.
- sio: Optional[AsyncServer] = None
- # The socket app.
- socket_app: Optional[ASGIApp] = None
- # The state class to use for the app.
- state: Optional[Type[BaseState]] = None
- # Class to manage many client states.
- _state_manager: Optional[StateManager] = None
- # The styling to apply to each component.
- style: ComponentStyle = {}
- # Middleware to add to the app.
- middleware: List[Middleware] = []
- # List of event handlers to trigger when a page loads.
- load_events: Dict[str, List[Union[EventHandler, EventSpec]]] = {}
- # Admin dashboard
- admin_dash: Optional[AdminDash] = None
- # The async server name space
- event_namespace: Optional[EventNamespace] = None
- # Components to add to the head of every page.
- head_components: List[Component] = []
- # A component that is present on every page.
- overlay_component: Optional[
- Union[Component, ComponentCallable]
- ] = default_overlay_component
- # Background tasks that are currently running
- background_tasks: Set[asyncio.Task] = set()
- # The radix theme for the entire app
- theme: Optional[Component] = None
- def __init__(self, *args, **kwargs):
- """Initialize the app.
- Args:
- *args: Args to initialize the app with.
- **kwargs: Kwargs to initialize the app with.
- Raises:
- ValueError: If the event namespace is not provided in the config.
- Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
- of the DefaultState and the client app state).
- """
- if "connect_error_component" in kwargs:
- raise ValueError(
- "`connect_error_component` is deprecated, use `overlay_component` instead"
- )
- super().__init__(*args, **kwargs)
- state_subclasses = BaseState.__subclasses__()
- is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
- # Special case to allow test cases have multiple subclasses of rx.BaseState.
- if not is_testing_env:
- # Only one Base State class is allowed.
- if len(state_subclasses) > 1:
- raise ValueError(
- "rx.BaseState cannot be subclassed multiple times. use rx.State instead"
- )
- if "state" in kwargs:
- console.deprecate(
- feature_name="`state` argument for App()",
- reason="due to all `rx.State` subclasses being inferred.",
- deprecation_version="0.3.5",
- removal_version="0.4.0",
- )
- self.state = State
- # Get the config
- config = get_config()
- # Add middleware.
- self.middleware.append(HydrateMiddleware())
- # Set up the API.
- self.api = FastAPI()
- self.add_cors()
- self.add_default_endpoints()
- if self.state:
- # Set up the state manager.
- self._state_manager = StateManager.create(state=self.state)
- # Set up the Socket.IO AsyncServer.
- self.sio = AsyncServer(
- async_mode="asgi",
- cors_allowed_origins="*"
- if config.cors_allowed_origins == ["*"]
- else config.cors_allowed_origins,
- cors_credentials=True,
- max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
- ping_interval=constants.Ping.INTERVAL,
- ping_timeout=constants.Ping.TIMEOUT,
- )
- # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
- self.socket_app = ASGIApp(self.sio, socketio_path="")
- namespace = config.get_event_namespace()
- if not namespace:
- raise ValueError("event namespace must be provided in the config.")
- # Create the event namespace and attach the main app. Not related to any paths.
- self.event_namespace = EventNamespace(namespace, self)
- # Register the event namespace with the socket.
- self.sio.register_namespace(self.event_namespace)
- # Mount the socket app with the API.
- self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
- # Set up the admin dash.
- self.setup_admin_dash()
- # If a State is not used and no overlay_component is specified, do not render the connection modal
- if self.state is None and self.overlay_component is default_overlay_component:
- self.overlay_component = None
- def __repr__(self) -> str:
- """Get the string representation of the app.
- Returns:
- The string representation of the app.
- """
- return f"<App state={self.state.__name__ if self.state else None}>"
- def __call__(self) -> FastAPI:
- """Run the backend api instance.
- Returns:
- The backend api.
- """
- return self.api
- def add_default_endpoints(self):
- """Add the default endpoints."""
- # To test the server.
- self.api.get(str(constants.Endpoint.PING))(ping)
- # To upload files.
- self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
- def add_cors(self):
- """Add CORS middleware to the app."""
- self.api.add_middleware(
- cors.CORSMiddleware,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- allow_origins=["*"],
- )
- @property
- def state_manager(self) -> StateManager:
- """Get the state manager.
- Returns:
- The initialized state manager.
- Raises:
- ValueError: if the state has not been initialized.
- """
- if self._state_manager is None:
- raise ValueError("The state manager has not been initialized.")
- return self._state_manager
- async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
- """Preprocess the event.
- This is where middleware can modify the event before it is processed.
- Each middleware is called in the order it was added to the app.
- If a middleware returns an update, the event is not processed and the
- update is returned.
- Args:
- state: The state to preprocess.
- event: The event to preprocess.
- Returns:
- An optional state to return.
- """
- for middleware in self.middleware:
- if asyncio.iscoroutinefunction(middleware.preprocess):
- out = await middleware.preprocess(app=self, state=state, event=event) # type: ignore
- else:
- out = middleware.preprocess(app=self, state=state, event=event) # type: ignore
- if out is not None:
- return out # type: ignore
- async def postprocess(
- self, state: BaseState, event: Event, update: StateUpdate
- ) -> StateUpdate:
- """Postprocess the event.
- This is where middleware can modify the delta after it is processed.
- Each middleware is called in the order it was added to the app.
- Args:
- state: The state to postprocess.
- event: The event to postprocess.
- update: The current state update.
- Returns:
- The state update to return.
- """
- for middleware in self.middleware:
- if asyncio.iscoroutinefunction(middleware.postprocess):
- out = await middleware.postprocess(
- app=self, state=state, event=event, update=update # type: ignore
- )
- else:
- out = middleware.postprocess(
- app=self, state=state, event=event, update=update # type: ignore
- )
- if out is not None:
- return out # type: ignore
- return update
- def add_middleware(self, middleware: Middleware, index: int | None = None):
- """Add middleware to the app.
- Args:
- middleware: The middleware to add.
- index: The index to add the middleware at.
- """
- if index is None:
- self.middleware.append(middleware)
- else:
- self.middleware.insert(index, middleware)
- @staticmethod
- def _generate_component(component: Component | ComponentCallable) -> Component:
- """Generate a component from a callable.
- Args:
- component: The component function to call or Component to return as-is.
- Returns:
- The generated component.
- Raises:
- TypeError: When an invalid component function is passed.
- """
- try:
- return component if isinstance(component, Component) else component()
- except TypeError as e:
- message = str(e)
- if "BaseVar" in message or "ComputedVar" in message:
- raise TypeError(
- "You may be trying to use an invalid Python function on a state var. "
- "When referencing a var inside your render code, only limited var operations are supported. "
- "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
- ) from e
- raise e
- def add_page(
- self,
- component: Component | ComponentCallable,
- route: str | None = None,
- title: str = constants.DefaultPage.TITLE,
- description: str = constants.DefaultPage.DESCRIPTION,
- image: str = constants.DefaultPage.IMAGE,
- on_load: EventHandler
- | EventSpec
- | list[EventHandler | EventSpec]
- | None = None,
- meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
- script_tags: list[Component] | None = None,
- ):
- """Add a page to the app.
- If the component is a callable, by default the route is the name of the
- function. Otherwise, a route must be provided.
- Args:
- component: The component to display at the page.
- route: The route to display the component at.
- title: The title of the page.
- description: The description of the page.
- image: The image to display on the page.
- on_load: The event handler(s) that will be called each time the page load.
- meta: The metadata of the page.
- script_tags: List of script tags to be added to component
- """
- # If the route is not set, get it from the callable.
- if route is None:
- assert isinstance(
- component, Callable
- ), "Route must be set if component is not a callable."
- # Format the route.
- route = format.format_route(component.__name__)
- else:
- route = format.format_route(route, format_case=False)
- # Check if the route given is valid
- verify_route_validity(route)
- # Apply dynamic args to the route.
- if self.state:
- self.state.setup_dynamic_args(get_route_args(route))
- # Generate the component if it is a callable.
- component = self._generate_component(component)
- # Wrap the component in a fragment with optional overlay.
- if self.overlay_component is not None:
- component = Fragment.create(
- self._generate_component(self.overlay_component),
- component,
- )
- else:
- component = Fragment.create(component)
- # Add meta information to the component.
- compiler_utils.add_meta(
- component,
- title=title,
- image=image,
- description=description,
- meta=meta,
- )
- # Add script tags if given
- if script_tags:
- console.deprecate(
- feature_name="Passing script tags to add_page",
- reason="Add script components as children to the page component instead",
- deprecation_version="0.2.9",
- removal_version="0.4.0",
- )
- component.children.extend(script_tags)
- # Add the page.
- self._check_routes_conflict(route)
- self.pages[route] = component
- # Add the load events.
- if on_load:
- if not isinstance(on_load, list):
- on_load = [on_load]
- self.load_events[route] = on_load
- def get_load_events(self, route: str) -> list[EventHandler | EventSpec]:
- """Get the load events for a route.
- Args:
- route: The route to get the load events for.
- Returns:
- The load events for the route.
- """
- route = route.lstrip("/")
- if route == "":
- route = constants.PageNames.INDEX_ROUTE
- return self.load_events.get(route, [])
- def _check_routes_conflict(self, new_route: str):
- """Verify if there is any conflict between the new route and any existing route.
- Based on conflicts that NextJS would throw if not intercepted.
- Raises:
- ValueError: exception showing which conflict exist with the route to be added
- Args:
- new_route: the route being newly added.
- """
- newroute_catchall = catchall_in_route(new_route)
- if not newroute_catchall:
- return
- for route in self.pages:
- route = "" if route == "index" else route
- if new_route.startswith(f"{route}/[[..."):
- raise ValueError(
- f"You cannot define a route with the same specificity as a optional catch-all route ('{route}' and '{new_route}')"
- )
- route_catchall = catchall_in_route(route)
- if (
- route_catchall
- and newroute_catchall
- and catchall_prefix(route) == catchall_prefix(new_route)
- ):
- raise ValueError(
- f"You cannot use multiple catchall for the same dynamic route ({route} !== {new_route})"
- )
- def add_custom_404_page(
- self,
- component: Component | ComponentCallable | None = None,
- title: str = constants.Page404.TITLE,
- image: str = constants.Page404.IMAGE,
- description: str = constants.Page404.DESCRIPTION,
- on_load: EventHandler
- | EventSpec
- | list[EventHandler | EventSpec]
- | None = None,
- meta: list[dict[str, str]] = constants.DefaultPage.META_LIST,
- ):
- """Define a custom 404 page for any url having no match.
- If there is no page defined on 'index' route, add the 404 page to it.
- If there is no global catchall defined, add the 404 page with a catchall
- Args:
- component: The component to display at the page.
- title: The title of the page.
- description: The description of the page.
- image: The image to display on the page.
- on_load: The event handler(s) that will be called each time the page load.
- meta: The metadata of the page.
- """
- if component is None:
- component = Default404Page.create()
- self.add_page(
- component=wait_for_client_redirect(self._generate_component(component)),
- route=constants.Page404.SLUG,
- title=title or constants.Page404.TITLE,
- image=image or constants.Page404.IMAGE,
- description=description or constants.Page404.DESCRIPTION,
- on_load=on_load,
- meta=meta,
- )
- def setup_admin_dash(self):
- """Setup the admin dash."""
- # Get the admin dash.
- admin_dash = self.admin_dash
- if admin_dash and admin_dash.models:
- # Build the admin dashboard
- admin = (
- admin_dash.admin
- if admin_dash.admin
- else Admin(
- engine=Model.get_db_engine(),
- title="Reflex Admin Dashboard",
- logo_url="https://reflex.dev/Reflex.svg",
- )
- )
- for model in admin_dash.models:
- view = admin_dash.view_overrides.get(model, ModelView)
- admin.add_view(view(model))
- admin.mount_to(self.api)
- def get_frontend_packages(self, imports: Dict[str, set[ImportVar]]):
- """Gets the frontend packages to be installed and filters out the unnecessary ones.
- Args:
- imports: A dictionary containing the imports used in the current page.
- Example:
- >>> get_frontend_packages({"react": "16.14.0", "react-dom": "16.14.0"})
- """
- page_imports = {
- i
- for i, tags in imports.items()
- if i
- not in [
- *constants.PackageJson.DEPENDENCIES.keys(),
- *constants.PackageJson.DEV_DEPENDENCIES.keys(),
- ]
- and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
- and i != ""
- and any(tag.install for tag in tags)
- }
- frontend_packages = get_config().frontend_packages
- _frontend_packages = []
- for package in frontend_packages:
- if package in (get_config().tailwind or {}).get("plugins", []): # type: ignore
- console.warn(
- f"Tailwind packages are inferred from 'plugins', remove `{package}` from `frontend_packages`"
- )
- continue
- if package in page_imports:
- console.warn(
- f"React packages and their dependencies are inferred from Component.library and Component.lib_dependencies, remove `{package}` from `frontend_packages`"
- )
- continue
- _frontend_packages.append(package)
- page_imports.update(_frontend_packages)
- prerequisites.install_frontend_packages(page_imports)
- def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component:
- for component in tuple(app_wrappers.values()):
- app_wrappers.update(component.get_app_wrap_components())
- order = sorted(app_wrappers, key=lambda k: k[0], reverse=True)
- root = parent = copy.deepcopy(app_wrappers[order[0]])
- for key in order[1:]:
- child = copy.deepcopy(app_wrappers[key])
- parent.children.append(child)
- parent = child
- return root
- def _should_compile(self) -> bool:
- """Check if the app should be compiled.
- Returns:
- Whether the app should be compiled.
- """
- # Check the environment variable.
- if os.environ.get(constants.SKIP_COMPILE_ENV_VAR) == "yes":
- return False
- # Check the nocompile file.
- if os.path.exists(constants.NOCOMPILE_FILE):
- # Delete the nocompile file
- os.remove(constants.NOCOMPILE_FILE)
- return False
- # By default, compile the app.
- return True
- def compile(self):
- """Compile the app and output it to the pages folder."""
- # add the pages before the compile check so App know onload methods
- for render, kwargs in DECORATED_PAGES:
- self.add_page(render, **kwargs)
- # Render a default 404 page if the user didn't supply one
- if constants.Page404.SLUG not in self.pages:
- self.add_custom_404_page()
- if not self._should_compile():
- return
- # Create a progress bar.
- progress = Progress(
- *Progress.get_default_columns()[:-1],
- MofNCompleteColumn(),
- TimeElapsedColumn(),
- )
- # Get the env mode.
- config = get_config()
- # Store the compile results.
- compile_results = []
- # Compile the pages in parallel.
- custom_components = set()
- # TODO Anecdotally, processes=2 works 10% faster (cpu_count=12)
- all_imports = {}
- app_wrappers: Dict[tuple[int, str], Component] = {
- # Default app wrap component renders {children}
- (0, "AppWrap"): AppWrap.create()
- }
- if self.theme is not None:
- # If a theme component was provided, wrap the app with it
- app_wrappers[(20, "Theme")] = self.theme
- with progress, concurrent.futures.ThreadPoolExecutor() as thread_pool:
- fixed_pages = 7
- task = progress.add_task("Compiling:", total=len(self.pages) + fixed_pages)
- def mark_complete(_=None):
- progress.advance(task)
- for _route, component in self.pages.items():
- # Merge the component style with the app style.
- component.add_style(self.style)
- if self.theme is not None:
- component.apply_theme(self.theme)
- # Add component.get_imports() to all_imports.
- all_imports.update(component.get_imports())
- # Add the app wrappers from this component.
- app_wrappers.update(component.get_app_wrap_components())
- # Add the custom components from the page to the set.
- custom_components |= component.get_custom_components()
- # Perform auto-memoization of stateful components.
- (
- stateful_components_path,
- stateful_components_code,
- page_components,
- ) = compiler.compile_stateful_components(self.pages.values())
- compile_results.append((stateful_components_path, stateful_components_code))
- result_futures = []
- def submit_work(fn, *args, **kwargs):
- """Submit work to the thread pool and add a callback to mark the task as complete.
- The Future will be added to the `result_futures` list.
- Args:
- fn: The function to submit.
- *args: The args to submit.
- **kwargs: The kwargs to submit.
- """
- f = thread_pool.submit(fn, *args, **kwargs)
- f.add_done_callback(mark_complete)
- result_futures.append(f)
- # Compile all page components.
- for route, component in zip(self.pages, page_components):
- submit_work(
- compiler.compile_page,
- route,
- component,
- self.state,
- )
- # Compile the app wrapper.
- app_root = self._app_root(app_wrappers=app_wrappers)
- submit_work(compiler.compile_app, app_root)
- # Compile the custom components.
- submit_work(compiler.compile_components, custom_components)
- # Compile the root stylesheet with base styles.
- submit_work(compiler.compile_root_stylesheet, self.stylesheets)
- # Compile the root document.
- submit_work(compiler.compile_document_root, self.head_components)
- # Compile the theme.
- submit_work(compiler.compile_theme, style=self.style)
- # Compile the contexts.
- submit_work(compiler.compile_contexts, self.state)
- # Compile the Tailwind config.
- if config.tailwind is not None:
- config.tailwind["content"] = config.tailwind.get(
- "content", constants.Tailwind.CONTENT
- )
- submit_work(compiler.compile_tailwind, config.tailwind)
- # Get imports from AppWrap components.
- all_imports.update(app_root.get_imports())
- # Iterate through all the custom components and add their imports to the all_imports.
- for component in custom_components:
- all_imports.update(component.get_imports())
- # Wait for all compilation tasks to complete.
- for future in concurrent.futures.as_completed(result_futures):
- compile_results.append(future.result())
- # Empty the .web pages directory.
- compiler.purge_web_pages_dir()
- # Avoid flickering when installing frontend packages
- progress.stop()
- # Install frontend packages.
- self.get_frontend_packages(all_imports)
- # Write the pages at the end to trigger the NextJS hot reload only once.
- write_page_futures = []
- for output_path, code in compile_results:
- write_page_futures.append(
- thread_pool.submit(compiler_utils.write_page, output_path, code)
- )
- for future in concurrent.futures.as_completed(write_page_futures):
- future.result()
- @contextlib.asynccontextmanager
- async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
- """Modify the state out of band.
- Args:
- token: The token to modify the state for.
- Yields:
- The state to modify.
- Raises:
- RuntimeError: If the app has not been initialized yet.
- """
- if self.event_namespace is None:
- raise RuntimeError("App has not been initialized yet.")
- # Get exclusive access to the state.
- async with self.state_manager.modify_state(token) as state:
- # No other event handler can modify the state while in this context.
- yield state
- delta = state.get_delta()
- if delta:
- # When the state is modified reset dirty status and emit the delta to the frontend.
- state._clean()
- await self.event_namespace.emit_update(
- update=StateUpdate(delta=delta),
- sid=state.router.session.session_id,
- )
- def _process_background(
- self, state: BaseState, event: Event
- ) -> asyncio.Task | None:
- """Process an event in the background and emit updates as they arrive.
- Args:
- state: The state to process the event for.
- event: The event to process.
- Returns:
- Task if the event was backgroundable, otherwise None
- """
- substate, handler = state._get_event_handler(event)
- if not handler.is_background:
- return None
- async def _coro():
- """Coroutine to process the event and emit updates inside an asyncio.Task.
- Raises:
- RuntimeError: If the app has not been initialized yet.
- """
- if self.event_namespace is None:
- raise RuntimeError("App has not been initialized yet.")
- # Process the event.
- async for update in state._process_event(
- handler=handler, state=substate, payload=event.payload
- ):
- # Postprocess the event.
- update = await self.postprocess(state, event, update)
- # Send the update to the client.
- await self.event_namespace.emit_update(
- update=update,
- sid=state.router.session.session_id,
- )
- task = asyncio.create_task(_coro())
- self.background_tasks.add(task)
- # Clean up task from background_tasks set when complete.
- task.add_done_callback(self.background_tasks.discard)
- return task
- async def process(
- app: App, event: Event, sid: str, headers: Dict, client_ip: str
- ) -> AsyncIterator[StateUpdate]:
- """Process an event.
- Args:
- app: The app to process the event for.
- event: The event to process.
- sid: The Socket.IO session id.
- headers: The client headers.
- client_ip: The client_ip.
- Yields:
- The state updates after processing the event.
- """
- # Add request data to the state.
- router_data = event.router_data
- router_data.update(
- {
- constants.RouteVar.QUERY: format.format_query_params(event.router_data),
- constants.RouteVar.CLIENT_TOKEN: event.token,
- constants.RouteVar.SESSION_ID: sid,
- constants.RouteVar.HEADERS: headers,
- constants.RouteVar.CLIENT_IP: client_ip,
- }
- )
- # Get the state for the session exclusively.
- async with app.state_manager.modify_state(event.token) as state:
- # re-assign only when the value is different
- if state.router_data != router_data:
- # assignment will recurse into substates and force recalculation of
- # dependent ComputedVar (dynamic route variables)
- state.router_data = router_data
- state.router = RouterData(router_data)
- # Preprocess the event.
- update = await app.preprocess(state, event)
- # If there was an update, yield it.
- if update is not None:
- yield update
- # Only process the event if there is no update.
- else:
- if app._process_background(state, event) is not None:
- # `final=True` allows the frontend send more events immediately.
- yield StateUpdate(final=True)
- return
- # Process the event synchronously.
- async for update in state._process(event):
- # Postprocess the event.
- update = await app.postprocess(state, event, update)
- # Yield the update.
- yield update
- async def ping() -> str:
- """Test API endpoint.
- Returns:
- The response.
- """
- return "pong"
- def upload(app: App):
- """Upload a file.
- Args:
- app: The app to upload the file for.
- Returns:
- The upload function.
- """
- async def upload_file(request: Request, files: List[UploadFile]):
- """Upload a file.
- Args:
- request: The FastAPI request object.
- files: The file(s) to upload.
- Returns:
- StreamingResponse yielding newline-delimited JSON of StateUpdate
- emitted by the upload handler.
- Raises:
- ValueError: if there are no args with supported annotation.
- TypeError: if a background task is used as the handler.
- HTTPException: when the request does not include token / handler headers.
- """
- token = request.headers.get("reflex-client-token")
- handler = request.headers.get("reflex-event-handler")
- if not token or not handler:
- raise HTTPException(
- status_code=400,
- detail="Missing reflex-client-token or reflex-event-handler header.",
- )
- # Get the state for the session.
- state = await app.state_manager.get_state(token)
- # get the current session ID
- # get the current state(parent state/substate)
- path = handler.split(".")[:-1]
- current_state = state.get_substate(path)
- handler_upload_param = ()
- # get handler function
- func = getattr(type(current_state), handler.split(".")[-1])
- # check if there exists any handler args with annotation, List[UploadFile]
- if isinstance(func, EventHandler):
- if func.is_background:
- raise TypeError(
- f"@rx.background is not supported for upload handler `{handler}`.",
- )
- func = func.fn
- if isinstance(func, functools.partial):
- func = func.func
- for k, v in get_type_hints(func).items():
- if types.is_generic_alias(v) and types._issubclass(
- get_args(v)[0],
- UploadFile,
- ):
- handler_upload_param = (k, v)
- break
- if not handler_upload_param:
- raise ValueError(
- f"`{handler}` handler should have a parameter annotated as "
- "List[rx.UploadFile]"
- )
- event = Event(
- token=token,
- name=handler,
- payload={handler_upload_param[0]: files},
- )
- async def _ndjson_updates():
- """Process the upload event, generating ndjson updates.
- Yields:
- Each state update as JSON followed by a new line.
- """
- # Process the event.
- async with app.state_manager.modify_state(token) as state:
- async for update in state._process(event):
- # Postprocess the event.
- update = await app.postprocess(state, event, update)
- yield update.json() + "\n"
- # Stream updates to client
- return StreamingResponse(
- _ndjson_updates(),
- media_type="application/x-ndjson",
- )
- return upload_file
- class EventNamespace(AsyncNamespace):
- """The event namespace."""
- # The application object.
- app: App
- def __init__(self, namespace: str, app: App):
- """Initialize the event namespace.
- Args:
- namespace: The namespace.
- app: The application object.
- """
- super().__init__(namespace)
- self.app = app
- def on_connect(self, sid, environ):
- """Event for when the websocket is connected.
- Args:
- sid: The Socket.IO session id.
- environ: The request information, including HTTP headers.
- """
- pass
- def on_disconnect(self, sid):
- """Event for when the websocket disconnects.
- Args:
- sid: The Socket.IO session id.
- """
- pass
- async def emit_update(self, update: StateUpdate, sid: str) -> None:
- """Emit an update to the client.
- Args:
- update: The state update to send.
- sid: The Socket.IO session id.
- """
- # Creating a task prevents the update from being blocked behind other coroutines.
- await asyncio.create_task(
- self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
- )
- async def on_event(self, sid, data):
- """Event for receiving front-end websocket events.
- Args:
- sid: The Socket.IO session id.
- data: The event data.
- """
- # Get the event.
- event = Event.parse_raw(data)
- # Get the event environment.
- assert self.app.sio is not None
- environ = self.app.sio.get_environ(sid, self.namespace)
- assert environ is not None
- # Get the client headers.
- headers = {
- k.decode("utf-8"): v.decode("utf-8")
- for (k, v) in environ["asgi.scope"]["headers"]
- }
- # Get the client IP
- client_ip = environ["REMOTE_ADDR"]
- # Process the events.
- async for update in process(self.app, event, sid, headers, client_ip):
- # Emit the update from processing the event.
- await self.emit_update(update=update, sid=sid)
- async def on_ping(self, sid):
- """Event for testing the API endpoint.
- Args:
- sid: The Socket.IO session id.
- """
- # Emit the test event.
- await self.emit(str(constants.SocketEvent.PING), "pong", to=sid)
|