|
@@ -10,7 +10,17 @@ import os
|
|
|
import sys
|
|
|
import threading
|
|
|
from textwrap import dedent
|
|
|
-from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
|
|
+from typing import (
|
|
|
+ Any,
|
|
|
+ AsyncGenerator,
|
|
|
+ Callable,
|
|
|
+ Dict,
|
|
|
+ List,
|
|
|
+ Optional,
|
|
|
+ Set,
|
|
|
+ Tuple,
|
|
|
+ Union,
|
|
|
+)
|
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
|
|
import pytest
|
|
@@ -1829,12 +1839,11 @@ async def test_state_manager_lock_expire_contend(
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
|
-def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
|
|
- """Mock app fixture.
|
|
|
+def mock_app_simple(monkeypatch) -> rx.App:
|
|
|
+ """Simple Mock app fixture.
|
|
|
|
|
|
Args:
|
|
|
monkeypatch: Pytest monkeypatch object.
|
|
|
- state_manager: A state manager.
|
|
|
|
|
|
Returns:
|
|
|
The app, after mocking out prerequisites.get_app()
|
|
@@ -1845,7 +1854,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
|
|
|
|
|
setattr(app_module, CompileVars.APP, app)
|
|
|
app.state = TestState
|
|
|
- app._state_manager = state_manager
|
|
|
app.event_namespace.emit = AsyncMock() # type: ignore
|
|
|
|
|
|
def _mock_get_app(*args, **kwargs):
|
|
@@ -1855,6 +1863,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
|
|
return app
|
|
|
|
|
|
|
|
|
+@pytest.fixture(scope="function")
|
|
|
+def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
|
|
|
+ """Mock app fixture.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mock_app_simple: A simple mock app.
|
|
|
+ state_manager: A state manager.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The app, after mocking out prerequisites.get_app()
|
|
|
+ """
|
|
|
+ mock_app_simple._state_manager = state_manager
|
|
|
+ return mock_app_simple
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
|
|
"""Test that the state proxy works.
|
|
@@ -3507,3 +3530,106 @@ def test_init_mixin() -> None:
|
|
|
|
|
|
with pytest.raises(ReflexRuntimeError):
|
|
|
SubMixin()
|
|
|
+
|
|
|
+
|
|
|
+class ReflexModel(rx.Model):
|
|
|
+ """A model for testing."""
|
|
|
+
|
|
|
+ foo: str
|
|
|
+
|
|
|
+
|
|
|
+class UpcastState(rx.State):
|
|
|
+ """A state for testing upcasting."""
|
|
|
+
|
|
|
+ passed: bool = False
|
|
|
+
|
|
|
+ def rx_model(self, m: ReflexModel): # noqa: D102
|
|
|
+ assert isinstance(m, ReflexModel)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def rx_base(self, o: Object): # noqa: D102
|
|
|
+ assert isinstance(o, Object)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
|
|
|
+ if o is not None:
|
|
|
+ assert isinstance(o, Object)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def rx_basemodelv1(self, m: ModelV1): # noqa: D102
|
|
|
+ assert isinstance(m, ModelV1)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def rx_basemodelv2(self, m: ModelV2): # noqa: D102
|
|
|
+ assert isinstance(m, ModelV2)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def rx_dataclass(self, dc: ModelDC): # noqa: D102
|
|
|
+ assert isinstance(dc, ModelDC)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_set(self, s: set): # noqa: D102
|
|
|
+ assert isinstance(s, set)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_Set(self, s: Set): # noqa: D102
|
|
|
+ assert isinstance(s, Set)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_tuple(self, t: tuple): # noqa: D102
|
|
|
+ assert isinstance(t, tuple)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_Tuple(self, t: Tuple): # noqa: D102
|
|
|
+ assert isinstance(t, tuple)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_dict(self, d: dict[str, str]): # noqa: D102
|
|
|
+ assert isinstance(d, dict)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_list(self, ls: list[str]): # noqa: D102
|
|
|
+ assert isinstance(ls, list)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_Any(self, a: Any): # noqa: D102
|
|
|
+ assert isinstance(a, list)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+ def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
|
|
|
+ assert isinstance(u, list)
|
|
|
+ self.passed = True
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+@pytest.mark.usefixtures("mock_app_simple")
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ ("handler", "payload"),
|
|
|
+ [
|
|
|
+ (UpcastState.rx_model, {"m": {"foo": "bar"}}),
|
|
|
+ (UpcastState.rx_base, {"o": {"foo": "bar"}}),
|
|
|
+ (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
|
|
|
+ (UpcastState.rx_base_or_none, {"o": None}),
|
|
|
+ (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
|
|
|
+ (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
|
|
|
+ (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
|
|
|
+ (UpcastState.py_set, {"s": ["foo", "foo"]}),
|
|
|
+ (UpcastState.py_Set, {"s": ["foo", "foo"]}),
|
|
|
+ (UpcastState.py_tuple, {"t": ["foo", "foo"]}),
|
|
|
+ (UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
|
|
|
+ (UpcastState.py_dict, {"d": {"foo": "bar"}}),
|
|
|
+ (UpcastState.py_list, {"ls": ["foo", "foo"]}),
|
|
|
+ (UpcastState.py_Any, {"a": ["foo"]}),
|
|
|
+ (UpcastState.py_unresolvable, {"u": ["foo"]}),
|
|
|
+ ],
|
|
|
+)
|
|
|
+async def test_upcast_event_handler_arg(handler, payload):
|
|
|
+ """Test that upcast event handler args work correctly.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ handler: The handler to test.
|
|
|
+ payload: The payload to test.
|
|
|
+ """
|
|
|
+ state = UpcastState()
|
|
|
+ async for update in state._process_event(handler, state, payload):
|
|
|
+ assert update.delta == {UpcastState.get_full_name(): {"passed": True}}
|