from __future__ import annotations import functools import io import json import os.path import re import unittest.mock import uuid from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import Generator, List, Tuple, Type, cast from unittest.mock import AsyncMock import pytest import sqlmodel from fastapi import FastAPI, UploadFile from starlette_admin.auth import AuthProvider from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView import reflex as rx from reflex import AdminDash, constants from reflex.app import ( App, ComponentCallable, OverlayFragment, default_overlay_component, process, upload, ) from reflex.components import Component from reflex.components.base.fragment import Fragment from reflex.components.radix.themes.typography.text import Text from reflex.event import Event, EventHandler from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( BaseState, OnLoadInternalState, RouterData, State, StateManagerDisk, StateManagerMemory, StateManagerRedis, StateUpdate, _substate_key, ) from reflex.style import Style from reflex.utils import exceptions, format from reflex.vars.base import computed_var from .conftest import chdir from .states import ( ChildFileUploadState, FileStateBase1, FileUploadState, GenState, GrandChildFileUploadState, ) class EmptyState(BaseState): """An empty state.""" pass @pytest.fixture def index_page() -> ComponentCallable: """An index page. Returns: The index page. """ def index(): return rx.box("Index") return index @pytest.fixture def about_page() -> ComponentCallable: """An about page. Returns: The about page. """ def about(): return rx.box("About") return about class ATestState(BaseState): """A simple state for testing.""" var: int @pytest.fixture() def test_state() -> Type[BaseState]: """A default state. Returns: A default state. """ return ATestState @pytest.fixture() def redundant_test_state() -> Type[BaseState]: """A default state. Returns: A default state. """ class RedundantTestState(BaseState): var: int return RedundantTestState @pytest.fixture(scope="session") def test_model() -> Type[Model]: """A default model. Returns: A default model. """ class TestModel(Model, table=True): pass return TestModel @pytest.fixture(scope="session") def test_model_auth() -> Type[Model]: """A default model. Returns: A default model. """ class TestModelAuth(Model, table=True): """A test model with auth.""" pass return TestModelAuth @pytest.fixture() def test_get_engine(): """A default database engine. Returns: A default database engine. """ enable_admin = True url = "sqlite:///test.db" return sqlmodel.create_engine( url, echo=False, connect_args={"check_same_thread": False} if enable_admin else {}, ) @pytest.fixture() def test_custom_auth_admin() -> Type[AuthProvider]: """A default auth provider. Returns: A default default auth provider. """ class TestAuthProvider(AuthProvider): """A test auth provider.""" login_path: str = "/login" logout_path: str = "/logout" def login(self): # pyright: ignore [reportIncompatibleMethodOverride] """Login.""" pass def is_authenticated(self): # pyright: ignore [reportIncompatibleMethodOverride] """Is authenticated.""" pass def get_admin_user(self): # pyright: ignore [reportIncompatibleMethodOverride] """Get admin user.""" pass def logout(self): # pyright: ignore [reportIncompatibleMethodOverride] """Logout.""" pass return TestAuthProvider def test_default_app(app: App): """Test creating an app with no args. Args: app: The app to test. """ assert app.middleware == [HydrateMiddleware()] assert app.style == Style() assert app.admin_dash is None def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): """Test that an error is thrown when multiple classes subclass rx.BaseState. Args: monkeypatch: Pytest monkeypatch object. test_state: A test state subclassing rx.BaseState. redundant_test_state: Another test state subclassing rx.BaseState. """ monkeypatch.delenv(constants.PYTEST_CURRENT_TEST) with pytest.raises(ValueError): App() def test_add_page_default_route(app: App, index_page, about_page): """Test adding a page to an app. Args: app: The app to test. index_page: The index page. about_page: The about page. """ assert app._pages == {} assert app._unevaluated_pages == {} app.add_page(index_page) app._compile_page("index") assert app._pages.keys() == {"index"} app.add_page(about_page) app._compile_page("about") assert app._pages.keys() == {"index", "about"} def test_add_page_set_route(app: App, index_page, windows_platform: bool): """Test adding a page to an app. Args: app: The app to test. index_page: The index page. windows_platform: Whether the system is windows. """ route = "test" if windows_platform else "/test" assert app._unevaluated_pages == {} app.add_page(index_page, route=route) app._compile_page("test") assert app._pages.keys() == {"test"} def test_add_page_set_route_dynamic(index_page, windows_platform: bool): """Test adding a page with dynamic route variable to an app. Args: index_page: The index page. windows_platform: Whether the system is windows. """ app = App(_state=EmptyState) assert app._state is not None route = "/test/[dynamic]" assert app._unevaluated_pages == {} app.add_page(index_page, route=route) app._compile_page("test/[dynamic]") assert app._pages.keys() == {"test/[dynamic]"} assert "dynamic" in app._state.computed_vars assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { EmptyState.get_full_name(): {constants.ROUTER}, } assert constants.ROUTER in app._state()._var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): """Test adding a page to an app. Args: app: The app to test. index_page: The index page. windows_platform: Whether the system is windows. """ route = "test\\nested" if windows_platform else "/test/nested" assert app._unevaluated_pages == {} app.add_page(index_page, route=route) assert app._unevaluated_pages.keys() == {route.strip(os.path.sep)} def test_add_page_invalid_api_route(app: App, index_page): """Test adding a page with an invalid route to an app. Args: app: The app to test. index_page: The index page. """ with pytest.raises(ValueError): app.add_page(index_page, route="api") with pytest.raises(ValueError): app.add_page(index_page, route="/api") with pytest.raises(ValueError): app.add_page(index_page, route="/api/") with pytest.raises(ValueError): app.add_page(index_page, route="api/foo") with pytest.raises(ValueError): app.add_page(index_page, route="/api/foo") # These should be fine app.add_page(index_page, route="api2") app.add_page(index_page, route="/foo/api") def page1(): return rx.fragment() def page2(): return rx.fragment() def index(): return rx.fragment() @pytest.mark.parametrize( "first_page,second_page, route", [ (lambda: rx.fragment(), lambda: rx.fragment(rx.text("second")), "/"), (rx.fragment(rx.text("first")), rx.fragment(rx.text("second")), "/page1"), ( lambda: rx.fragment(rx.text("first")), rx.fragment(rx.text("second")), "page3", ), (page1, page2, "page1"), (index, index, None), (page1, page1, None), ], ) def test_add_duplicate_page_route_error(app, first_page, second_page, route): app.add_page(first_page, route=route) with pytest.raises(ValueError): app.add_page(second_page, route="/" + route.strip("/") if route else None) def test_initialize_with_admin_dashboard(test_model): """Test setting the admin dashboard of an app. Args: test_model: The default model. """ app = App(admin_dash=AdminDash(models=[test_model])) assert app.admin_dash is not None assert len(app.admin_dash.models) > 0 assert app.admin_dash.models[0] == test_model def test_initialize_with_custom_admin_dashboard( test_get_engine, test_custom_auth_admin, test_model_auth, ): """Test setting the custom admin dashboard of an app. Args: test_get_engine: The default database engine. test_model_auth: The default model for an auth admin dashboard. test_custom_auth_admin: The custom auth provider. """ custom_auth_provider = test_custom_auth_admin() custom_admin = Admin(engine=test_get_engine, auth_provider=custom_auth_provider) app = App(admin_dash=AdminDash(models=[test_model_auth], admin=custom_admin)) assert app.admin_dash is not None assert app.admin_dash.admin is not None assert len(app.admin_dash.models) > 0 assert app.admin_dash.models[0] == test_model_auth assert app.admin_dash.admin.auth_provider == custom_auth_provider def test_initialize_admin_dashboard_with_view_overrides(test_model): """Test setting the admin dashboard of an app with view class overridden. Args: test_model: The default model. """ class TestModelView(ModelView): pass app = App( admin_dash=AdminDash( models=[test_model], view_overrides={test_model: TestModelView} ) ) assert app.admin_dash is not None assert app.admin_dash.models == [test_model] assert app.admin_dash.view_overrides[test_model] == TestModelView @pytest.mark.asyncio async def test_initialize_with_state(test_state: Type[ATestState], token: str): """Test setting the state of an app. Args: test_state: The default state. token: a Token. """ app = App(_state=test_state) assert app._state == test_state # Get a state for a given token. state = await app.state_manager.get_state(_substate_key(token, test_state)) assert isinstance(state, test_state) assert state.var == 0 if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.asyncio async def test_set_and_get_state(test_state): """Test setting and getting the state of an app with different tokens. Args: test_state: The default state. """ app = App(_state=test_state) # Create two tokens. token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" # Get the default state for each token. state1 = await app.state_manager.get_state(token1) state2 = await app.state_manager.get_state(token2) assert state1.var == 0 assert state2.var == 0 # Set the vars to different values. state1.var = 1 state2.var = 2 await app.state_manager.set_state(token1, state1) await app.state_manager.set_state(token2, state2) # Get the states again and check the values. state1 = await app.state_manager.get_state(token1) state2 = await app.state_manager.get_state(token2) assert state1.var == 1 assert state2.var == 2 if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.asyncio async def test_dynamic_var_event(test_state: Type[ATestState], token: str): """Test that the default handler of a dynamic generated var works as expected. Args: test_state: State Fixture. token: a Token. """ state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) async for result in state._process( Event( token=token, name=f"{test_state.get_name()}.set_int_val", router_data={"pathname": "/", "query": {}}, payload={"value": 50}, ) ): assert result.delta == {test_state.get_name(): {"int_val": 50}} @pytest.mark.asyncio @pytest.mark.parametrize( "event_tuples", [ pytest.param( [ ( "make_friend", {"plain_friends": ["Tommy", "another-fd"]}, ), ( "change_first_friend", {"plain_friends": ["Jenny", "another-fd"]}, ), ], id="append then __setitem__", ), pytest.param( [ ( "unfriend_first_friend", {"plain_friends": []}, ), ( "make_friend", {"plain_friends": ["another-fd"]}, ), ], id="delitem then append", ), pytest.param( [ ( "make_friends_with_colleagues", {"plain_friends": ["Tommy", "Peter", "Jimmy"]}, ), ( "remove_tommy", {"plain_friends": ["Peter", "Jimmy"]}, ), ( "remove_last_friend", {"plain_friends": ["Peter"]}, ), ( "unfriend_all_friends", {"plain_friends": []}, ), ], id="extend, remove, pop, clear", ), pytest.param( [ ( "add_jimmy_to_second_group", {"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]}, ), ( "remove_first_person_from_first_group", {"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]}, ), ( "remove_first_group", {"friends_in_nested_list": [["Jenny", "Jimmy"]]}, ), ], id="nested list", ), pytest.param( [ ( "add_jimmy_to_tommy_friends", {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}, ), ( "remove_jenny_from_tommy", {"friends_in_dict": {"Tommy": ["Jimmy"]}}, ), ( "tommy_has_no_fds", {"friends_in_dict": {"Tommy": []}}, ), ], id="list in dict", ), ], ) async def test_list_mutation_detection__plain_list( event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State, token: str, ): """Test list mutation detection when reassignment is not explicitly included in the logic. Args: event_tuples: From parametrization. list_mutation_state: A state with list mutation features. token: a Token. """ for event_name, expected_delta in event_tuples: async for result in list_mutation_state._process( Event( token=token, name=f"{list_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) ): # prefix keys in expected_delta with the state name expected_delta = {list_mutation_state.get_name(): expected_delta} assert result.delta == expected_delta @pytest.mark.asyncio @pytest.mark.parametrize( "event_tuples", [ pytest.param( [ ( "add_age", {"details": {"name": "Tommy", "age": 20}}, ), ( "change_name", {"details": {"name": "Jenny", "age": 20}}, ), ( "remove_last_detail", {"details": {"name": "Jenny"}}, ), ], id="update then __setitem__", ), pytest.param( [ ( "clear_details", {"details": {}}, ), ( "add_age", {"details": {"age": 20}}, ), ], id="delitem then update", ), pytest.param( [ ( "add_age", {"details": {"name": "Tommy", "age": 20}}, ), ( "remove_name", {"details": {"age": 20}}, ), ( "pop_out_age", {"details": {}}, ), ], id="add, remove, pop", ), pytest.param( [ ( "remove_home_address", {"address": [{}, {"work": "work address"}]}, ), ( "add_street_to_home_address", { "address": [ {"street": "street address"}, {"work": "work address"}, ] }, ), ], id="dict in list", ), pytest.param( [ ( "change_friend_name", { "friend_in_nested_dict": { "name": "Nikhil", "friend": {"name": "Tommy"}, } }, ), ( "add_friend_age", { "friend_in_nested_dict": { "name": "Nikhil", "friend": {"name": "Tommy", "age": 30}, } }, ), ( "remove_friend", {"friend_in_nested_dict": {"name": "Nikhil"}}, ), ], id="nested dict", ), ], ) async def test_dict_mutation_detection__plain_list( event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State, token: str, ): """Test dict mutation detection when reassignment is not explicitly included in the logic. Args: event_tuples: From parametrization. dict_mutation_state: A state with dict mutation features. token: a Token. """ for event_name, expected_delta in event_tuples: async for result in dict_mutation_state._process( Event( token=token, name=f"{dict_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) ): # prefix keys in expected_delta with the state name expected_delta = {dict_mutation_state.get_name(): expected_delta} assert result.delta == expected_delta @pytest.mark.asyncio @pytest.mark.parametrize( ("state", "delta"), [ ( FileUploadState, { FileUploadState.get_full_name(): { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ( ChildFileUploadState, { ChildFileUploadState.get_full_name(): { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ( GrandChildFileUploadState, { GrandChildFileUploadState.get_full_name(): { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ], ) async def test_upload_file(tmp_path, state, delta, token: str, mocker): """Test that file upload works correctly. Args: tmp_path: Temporary path. state: The state class. delta: Expected delta token: a Token. mocker: pytest mocker object. """ mocker.patch( "reflex.state.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) state._tmp_path = tmp_path # The App state must be the "root" of the state tree app = App() app._enable_state() app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] current_state = await app.state_manager.get_state(_substate_key(token, state)) data = b"This is binary data" # Create a binary IO object and write data to it bio = io.BytesIO() bio.write(data) request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, "reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload", } file1 = UploadFile( filename="image1.jpg", file=bio, ) file2 = UploadFile( filename="image2.jpg", file=bio, ) upload_fn = upload(app) streaming_response = await upload_fn(request_mock, [file1, file2]) # pyright: ignore [reportFunctionMemberAccess] async for state_update in streaming_response.body_iterator: assert ( state_update == StateUpdate(delta=delta, events=[], final=True).json() + "\n" ) current_state = await app.state_manager.get_state(_substate_key(token, state)) state_dict = current_state.dict()[state.get_full_name()] assert state_dict["img_list"] == [ "image1.jpg", "image2.jpg", ] if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.asyncio @pytest.mark.parametrize( "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) async def test_upload_file_without_annotation(state, tmp_path, token): """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile]. Args: state: The state class. tmp_path: Temporary path. token: a Token. """ state._tmp_path = tmp_path app = App(_state=State) request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, "reflex-event-handler": f"{state.get_full_name()}.handle_upload2", } file_mock = unittest.mock.Mock(filename="image1.jpg") fn = upload(app) with pytest.raises(ValueError) as err: await fn(request_mock, [file_mock]) assert ( err.value.args[0] == f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" ) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.asyncio @pytest.mark.parametrize( "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) async def test_upload_file_background(state, tmp_path, token): """Test that an error is thrown handler is a background task. Args: state: The state class. tmp_path: Temporary path. token: a Token. """ state._tmp_path = tmp_path app = App(_state=State) request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, "reflex-event-handler": f"{state.get_full_name()}.bg_upload", } file_mock = unittest.mock.Mock(filename="image1.jpg") fn = upload(app) with pytest.raises(TypeError) as err: await fn(request_mock, [file_mock]) assert ( err.value.args[0] == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`." ) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() class DynamicState(BaseState): """State class for testing dynamic route var. This is defined at module level because event handlers cannot be addressed correctly when the class is defined as a local. There are several counters: * loaded: counts how many times `on_load` was triggered by the hydrate middleware * counter: counts how many times `on_counter` was triggered by a non-navigational event -> these events should NOT trigger reload or recalculation of router_data dependent vars * side_effect_counter: counts how many times a computed var was recalculated when the dynamic route var was dirty """ is_hydrated: bool = False loaded: int = 0 counter: int = 0 @rx.event def on_load(self): """Event handler for page on_load, should trigger for all navigation events.""" self.loaded = self.loaded + 1 @rx.event def on_counter(self): """Increment the counter var.""" self.counter = self.counter + 1 @computed_var def comp_dynamic(self) -> str: """A computed var that depends on the dynamic var. Returns: same as self.dynamic """ return self.dynamic on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn def test_dynamic_arg_shadow( index_page: ComponentCallable, windows_platform: bool, token: str, app_module_mock: unittest.mock.Mock, mocker, ): """Create app with dynamic route var and try to add a page with a dynamic arg that shadows a state var. Args: index_page: The index page. windows_platform: Whether the system is windows. token: a Token. app_module_mock: Mocked app module. mocker: pytest mocker object. """ arg_name = "counter" route = f"/test/[{arg_name}]" app = app_module_mock.app = App(_state=DynamicState) assert app._state is not None with pytest.raises(NameError): app.add_page(index_page, route=route, on_load=DynamicState.on_load) def test_multiple_dynamic_args( index_page: ComponentCallable, windows_platform: bool, token: str, app_module_mock: unittest.mock.Mock, mocker, ): """Create app with multiple dynamic route vars with the same name. Args: index_page: The index page. windows_platform: Whether the system is windows. token: a Token. app_module_mock: Mocked app module. mocker: pytest mocker object. """ arg_name = "my_arg" route = f"/test/[{arg_name}]" route2 = f"/test2/[{arg_name}]" app = app_module_mock.app = App(_state=EmptyState) app.add_page(index_page, route=route) app.add_page(index_page, route=route2) @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( index_page: ComponentCallable, windows_platform: bool, token: str, app_module_mock: unittest.mock.Mock, mocker, ): """Create app with dynamic route var, and simulate navigation. on_load should fire, allowing any additional vars to be updated before the initial page hydrate. Args: index_page: The index page. windows_platform: Whether the system is windows. token: a Token. app_module_mock: Mocked app module. mocker: pytest mocker object. """ arg_name = "dynamic" route = f"/test/[{arg_name}]" app = app_module_mock.app = App(_state=DynamicState) assert app._state is not None assert arg_name not in app._state.vars app.add_page(index_page, route=route, on_load=DynamicState.on_load) assert arg_name in app._state.vars assert arg_name in app._state.computed_vars assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == { DynamicState.get_full_name(): {constants.ROUTER}, } assert constants.ROUTER in app._state()._var_dependencies substate_token = _substate_key(token, DynamicState) sid = "mock_sid" client_ip = "127.0.0.1" async with app.state_manager.modify_state(substate_token) as state: state.router_data = {"simulate": "hydrated"} assert state.dynamic == "" exp_vals = ["foo", "foobar", "baz"] def _event(name, val, **kwargs): return Event( token=kwargs.pop("token", token), name=name, router_data=kwargs.pop( "router_data", {"pathname": route, "query": {arg_name: val}} ), payload=kwargs.pop("payload", {}), **kwargs, ) def _dynamic_state_event(name, val, **kwargs): return _event( name=format.format_event_handler(getattr(DynamicState, name)), val=val, **kwargs, ) prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): on_load_internal = _event( name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", val=exp_val, ) exp_router_data = { "headers": {}, "ip": client_ip, "sid": sid, "token": token, **on_load_internal.router_data, } exp_router = RouterData(exp_router_data) process_coro = process( app, event=on_load_internal, sid=sid, headers={}, client_ip=client_ip, ) update = await process_coro.__anext__() # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)] assert update == StateUpdate( delta={ state.get_name(): { arg_name: exp_val, f"comp_{arg_name}": exp_val, constants.CompileVars.IS_HYDRATED: False, "router": exp_router, } }, events=[ _dynamic_state_event( name="on_load", val=exp_val, ), _event( name=f"{State.get_name()}.set_is_hydrated", payload={"value": True}, val=exp_val, router_data={}, ), ], ) if isinstance(app.state_manager, StateManagerRedis): # When redis is used, the state is not updated until the processing is complete state = await app.state_manager.get_state(substate_token) assert state.dynamic == prev_exp_val # complete the processing with pytest.raises(StopAsyncIteration): await process_coro.__anext__() # check that router data was written to the state_manager store state = await app.state_manager.get_state(substate_token) assert state.dynamic == exp_val process_coro = process( app, event=_dynamic_state_event(name="on_load", val=exp_val), sid=sid, headers={}, client_ip=client_ip, ) on_load_update = await process_coro.__anext__() assert on_load_update == StateUpdate( delta={ state.get_name(): { "loaded": exp_index + 1, }, }, events=[], ) # complete the processing with pytest.raises(StopAsyncIteration): await process_coro.__anext__() process_coro = process( app, event=_dynamic_state_event( name="set_is_hydrated", payload={"value": True}, val=exp_val ), sid=sid, headers={}, client_ip=client_ip, ) on_set_is_hydrated_update = await process_coro.__anext__() assert on_set_is_hydrated_update == StateUpdate( delta={ state.get_name(): { "is_hydrated": True, }, }, events=[], ) # complete the processing with pytest.raises(StopAsyncIteration): await process_coro.__anext__() # a simple state update event should NOT trigger on_load or route var side effects process_coro = process( app, event=_dynamic_state_event(name="on_counter", val=exp_val), sid=sid, headers={}, client_ip=client_ip, ) update = await process_coro.__anext__() assert update == StateUpdate( delta={ state.get_name(): { "counter": exp_index + 1, } }, events=[], ) # complete the processing with pytest.raises(StopAsyncIteration): await process_coro.__anext__() prev_exp_val = exp_val state = await app.state_manager.get_state(substate_token) assert state.loaded == len(exp_vals) assert state.counter == len(exp_vals) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.asyncio async def test_process_events(mocker, token: str): """Test that an event is processed properly and that it is postprocessed n+1 times. Also check that the processing flag of the last stateupdate is set to False. Args: mocker: mocker object. token: a Token. """ router_data = { "pathname": "/", "query": {}, "token": token, "sid": "mock_sid", "headers": {}, "ip": "127.0.0.1", } app = App(_state=GenState) mocker.patch.object(app, "_postprocess", AsyncMock()) event = Event( token=token, name=f"{GenState.get_name()}.go", payload={"c": 5}, router_data=router_data, ) async with app.state_manager.modify_state(event.substate_token) as state: state.router_data = {"simulate": "hydrated"} async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): pass assert (await app.state_manager.get_state(event.substate_token)).value == 5 assert getattr(app._postprocess, "call_count", None) == 6 if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @pytest.mark.parametrize( ("state", "overlay_component", "exp_page_child"), [ (None, default_overlay_component, None), (None, None, None), (None, Text.create("foo"), Text), (State, default_overlay_component, Fragment), (State, None, None), (State, Text.create("foo"), Text), (State, lambda: Text.create("foo"), Text), ], ) def test_overlay_component( state: Type[State] | None, overlay_component: Component | ComponentCallable | None, exp_page_child: Type[Component] | None, ): """Test that the overlay component is set correctly. Args: state: The state class to pass to App. overlay_component: The overlay_component to pass to App. exp_page_child: The type of the expected child in the page fragment. """ app = App(_state=state, overlay_component=overlay_component) app._setup_overlay_component() if exp_page_child is None: assert app.overlay_component is None elif isinstance(exp_page_child, OverlayFragment): assert app.overlay_component is not None generated_component = app._generate_component(app.overlay_component) assert isinstance(generated_component, OverlayFragment) else: assert app.overlay_component is not None assert isinstance( app._generate_component(app.overlay_component), exp_page_child, ) app.add_page(rx.box("Index"), route="/test") # overlay components are wrapped during compile only app._compile_page("test") app._setup_overlay_component() page = app._pages["test"] if exp_page_child is not None: assert len(page.children) == 3 children_types = [type(child) for child in page.children] assert exp_page_child in children_types else: assert len(page.children) == 2 @pytest.fixture def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: """Fixture for an app that can be compiled. Args: tmp_path: Temporary path. Yields: Tuple containing (app instance, Path to ".web" directory) The working directory is set to the app dir (parent of .web), allowing app.compile() to be called. """ app_path = tmp_path / "app" web_dir = app_path / ".web" web_dir.mkdir(parents=True) (web_dir / constants.PackageJson.PATH).touch() app = App(theme=None) app._get_frontend_packages = unittest.mock.Mock() with chdir(app_path): yield app, web_dir @pytest.mark.parametrize( "react_strict_mode", [True, False], ) def test_app_wrap_compile_theme( react_strict_mode: bool, compilable_app: tuple[App, Path], mocker ): """Test that the radix theme component wraps the app. Args: react_strict_mode: Whether to use React Strict Mode. compilable_app: compilable_app fixture. mocker: pytest mocker object. """ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode) mocker.patch("reflex.config._get_config", return_value=conf) app, web_dir = compilable_app app.theme = rx.theme(accent_color="plum") app._compile() app_js_contents = (web_dir / "pages" / "_app.js").read_text() app_js_lines = [ line.strip() for line in app_js_contents.splitlines() if line.strip() ] lines = "".join(app_js_lines) assert ( "function AppWrap({children}) {" "return (" + ("" if react_strict_mode else "") + "" "" "" "" "" "{children}" "" "" "" "" + ("" if react_strict_mode else "") + ")" "}" ) in lines @pytest.mark.parametrize( "react_strict_mode", [True, False], ) def test_app_wrap_priority( react_strict_mode: bool, compilable_app: tuple[App, Path], mocker ): """Test that the app wrap components are wrapped in the correct order. Args: react_strict_mode: Whether to use React Strict Mode. compilable_app: compilable_app fixture. mocker: pytest mocker object. """ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode) mocker.patch("reflex.config._get_config", return_value=conf) app, web_dir = compilable_app class Fragment1(Component): tag = "Fragment1" def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride] return {(99, "Box"): rx.box()} class Fragment2(Component): tag = "Fragment2" def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride] return {(50, "Text"): rx.text()} class Fragment3(Component): tag = "Fragment3" def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride] return {(10, "Fragment2"): Fragment2.create()} def page(): return Fragment1.create(Fragment3.create()) app.add_page(page) app._compile() app_js_contents = (web_dir / "pages" / "_app.js").read_text() app_js_lines = [ line.strip() for line in app_js_contents.splitlines() if line.strip() ] lines = "".join(app_js_lines) assert ( "function AppWrap({children}) {" "return (" + ("" if react_strict_mode else "") + "" '' "" "" "" "" "" "{children}" "" "" "" "" "" "" + ("" if react_strict_mode else "") ) in lines def test_app_state_determination(): """Test that the stateless status of an app is determined correctly.""" a1 = App() assert a1._state is None # No state, no router, no event handlers. a1.add_page(rx.box("Index"), route="/") assert a1._state is None # Add a page with `on_load` enables state. a1.add_page(rx.box("About"), route="/about", on_load=rx.console_log("")) a1._compile_page("about") assert a1._state is not None a2 = App() assert a2._state is None # Referencing a state Var enables state. a2.add_page(rx.box(rx.text(GenState.value)), route="/") a2._compile_page("index") assert a2._state is not None a3 = App() assert a3._state is None # Referencing router enables state. a3.add_page(rx.box(rx.text(State.router.page.full_path)), route="/") a3._compile_page("index") assert a3._state is not None a4 = App() assert a4._state is None a4.add_page(rx.box(rx.button("Click", on_click=rx.console_log(""))), route="/") assert a4._state is None a4.add_page( rx.box(rx.button("Click", on_click=DynamicState.on_counter)), route="/page2" ) a4._compile_page("page2") assert a4._state is not None def test_raise_on_state(): """Test that the state is set.""" # state kwargs is deprecated, we just make sure the app is created anyway. _app = App(_state=State) assert _app._state is not None assert issubclass(_app._state, State) def test_call_app(): """Test that the app can be called.""" app = App() api = app() assert isinstance(api, FastAPI) def test_app_with_optional_endpoints(): from reflex.components.core.upload import Upload app = App() Upload.is_used = True app._add_optional_endpoints() # TODO: verify the availability of the endpoints in app.api def test_app_state_manager(): app = App() with pytest.raises(ValueError): app.state_manager app._enable_state() assert app.state_manager is not None assert isinstance( app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis) ) def test_generate_component(): def index(): return rx.box("Index") def index_mismatch(): return rx.match( 1, (1, rx.box("Index")), (2, "About"), "Bar", ) comp = App._generate_component(index) assert isinstance(comp, Component) with pytest.raises(exceptions.MatchTypeError): App._generate_component(index_mismatch) # pyright: ignore [reportArgumentType] def test_add_page_component_returning_tuple(): """Test that a component or render method returning a tuple is unpacked in a Fragment. """ app = App() def index(): return rx.text("first"), rx.text("second") def page2(): return (rx.text("third"),) app.add_page(index) # pyright: ignore [reportArgumentType] app.add_page(page2) # pyright: ignore [reportArgumentType] app._compile_page("index") app._compile_page("page2") fragment_wrapper = app._pages["index"].children[0] assert isinstance(fragment_wrapper, Fragment) first_text = fragment_wrapper.children[0] assert isinstance(first_text, Text) assert str(first_text.children[0].contents) == '"first"' # pyright: ignore [reportAttributeAccessIssue] second_text = fragment_wrapper.children[1] assert isinstance(second_text, Text) assert str(second_text.children[0].contents) == '"second"' # pyright: ignore [reportAttributeAccessIssue] # Test page with trailing comma. page2_fragment_wrapper = app._pages["page2"].children[0] assert isinstance(page2_fragment_wrapper, Fragment) third_text = page2_fragment_wrapper.children[0] assert isinstance(third_text, Text) assert str(third_text.children[0].contents) == '"third"' # pyright: ignore [reportAttributeAccessIssue] @pytest.mark.parametrize("export", (True, False)) def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool): class C1(rx.Component): library = "foo@1.2.3" tag = "Foo" transpile_packages: List[str] = ["foo"] class C2(rx.Component): library = "bar@4.5.6" tag = "Bar" transpile_packages: List[str] = ["bar@4.5.6"] class C3(rx.NoSSRComponent): library = "baz@7.8.10" tag = "Baz" transpile_packages: List[str] = ["baz@7.8.9"] class C4(rx.NoSSRComponent): library = "quuc@2.3.4" tag = "Quuc" transpile_packages: List[str] = ["quuc"] class C5(rx.Component): library = "quuc" tag = "Quuc" app, web_dir = compilable_app page = Fragment.create( C1.create(), C2.create(), C3.create(), C4.create(), C5.create() ) app.add_page(page, route="/") app._compile(export=export) next_config = (web_dir / "next.config.js").read_text() transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", next_config) transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess] transpile_packages = sorted(json.loads(transpile_packages_json)) assert transpile_packages == [ "bar", "foo", "quuc", ] if export: assert 'output: "export"' in next_config assert f'distDir: "{constants.Dirs.STATIC}"' in next_config else: assert 'output: "export"' not in next_config assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]): app, _ = compilable_app class ValidDepState(BaseState): base: int = 0 _backend: int = 0 @computed_var() def foo(self) -> str: return "foo" @computed_var(deps=["_backend", "base", foo]) def bar(self) -> str: return "bar" class Child1(ValidDepState): @computed_var(deps=["base", ValidDepState.bar]) def other(self) -> str: return "other" class Child2(ValidDepState): @computed_var(deps=["base", Child1.other]) def other(self) -> str: return "other" app._state = ValidDepState app._compile() def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]): app, _ = compilable_app class InvalidDepState(BaseState): @computed_var(deps=["foolksjdf"]) def bar(self) -> str: return "bar" app._state = InvalidDepState with pytest.raises(exceptions.VarDependencyError): app._compile() # Test custom exception handlers def valid_custom_handler(exception: Exception, logger: str = "test"): print("Custom Backend Exception") print(exception) def custom_exception_handler_with_wrong_arg_order( logger: str, exception: Exception, # Should be first ): print("Custom Backend Exception") print(exception) def custom_exception_handler_with_wrong_argspec( exception: str, # Should be Exception ): print("Custom Backend Exception") print(exception) class DummyExceptionHandler: """Dummy exception handler class.""" def handle(self, exception: Exception): """Handle the exception. Args: exception: The exception. """ print("Custom Backend Exception") print(exception) custom_exception_handlers = { "lambda": lambda exception: print("Custom Exception Handler", exception), "wrong_argspec": custom_exception_handler_with_wrong_argspec, "wrong_arg_order": custom_exception_handler_with_wrong_arg_order, "valid": valid_custom_handler, "partial": functools.partial(valid_custom_handler, logger="test"), "method": DummyExceptionHandler().handle, } @pytest.mark.parametrize( "handler_fn, expected", [ pytest.param( custom_exception_handlers["partial"], pytest.raises(ValueError), id="partial", ), pytest.param( custom_exception_handlers["lambda"], pytest.raises(ValueError), id="lambda", ), pytest.param( custom_exception_handlers["wrong_argspec"], pytest.raises(ValueError), id="wrong_argspec", ), pytest.param( custom_exception_handlers["wrong_arg_order"], pytest.raises(ValueError), id="wrong_arg_order", ), pytest.param( custom_exception_handlers["valid"], does_not_raise(), id="valid_handler", ), pytest.param( custom_exception_handlers["method"], does_not_raise(), id="valid_class_method", ), ], ) def test_frontend_exception_handler_validation(handler_fn, expected): """Test that the custom frontend exception handler is properly validated. Args: handler_fn: The handler function. expected: The expected result. """ with expected: rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers() def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int: """Custom backend exception handler with wrong return type. Args: exception: The exception. Returns: int: The wrong return type. """ print("Custom Backend Exception") print(exception) return 5 @pytest.mark.parametrize( "handler_fn, expected", [ pytest.param( backend_exception_handler_with_wrong_return_type, pytest.raises(ValueError), id="wrong_return_type", ), pytest.param( custom_exception_handlers["partial"], pytest.raises(ValueError), id="partial", ), pytest.param( custom_exception_handlers["lambda"], pytest.raises(ValueError), id="lambda", ), pytest.param( custom_exception_handlers["wrong_argspec"], pytest.raises(ValueError), id="wrong_argspec", ), pytest.param( custom_exception_handlers["wrong_arg_order"], pytest.raises(ValueError), id="wrong_arg_order", ), pytest.param( custom_exception_handlers["valid"], does_not_raise(), id="valid_handler", ), pytest.param( custom_exception_handlers["method"], does_not_raise(), id="valid_class_method", ), ], ) def test_backend_exception_handler_validation(handler_fn, expected): """Test that the custom backend exception handler is properly validated. Args: handler_fn: The handler function. expected: The expected result. """ with expected: rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers()