from __future__ import annotations
import functools
import io
import json
import os.path
import re
import unittest.mock
import uuid
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Generator, List, Tuple, Type, cast
from unittest.mock import AsyncMock
import pytest
import sqlmodel
from fastapi import FastAPI, UploadFile
from starlette_admin.auth import AuthProvider
from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView
import reflex as rx
from reflex import AdminDash, constants
from reflex.app import (
App,
ComponentCallable,
OverlayFragment,
default_overlay_component,
process,
upload,
)
from reflex.components import Component
from reflex.components.base.fragment import Fragment
from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event, EventHandler
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import (
BaseState,
OnLoadInternalState,
RouterData,
State,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
StateUpdate,
_substate_key,
)
from reflex.style import Style
from reflex.utils import exceptions, format
from reflex.vars.base import computed_var
from .conftest import chdir
from .states import (
ChildFileUploadState,
FileStateBase1,
FileUploadState,
GenState,
GrandChildFileUploadState,
)
class EmptyState(BaseState):
"""An empty state."""
pass
@pytest.fixture
def index_page() -> ComponentCallable:
"""An index page.
Returns:
The index page.
"""
def index():
return rx.box("Index")
return index
@pytest.fixture
def about_page() -> ComponentCallable:
"""An about page.
Returns:
The about page.
"""
def about():
return rx.box("About")
return about
class ATestState(BaseState):
"""A simple state for testing."""
var: int
@pytest.fixture()
def test_state() -> Type[BaseState]:
"""A default state.
Returns:
A default state.
"""
return ATestState
@pytest.fixture()
def redundant_test_state() -> Type[BaseState]:
"""A default state.
Returns:
A default state.
"""
class RedundantTestState(BaseState):
var: int
return RedundantTestState
@pytest.fixture(scope="session")
def test_model() -> Type[Model]:
"""A default model.
Returns:
A default model.
"""
class TestModel(Model, table=True):
pass
return TestModel
@pytest.fixture(scope="session")
def test_model_auth() -> Type[Model]:
"""A default model.
Returns:
A default model.
"""
class TestModelAuth(Model, table=True):
"""A test model with auth."""
pass
return TestModelAuth
@pytest.fixture()
def test_get_engine():
"""A default database engine.
Returns:
A default database engine.
"""
enable_admin = True
url = "sqlite:///test.db"
return sqlmodel.create_engine(
url,
echo=False,
connect_args={"check_same_thread": False} if enable_admin else {},
)
@pytest.fixture()
def test_custom_auth_admin() -> Type[AuthProvider]:
"""A default auth provider.
Returns:
A default default auth provider.
"""
class TestAuthProvider(AuthProvider):
"""A test auth provider."""
login_path: str = "/login"
logout_path: str = "/logout"
def login(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Login."""
pass
def is_authenticated(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Is authenticated."""
pass
def get_admin_user(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Get admin user."""
pass
def logout(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Logout."""
pass
return TestAuthProvider
def test_default_app(app: App):
"""Test creating an app with no args.
Args:
app: The app to test.
"""
assert app.middleware == [HydrateMiddleware()]
assert app.style == Style()
assert app.admin_dash is None
def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
"""Test that an error is thrown when multiple classes subclass rx.BaseState.
Args:
monkeypatch: Pytest monkeypatch object.
test_state: A test state subclassing rx.BaseState.
redundant_test_state: Another test state subclassing rx.BaseState.
"""
monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
with pytest.raises(ValueError):
App()
def test_add_page_default_route(app: App, index_page, about_page):
"""Test adding a page to an app.
Args:
app: The app to test.
index_page: The index page.
about_page: The about page.
"""
assert app._pages == {}
assert app._unevaluated_pages == {}
app.add_page(index_page)
app._compile_page("index")
assert app._pages.keys() == {"index"}
app.add_page(about_page)
app._compile_page("about")
assert app._pages.keys() == {"index", "about"}
def test_add_page_set_route(app: App, index_page, windows_platform: bool):
"""Test adding a page to an app.
Args:
app: The app to test.
index_page: The index page.
windows_platform: Whether the system is windows.
"""
route = "test" if windows_platform else "/test"
assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
app._compile_page("test")
assert app._pages.keys() == {"test"}
def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
"""Test adding a page with dynamic route variable to an app.
Args:
index_page: The index page.
windows_platform: Whether the system is windows.
"""
app = App(_state=EmptyState)
assert app._state is not None
route = "/test/[dynamic]"
assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
app._compile_page("test/[dynamic]")
assert app._pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app._state.computed_vars
assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
EmptyState.get_full_name(): {constants.ROUTER},
}
assert constants.ROUTER in app._state()._var_dependencies
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
"""Test adding a page to an app.
Args:
app: The app to test.
index_page: The index page.
windows_platform: Whether the system is windows.
"""
route = "test\\nested" if windows_platform else "/test/nested"
assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
assert app._unevaluated_pages.keys() == {route.strip(os.path.sep)}
def test_add_page_invalid_api_route(app: App, index_page):
"""Test adding a page with an invalid route to an app.
Args:
app: The app to test.
index_page: The index page.
"""
with pytest.raises(ValueError):
app.add_page(index_page, route="api")
with pytest.raises(ValueError):
app.add_page(index_page, route="/api")
with pytest.raises(ValueError):
app.add_page(index_page, route="/api/")
with pytest.raises(ValueError):
app.add_page(index_page, route="api/foo")
with pytest.raises(ValueError):
app.add_page(index_page, route="/api/foo")
# These should be fine
app.add_page(index_page, route="api2")
app.add_page(index_page, route="/foo/api")
def page1():
return rx.fragment()
def page2():
return rx.fragment()
def index():
return rx.fragment()
@pytest.mark.parametrize(
"first_page,second_page, route",
[
(lambda: rx.fragment(), lambda: rx.fragment(rx.text("second")), "/"),
(rx.fragment(rx.text("first")), rx.fragment(rx.text("second")), "/page1"),
(
lambda: rx.fragment(rx.text("first")),
rx.fragment(rx.text("second")),
"page3",
),
(page1, page2, "page1"),
(index, index, None),
(page1, page1, None),
],
)
def test_add_duplicate_page_route_error(app, first_page, second_page, route):
app.add_page(first_page, route=route)
with pytest.raises(ValueError):
app.add_page(second_page, route="/" + route.strip("/") if route else None)
def test_initialize_with_admin_dashboard(test_model):
"""Test setting the admin dashboard of an app.
Args:
test_model: The default model.
"""
app = App(admin_dash=AdminDash(models=[test_model]))
assert app.admin_dash is not None
assert len(app.admin_dash.models) > 0
assert app.admin_dash.models[0] == test_model
def test_initialize_with_custom_admin_dashboard(
test_get_engine,
test_custom_auth_admin,
test_model_auth,
):
"""Test setting the custom admin dashboard of an app.
Args:
test_get_engine: The default database engine.
test_model_auth: The default model for an auth admin dashboard.
test_custom_auth_admin: The custom auth provider.
"""
custom_auth_provider = test_custom_auth_admin()
custom_admin = Admin(engine=test_get_engine, auth_provider=custom_auth_provider)
app = App(admin_dash=AdminDash(models=[test_model_auth], admin=custom_admin))
assert app.admin_dash is not None
assert app.admin_dash.admin is not None
assert len(app.admin_dash.models) > 0
assert app.admin_dash.models[0] == test_model_auth
assert app.admin_dash.admin.auth_provider == custom_auth_provider
def test_initialize_admin_dashboard_with_view_overrides(test_model):
"""Test setting the admin dashboard of an app with view class overridden.
Args:
test_model: The default model.
"""
class TestModelView(ModelView):
pass
app = App(
admin_dash=AdminDash(
models=[test_model], view_overrides={test_model: TestModelView}
)
)
assert app.admin_dash is not None
assert app.admin_dash.models == [test_model]
assert app.admin_dash.view_overrides[test_model] == TestModelView
@pytest.mark.asyncio
async def test_initialize_with_state(test_state: Type[ATestState], token: str):
"""Test setting the state of an app.
Args:
test_state: The default state.
token: a Token.
"""
app = App(_state=test_state)
assert app._state == test_state
# Get a state for a given token.
state = await app.state_manager.get_state(_substate_key(token, test_state))
assert isinstance(state, test_state)
assert state.var == 0
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.asyncio
async def test_set_and_get_state(test_state):
"""Test setting and getting the state of an app with different tokens.
Args:
test_state: The default state.
"""
app = App(_state=test_state)
# Create two tokens.
token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
# Get the default state for each token.
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
assert state1.var == 0
assert state2.var == 0
# Set the vars to different values.
state1.var = 1
state2.var = 2
await app.state_manager.set_state(token1, state1)
await app.state_manager.set_state(token2, state2)
# Get the states again and check the values.
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
assert state1.var == 1
assert state2.var == 2
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.asyncio
async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
"""Test that the default handler of a dynamic generated var
works as expected.
Args:
test_state: State Fixture.
token: a Token.
"""
state = test_state() # pyright: ignore [reportCallIssue]
state.add_var("int_val", int, 0)
async for result in state._process(
Event(
token=token,
name=f"{test_state.get_name()}.set_int_val",
router_data={"pathname": "/", "query": {}},
payload={"value": 50},
)
):
assert result.delta == {test_state.get_name(): {"int_val": 50}}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"event_tuples",
[
pytest.param(
[
(
"make_friend",
{"plain_friends": ["Tommy", "another-fd"]},
),
(
"change_first_friend",
{"plain_friends": ["Jenny", "another-fd"]},
),
],
id="append then __setitem__",
),
pytest.param(
[
(
"unfriend_first_friend",
{"plain_friends": []},
),
(
"make_friend",
{"plain_friends": ["another-fd"]},
),
],
id="delitem then append",
),
pytest.param(
[
(
"make_friends_with_colleagues",
{"plain_friends": ["Tommy", "Peter", "Jimmy"]},
),
(
"remove_tommy",
{"plain_friends": ["Peter", "Jimmy"]},
),
(
"remove_last_friend",
{"plain_friends": ["Peter"]},
),
(
"unfriend_all_friends",
{"plain_friends": []},
),
],
id="extend, remove, pop, clear",
),
pytest.param(
[
(
"add_jimmy_to_second_group",
{"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]},
),
(
"remove_first_person_from_first_group",
{"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]},
),
(
"remove_first_group",
{"friends_in_nested_list": [["Jenny", "Jimmy"]]},
),
],
id="nested list",
),
pytest.param(
[
(
"add_jimmy_to_tommy_friends",
{"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}},
),
(
"remove_jenny_from_tommy",
{"friends_in_dict": {"Tommy": ["Jimmy"]}},
),
(
"tommy_has_no_fds",
{"friends_in_dict": {"Tommy": []}},
),
],
id="list in dict",
),
],
)
async def test_list_mutation_detection__plain_list(
event_tuples: List[Tuple[str, List[str]]],
list_mutation_state: State,
token: str,
):
"""Test list mutation detection
when reassignment is not explicitly included in the logic.
Args:
event_tuples: From parametrization.
list_mutation_state: A state with list mutation features.
token: a Token.
"""
for event_name, expected_delta in event_tuples:
async for result in list_mutation_state._process(
Event(
token=token,
name=f"{list_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
):
# prefix keys in expected_delta with the state name
expected_delta = {list_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
@pytest.mark.asyncio
@pytest.mark.parametrize(
"event_tuples",
[
pytest.param(
[
(
"add_age",
{"details": {"name": "Tommy", "age": 20}},
),
(
"change_name",
{"details": {"name": "Jenny", "age": 20}},
),
(
"remove_last_detail",
{"details": {"name": "Jenny"}},
),
],
id="update then __setitem__",
),
pytest.param(
[
(
"clear_details",
{"details": {}},
),
(
"add_age",
{"details": {"age": 20}},
),
],
id="delitem then update",
),
pytest.param(
[
(
"add_age",
{"details": {"name": "Tommy", "age": 20}},
),
(
"remove_name",
{"details": {"age": 20}},
),
(
"pop_out_age",
{"details": {}},
),
],
id="add, remove, pop",
),
pytest.param(
[
(
"remove_home_address",
{"address": [{}, {"work": "work address"}]},
),
(
"add_street_to_home_address",
{
"address": [
{"street": "street address"},
{"work": "work address"},
]
},
),
],
id="dict in list",
),
pytest.param(
[
(
"change_friend_name",
{
"friend_in_nested_dict": {
"name": "Nikhil",
"friend": {"name": "Tommy"},
}
},
),
(
"add_friend_age",
{
"friend_in_nested_dict": {
"name": "Nikhil",
"friend": {"name": "Tommy", "age": 30},
}
},
),
(
"remove_friend",
{"friend_in_nested_dict": {"name": "Nikhil"}},
),
],
id="nested dict",
),
],
)
async def test_dict_mutation_detection__plain_list(
event_tuples: List[Tuple[str, List[str]]],
dict_mutation_state: State,
token: str,
):
"""Test dict mutation detection
when reassignment is not explicitly included in the logic.
Args:
event_tuples: From parametrization.
dict_mutation_state: A state with dict mutation features.
token: a Token.
"""
for event_name, expected_delta in event_tuples:
async for result in dict_mutation_state._process(
Event(
token=token,
name=f"{dict_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
):
# prefix keys in expected_delta with the state name
expected_delta = {dict_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
@pytest.mark.asyncio
@pytest.mark.parametrize(
("state", "delta"),
[
(
FileUploadState,
{
FileUploadState.get_full_name(): {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
(
ChildFileUploadState,
{
ChildFileUploadState.get_full_name(): {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
(
GrandChildFileUploadState,
{
GrandChildFileUploadState.get_full_name(): {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
],
)
async def test_upload_file(tmp_path, state, delta, token: str, mocker):
"""Test that file upload works correctly.
Args:
tmp_path: Temporary path.
state: The state class.
delta: Expected delta
token: a Token.
mocker: pytest mocker object.
"""
mocker.patch(
"reflex.state.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1},
)
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App()
app._enable_state()
app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess]
current_state = await app.state_manager.get_state(_substate_key(token, state))
data = b"This is binary data"
# Create a binary IO object and write data to it
bio = io.BytesIO()
bio.write(data)
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload",
}
file1 = UploadFile(
filename="image1.jpg",
file=bio,
)
file2 = UploadFile(
filename="image2.jpg",
file=bio,
)
upload_fn = upload(app)
streaming_response = await upload_fn(request_mock, [file1, file2]) # pyright: ignore [reportFunctionMemberAccess]
async for state_update in streaming_response.body_iterator:
assert (
state_update
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)
current_state = await app.state_manager.get_state(_substate_key(token, state))
state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
"image2.jpg",
]
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state",
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
)
async def test_upload_file_without_annotation(state, tmp_path, token):
"""Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
Args:
state: The state class.
tmp_path: Temporary path.
token: a Token.
"""
state._tmp_path = tmp_path
app = App(_state=State)
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state.get_full_name()}.handle_upload2",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
with pytest.raises(ValueError) as err:
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state",
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
)
async def test_upload_file_background(state, tmp_path, token):
"""Test that an error is thrown handler is a background task.
Args:
state: The state class.
tmp_path: Temporary path.
token: a Token.
"""
state._tmp_path = tmp_path
app = App(_state=State)
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state.get_full_name()}.bg_upload",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
with pytest.raises(TypeError) as err:
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`."
)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
class DynamicState(BaseState):
"""State class for testing dynamic route var.
This is defined at module level because event handlers cannot be addressed
correctly when the class is defined as a local.
There are several counters:
* loaded: counts how many times `on_load` was triggered by the hydrate middleware
* counter: counts how many times `on_counter` was triggered by a non-navigational event
-> these events should NOT trigger reload or recalculation of router_data dependent vars
* side_effect_counter: counts how many times a computed var was
recalculated when the dynamic route var was dirty
"""
is_hydrated: bool = False
loaded: int = 0
counter: int = 0
@rx.event
def on_load(self):
"""Event handler for page on_load, should trigger for all navigation events."""
self.loaded = self.loaded + 1
@rx.event
def on_counter(self):
"""Increment the counter var."""
self.counter = self.counter + 1
@computed_var
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.
Returns:
same as self.dynamic
"""
return self.dynamic
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
def test_dynamic_arg_shadow(
index_page: ComponentCallable,
windows_platform: bool,
token: str,
app_module_mock: unittest.mock.Mock,
mocker,
):
"""Create app with dynamic route var and try to add a page with a dynamic arg that shadows a state var.
Args:
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
app_module_mock: Mocked app module.
mocker: pytest mocker object.
"""
arg_name = "counter"
route = f"/test/[{arg_name}]"
app = app_module_mock.app = App(_state=DynamicState)
assert app._state is not None
with pytest.raises(NameError):
app.add_page(index_page, route=route, on_load=DynamicState.on_load)
def test_multiple_dynamic_args(
index_page: ComponentCallable,
windows_platform: bool,
token: str,
app_module_mock: unittest.mock.Mock,
mocker,
):
"""Create app with multiple dynamic route vars with the same name.
Args:
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
app_module_mock: Mocked app module.
mocker: pytest mocker object.
"""
arg_name = "my_arg"
route = f"/test/[{arg_name}]"
route2 = f"/test2/[{arg_name}]"
app = app_module_mock.app = App(_state=EmptyState)
app.add_page(index_page, route=route)
app.add_page(index_page, route=route2)
@pytest.mark.asyncio
async def test_dynamic_route_var_route_change_completed_on_load(
index_page: ComponentCallable,
windows_platform: bool,
token: str,
app_module_mock: unittest.mock.Mock,
mocker,
):
"""Create app with dynamic route var, and simulate navigation.
on_load should fire, allowing any additional vars to be updated before the
initial page hydrate.
Args:
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
app_module_mock: Mocked app module.
mocker: pytest mocker object.
"""
arg_name = "dynamic"
route = f"/test/[{arg_name}]"
app = app_module_mock.app = App(_state=DynamicState)
assert app._state is not None
assert arg_name not in app._state.vars
app.add_page(index_page, route=route, on_load=DynamicState.on_load)
assert arg_name in app._state.vars
assert arg_name in app._state.computed_vars
assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
DynamicState.get_full_name(): {constants.ROUTER},
}
assert constants.ROUTER in app._state()._var_dependencies
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
client_ip = "127.0.0.1"
async with app.state_manager.modify_state(substate_token) as state:
state.router_data = {"simulate": "hydrated"}
assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"]
def _event(name, val, **kwargs):
return Event(
token=kwargs.pop("token", token),
name=name,
router_data=kwargs.pop(
"router_data", {"pathname": route, "query": {arg_name: val}}
),
payload=kwargs.pop("payload", {}),
**kwargs,
)
def _dynamic_state_event(name, val, **kwargs):
return _event(
name=format.format_event_handler(getattr(DynamicState, name)),
val=val,
**kwargs,
)
prev_exp_val = ""
for exp_index, exp_val in enumerate(exp_vals):
on_load_internal = _event(
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}",
val=exp_val,
)
exp_router_data = {
"headers": {},
"ip": client_ip,
"sid": sid,
"token": token,
**on_load_internal.router_data,
}
exp_router = RouterData(exp_router_data)
process_coro = process(
app,
event=on_load_internal,
sid=sid,
headers={},
client_ip=client_ip,
)
update = await process_coro.__anext__()
# route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
assert update == StateUpdate(
delta={
state.get_name(): {
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False,
"router": exp_router,
}
},
events=[
_dynamic_state_event(
name="on_load",
val=exp_val,
),
_event(
name=f"{State.get_name()}.set_is_hydrated",
payload={"value": True},
val=exp_val,
router_data={},
),
],
)
if isinstance(app.state_manager, StateManagerRedis):
# When redis is used, the state is not updated until the processing is complete
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == prev_exp_val
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__()
# check that router data was written to the state_manager store
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == exp_val
process_coro = process(
app,
event=_dynamic_state_event(name="on_load", val=exp_val),
sid=sid,
headers={},
client_ip=client_ip,
)
on_load_update = await process_coro.__anext__()
assert on_load_update == StateUpdate(
delta={
state.get_name(): {
"loaded": exp_index + 1,
},
},
events=[],
)
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__()
process_coro = process(
app,
event=_dynamic_state_event(
name="set_is_hydrated", payload={"value": True}, val=exp_val
),
sid=sid,
headers={},
client_ip=client_ip,
)
on_set_is_hydrated_update = await process_coro.__anext__()
assert on_set_is_hydrated_update == StateUpdate(
delta={
state.get_name(): {
"is_hydrated": True,
},
},
events=[],
)
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__()
# a simple state update event should NOT trigger on_load or route var side effects
process_coro = process(
app,
event=_dynamic_state_event(name="on_counter", val=exp_val),
sid=sid,
headers={},
client_ip=client_ip,
)
update = await process_coro.__anext__()
assert update == StateUpdate(
delta={
state.get_name(): {
"counter": exp_index + 1,
}
},
events=[],
)
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__()
prev_exp_val = exp_val
state = await app.state_manager.get_state(substate_token)
assert state.loaded == len(exp_vals)
assert state.counter == len(exp_vals)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.asyncio
async def test_process_events(mocker, token: str):
"""Test that an event is processed properly and that it is postprocessed
n+1 times. Also check that the processing flag of the last stateupdate is set to
False.
Args:
mocker: mocker object.
token: a Token.
"""
router_data = {
"pathname": "/",
"query": {},
"token": token,
"sid": "mock_sid",
"headers": {},
"ip": "127.0.0.1",
}
app = App(_state=GenState)
mocker.patch.object(app, "_postprocess", AsyncMock())
event = Event(
token=token,
name=f"{GenState.get_name()}.go",
payload={"c": 5},
router_data=router_data,
)
async with app.state_manager.modify_state(event.substate_token) as state:
state.router_data = {"simulate": "hydrated"}
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
pass
assert (await app.state_manager.get_state(event.substate_token)).value == 5
assert getattr(app._postprocess, "call_count", None) == 6
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@pytest.mark.parametrize(
("state", "overlay_component", "exp_page_child"),
[
(None, default_overlay_component, None),
(None, None, None),
(None, Text.create("foo"), Text),
(State, default_overlay_component, Fragment),
(State, None, None),
(State, Text.create("foo"), Text),
(State, lambda: Text.create("foo"), Text),
],
)
def test_overlay_component(
state: Type[State] | None,
overlay_component: Component | ComponentCallable | None,
exp_page_child: Type[Component] | None,
):
"""Test that the overlay component is set correctly.
Args:
state: The state class to pass to App.
overlay_component: The overlay_component to pass to App.
exp_page_child: The type of the expected child in the page fragment.
"""
app = App(_state=state, overlay_component=overlay_component)
app._setup_overlay_component()
if exp_page_child is None:
assert app.overlay_component is None
elif isinstance(exp_page_child, OverlayFragment):
assert app.overlay_component is not None
generated_component = app._generate_component(app.overlay_component)
assert isinstance(generated_component, OverlayFragment)
else:
assert app.overlay_component is not None
assert isinstance(
app._generate_component(app.overlay_component),
exp_page_child,
)
app.add_page(rx.box("Index"), route="/test")
# overlay components are wrapped during compile only
app._compile_page("test")
app._setup_overlay_component()
page = app._pages["test"]
if exp_page_child is not None:
assert len(page.children) == 3
children_types = [type(child) for child in page.children]
assert exp_page_child in children_types
else:
assert len(page.children) == 2
@pytest.fixture
def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
"""Fixture for an app that can be compiled.
Args:
tmp_path: Temporary path.
Yields:
Tuple containing (app instance, Path to ".web" directory)
The working directory is set to the app dir (parent of .web),
allowing app.compile() to be called.
"""
app_path = tmp_path / "app"
web_dir = app_path / ".web"
web_dir.mkdir(parents=True)
(web_dir / constants.PackageJson.PATH).touch()
app = App(theme=None)
app._get_frontend_packages = unittest.mock.Mock()
with chdir(app_path):
yield app, web_dir
@pytest.mark.parametrize(
"react_strict_mode",
[True, False],
)
def test_app_wrap_compile_theme(
react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
):
"""Test that the radix theme component wraps the app.
Args:
react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
mocker: pytest mocker object.
"""
conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
mocker.patch("reflex.config._get_config", return_value=conf)
app, web_dir = compilable_app
app.theme = rx.theme(accent_color="plum")
app._compile()
app_js_contents = (web_dir / "pages" / "_app.js").read_text()
app_js_lines = [
line.strip() for line in app_js_contents.splitlines() if line.strip()
]
lines = "".join(app_js_lines)
assert (
"function AppWrap({children}) {"
"return ("
+ ("" if react_strict_mode else "")
+ ""
""
""
""
""
"{children}"
""
""
""
""
+ ("" if react_strict_mode else "")
+ ")"
"}"
) in lines
@pytest.mark.parametrize(
"react_strict_mode",
[True, False],
)
def test_app_wrap_priority(
react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
):
"""Test that the app wrap components are wrapped in the correct order.
Args:
react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
mocker: pytest mocker object.
"""
conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
mocker.patch("reflex.config._get_config", return_value=conf)
app, web_dir = compilable_app
class Fragment1(Component):
tag = "Fragment1"
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(99, "Box"): rx.box()}
class Fragment2(Component):
tag = "Fragment2"
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(50, "Text"): rx.text()}
class Fragment3(Component):
tag = "Fragment3"
def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(10, "Fragment2"): Fragment2.create()}
def page():
return Fragment1.create(Fragment3.create())
app.add_page(page)
app._compile()
app_js_contents = (web_dir / "pages" / "_app.js").read_text()
app_js_lines = [
line.strip() for line in app_js_contents.splitlines() if line.strip()
]
lines = "".join(app_js_lines)
assert (
"function AppWrap({children}) {"
"return (" + ("" if react_strict_mode else "") + ""
''
""
""
""
""
""
"{children}"
""
""
""
""
""
"" + ("" if react_strict_mode else "")
) in lines
def test_app_state_determination():
"""Test that the stateless status of an app is determined correctly."""
a1 = App()
assert a1._state is None
# No state, no router, no event handlers.
a1.add_page(rx.box("Index"), route="/")
assert a1._state is None
# Add a page with `on_load` enables state.
a1.add_page(rx.box("About"), route="/about", on_load=rx.console_log(""))
a1._compile_page("about")
assert a1._state is not None
a2 = App()
assert a2._state is None
# Referencing a state Var enables state.
a2.add_page(rx.box(rx.text(GenState.value)), route="/")
a2._compile_page("index")
assert a2._state is not None
a3 = App()
assert a3._state is None
# Referencing router enables state.
a3.add_page(rx.box(rx.text(State.router.page.full_path)), route="/")
a3._compile_page("index")
assert a3._state is not None
a4 = App()
assert a4._state is None
a4.add_page(rx.box(rx.button("Click", on_click=rx.console_log(""))), route="/")
assert a4._state is None
a4.add_page(
rx.box(rx.button("Click", on_click=DynamicState.on_counter)), route="/page2"
)
a4._compile_page("page2")
assert a4._state is not None
def test_raise_on_state():
"""Test that the state is set."""
# state kwargs is deprecated, we just make sure the app is created anyway.
_app = App(_state=State)
assert _app._state is not None
assert issubclass(_app._state, State)
def test_call_app():
"""Test that the app can be called."""
app = App()
api = app()
assert isinstance(api, FastAPI)
def test_app_with_optional_endpoints():
from reflex.components.core.upload import Upload
app = App()
Upload.is_used = True
app._add_optional_endpoints()
# TODO: verify the availability of the endpoints in app.api
def test_app_state_manager():
app = App()
with pytest.raises(ValueError):
app.state_manager
app._enable_state()
assert app.state_manager is not None
assert isinstance(
app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis)
)
def test_generate_component():
def index():
return rx.box("Index")
def index_mismatch():
return rx.match(
1,
(1, rx.box("Index")),
(2, "About"),
"Bar",
)
comp = App._generate_component(index)
assert isinstance(comp, Component)
with pytest.raises(exceptions.MatchTypeError):
App._generate_component(index_mismatch) # pyright: ignore [reportArgumentType]
def test_add_page_component_returning_tuple():
"""Test that a component or render method returning a
tuple is unpacked in a Fragment.
"""
app = App()
def index():
return rx.text("first"), rx.text("second")
def page2():
return (rx.text("third"),)
app.add_page(index) # pyright: ignore [reportArgumentType]
app.add_page(page2) # pyright: ignore [reportArgumentType]
app._compile_page("index")
app._compile_page("page2")
fragment_wrapper = app._pages["index"].children[0]
assert isinstance(fragment_wrapper, Fragment)
first_text = fragment_wrapper.children[0]
assert isinstance(first_text, Text)
assert str(first_text.children[0].contents) == '"first"' # pyright: ignore [reportAttributeAccessIssue]
second_text = fragment_wrapper.children[1]
assert isinstance(second_text, Text)
assert str(second_text.children[0].contents) == '"second"' # pyright: ignore [reportAttributeAccessIssue]
# Test page with trailing comma.
page2_fragment_wrapper = app._pages["page2"].children[0]
assert isinstance(page2_fragment_wrapper, Fragment)
third_text = page2_fragment_wrapper.children[0]
assert isinstance(third_text, Text)
assert str(third_text.children[0].contents) == '"third"' # pyright: ignore [reportAttributeAccessIssue]
@pytest.mark.parametrize("export", (True, False))
def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: bool):
class C1(rx.Component):
library = "foo@1.2.3"
tag = "Foo"
transpile_packages: List[str] = ["foo"]
class C2(rx.Component):
library = "bar@4.5.6"
tag = "Bar"
transpile_packages: List[str] = ["bar@4.5.6"]
class C3(rx.NoSSRComponent):
library = "baz@7.8.10"
tag = "Baz"
transpile_packages: List[str] = ["baz@7.8.9"]
class C4(rx.NoSSRComponent):
library = "quuc@2.3.4"
tag = "Quuc"
transpile_packages: List[str] = ["quuc"]
class C5(rx.Component):
library = "quuc"
tag = "Quuc"
app, web_dir = compilable_app
page = Fragment.create(
C1.create(), C2.create(), C3.create(), C4.create(), C5.create()
)
app.add_page(page, route="/")
app._compile(export=export)
next_config = (web_dir / "next.config.js").read_text()
transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", next_config)
transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
transpile_packages = sorted(json.loads(transpile_packages_json))
assert transpile_packages == [
"bar",
"foo",
"quuc",
]
if export:
assert 'output: "export"' in next_config
assert f'distDir: "{constants.Dirs.STATIC}"' in next_config
else:
assert 'output: "export"' not in next_config
assert f'distDir: "{constants.Dirs.STATIC}"' not in next_config
def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class ValidDepState(BaseState):
base: int = 0
_backend: int = 0
@computed_var()
def foo(self) -> str:
return "foo"
@computed_var(deps=["_backend", "base", foo])
def bar(self) -> str:
return "bar"
class Child1(ValidDepState):
@computed_var(deps=["base", ValidDepState.bar])
def other(self) -> str:
return "other"
class Child2(ValidDepState):
@computed_var(deps=["base", Child1.other])
def other(self) -> str:
return "other"
app._state = ValidDepState
app._compile()
def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class InvalidDepState(BaseState):
@computed_var(deps=["foolksjdf"])
def bar(self) -> str:
return "bar"
app._state = InvalidDepState
with pytest.raises(exceptions.VarDependencyError):
app._compile()
# Test custom exception handlers
def valid_custom_handler(exception: Exception, logger: str = "test"):
print("Custom Backend Exception")
print(exception)
def custom_exception_handler_with_wrong_arg_order(
logger: str,
exception: Exception, # Should be first
):
print("Custom Backend Exception")
print(exception)
def custom_exception_handler_with_wrong_argspec(
exception: str, # Should be Exception
):
print("Custom Backend Exception")
print(exception)
class DummyExceptionHandler:
"""Dummy exception handler class."""
def handle(self, exception: Exception):
"""Handle the exception.
Args:
exception: The exception.
"""
print("Custom Backend Exception")
print(exception)
custom_exception_handlers = {
"lambda": lambda exception: print("Custom Exception Handler", exception),
"wrong_argspec": custom_exception_handler_with_wrong_argspec,
"wrong_arg_order": custom_exception_handler_with_wrong_arg_order,
"valid": valid_custom_handler,
"partial": functools.partial(valid_custom_handler, logger="test"),
"method": DummyExceptionHandler().handle,
}
@pytest.mark.parametrize(
"handler_fn, expected",
[
pytest.param(
custom_exception_handlers["partial"],
pytest.raises(ValueError),
id="partial",
),
pytest.param(
custom_exception_handlers["lambda"],
pytest.raises(ValueError),
id="lambda",
),
pytest.param(
custom_exception_handlers["wrong_argspec"],
pytest.raises(ValueError),
id="wrong_argspec",
),
pytest.param(
custom_exception_handlers["wrong_arg_order"],
pytest.raises(ValueError),
id="wrong_arg_order",
),
pytest.param(
custom_exception_handlers["valid"],
does_not_raise(),
id="valid_handler",
),
pytest.param(
custom_exception_handlers["method"],
does_not_raise(),
id="valid_class_method",
),
],
)
def test_frontend_exception_handler_validation(handler_fn, expected):
"""Test that the custom frontend exception handler is properly validated.
Args:
handler_fn: The handler function.
expected: The expected result.
"""
with expected:
rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers()
def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int:
"""Custom backend exception handler with wrong return type.
Args:
exception: The exception.
Returns:
int: The wrong return type.
"""
print("Custom Backend Exception")
print(exception)
return 5
@pytest.mark.parametrize(
"handler_fn, expected",
[
pytest.param(
backend_exception_handler_with_wrong_return_type,
pytest.raises(ValueError),
id="wrong_return_type",
),
pytest.param(
custom_exception_handlers["partial"],
pytest.raises(ValueError),
id="partial",
),
pytest.param(
custom_exception_handlers["lambda"],
pytest.raises(ValueError),
id="lambda",
),
pytest.param(
custom_exception_handlers["wrong_argspec"],
pytest.raises(ValueError),
id="wrong_argspec",
),
pytest.param(
custom_exception_handlers["wrong_arg_order"],
pytest.raises(ValueError),
id="wrong_arg_order",
),
pytest.param(
custom_exception_handlers["valid"],
does_not_raise(),
id="valid_handler",
),
pytest.param(
custom_exception_handlers["method"],
does_not_raise(),
id="valid_class_method",
),
],
)
def test_backend_exception_handler_validation(handler_fn, expected):
"""Test that the custom backend exception handler is properly validated.
Args:
handler_fn: The handler function.
expected: The expected result.
"""
with expected:
rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers()