Browse Source

Merge remote-tracking branch 'upstream/main' into minify-state-names-v2

Benedikt Bartscher 5 months ago
parent
commit
5c9839e0fe
4 changed files with 156 additions and 9 deletions
  1. 4 2
      reflex/app.py
  2. 6 2
      reflex/state.py
  3. 15 0
      reflex/utils/console.py
  4. 131 5
      tests/units/test_state.py

+ 4 - 2
reflex/app.py

@@ -1462,10 +1462,10 @@ class EventNamespace(AsyncNamespace):
     app: App
     app: App
 
 
     # Keep a mapping between socket ID and client token.
     # Keep a mapping between socket ID and client token.
-    token_to_sid: dict[str, str] = {}
+    token_to_sid: dict[str, str]
 
 
     # Keep a mapping between client token and socket ID.
     # Keep a mapping between client token and socket ID.
-    sid_to_token: dict[str, str] = {}
+    sid_to_token: dict[str, str]
 
 
     def __init__(self, namespace: str, app: App):
     def __init__(self, namespace: str, app: App):
         """Initialize the event namespace.
         """Initialize the event namespace.
@@ -1475,6 +1475,8 @@ class EventNamespace(AsyncNamespace):
             app: The application object.
             app: The application object.
         """
         """
         super().__init__(namespace)
         super().__init__(namespace)
+        self.token_to_sid = {}
+        self.sid_to_token = {}
         self.app = app
         self.app = app
 
 
     def on_connect(self, sid, environ):
     def on_connect(self, sid, environ):

+ 6 - 2
reflex/state.py

@@ -1806,7 +1806,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 if value is None:
                 if value is None:
                     continue
                     continue
                 hinted_args = value_inside_optional(hinted_args)
                 hinted_args = value_inside_optional(hinted_args)
-            if isinstance(value, dict) and inspect.isclass(hinted_args):
+            if (
+                isinstance(value, dict)
+                and inspect.isclass(hinted_args)
+                and not types.is_generic_alias(hinted_args)  # py3.9-py3.10
+            ):
                 if issubclass(hinted_args, Model):
                 if issubclass(hinted_args, Model):
                     # Remove non-fields from the payload
                     # Remove non-fields from the payload
                     payload[arg] = hinted_args(
                     payload[arg] = hinted_args(
@@ -1817,7 +1821,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                         }
                         }
                     )
                     )
                 elif dataclasses.is_dataclass(hinted_args) or issubclass(
                 elif dataclasses.is_dataclass(hinted_args) or issubclass(
-                    hinted_args, Base
+                    hinted_args, (Base, BaseModelV1, BaseModelV2)
                 ):
                 ):
                     payload[arg] = hinted_args(**value)
                     payload[arg] = hinted_args(**value)
             if isinstance(value, list) and (hinted_args is set or hinted_args is Set):
             if isinstance(value, list) and (hinted_args is set or hinted_args is Set):

+ 15 - 0
reflex/utils/console.py

@@ -26,7 +26,22 @@ def set_log_level(log_level: LogLevel):
 
 
     Args:
     Args:
         log_level: The log level to set.
         log_level: The log level to set.
+
+    Raises:
+        ValueError: If the log level is invalid.
     """
     """
+    if not isinstance(log_level, LogLevel):
+        deprecate(
+            feature_name="Passing a string to set_log_level",
+            reason="use reflex.constants.LogLevel enum instead",
+            deprecation_version="0.6.6",
+            removal_version="0.7.0",
+        )
+        try:
+            log_level = getattr(LogLevel, log_level.upper())
+        except AttributeError as ae:
+            raise ValueError(f"Invalid log level: {log_level}") from ae
+
     global _LOG_LEVEL
     global _LOG_LEVEL
     _LOG_LEVEL = log_level
     _LOG_LEVEL = log_level
 
 

+ 131 - 5
tests/units/test_state.py

@@ -10,7 +10,17 @@ import os
 import sys
 import sys
 import threading
 import threading
 from textwrap import dedent
 from textwrap import dedent
-from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
+from typing import (
+    Any,
+    AsyncGenerator,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 from unittest.mock import AsyncMock, Mock
 from unittest.mock import AsyncMock, Mock
 
 
 import pytest
 import pytest
@@ -1829,12 +1839,11 @@ async def test_state_manager_lock_expire_contend(
 
 
 
 
 @pytest.fixture(scope="function")
 @pytest.fixture(scope="function")
-def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
-    """Mock app fixture.
+def mock_app_simple(monkeypatch) -> rx.App:
+    """Simple Mock app fixture.
 
 
     Args:
     Args:
         monkeypatch: Pytest monkeypatch object.
         monkeypatch: Pytest monkeypatch object.
-        state_manager: A state manager.
 
 
     Returns:
     Returns:
         The app, after mocking out prerequisites.get_app()
         The app, after mocking out prerequisites.get_app()
@@ -1845,7 +1854,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
 
 
     setattr(app_module, CompileVars.APP, app)
     setattr(app_module, CompileVars.APP, app)
     app.state = TestState
     app.state = TestState
-    app._state_manager = state_manager
     app.event_namespace.emit = AsyncMock()  # type: ignore
     app.event_namespace.emit = AsyncMock()  # type: ignore
 
 
     def _mock_get_app(*args, **kwargs):
     def _mock_get_app(*args, **kwargs):
@@ -1855,6 +1863,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
     return app
     return app
 
 
 
 
+@pytest.fixture(scope="function")
+def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
+    """Mock app fixture.
+
+    Args:
+        mock_app_simple: A simple mock app.
+        state_manager: A state manager.
+
+    Returns:
+        The app, after mocking out prerequisites.get_app()
+    """
+    mock_app_simple._state_manager = state_manager
+    return mock_app_simple
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
 async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     """Test that the state proxy works.
     """Test that the state proxy works.
@@ -3507,3 +3530,106 @@ def test_init_mixin() -> None:
 
 
     with pytest.raises(ReflexRuntimeError):
     with pytest.raises(ReflexRuntimeError):
         SubMixin()
         SubMixin()
+
+
+class ReflexModel(rx.Model):
+    """A model for testing."""
+
+    foo: str
+
+
+class UpcastState(rx.State):
+    """A state for testing upcasting."""
+
+    passed: bool = False
+
+    def rx_model(self, m: ReflexModel):  # noqa: D102
+        assert isinstance(m, ReflexModel)
+        self.passed = True
+
+    def rx_base(self, o: Object):  # noqa: D102
+        assert isinstance(o, Object)
+        self.passed = True
+
+    def rx_base_or_none(self, o: Optional[Object]):  # noqa: D102
+        if o is not None:
+            assert isinstance(o, Object)
+        self.passed = True
+
+    def rx_basemodelv1(self, m: ModelV1):  # noqa: D102
+        assert isinstance(m, ModelV1)
+        self.passed = True
+
+    def rx_basemodelv2(self, m: ModelV2):  # noqa: D102
+        assert isinstance(m, ModelV2)
+        self.passed = True
+
+    def rx_dataclass(self, dc: ModelDC):  # noqa: D102
+        assert isinstance(dc, ModelDC)
+        self.passed = True
+
+    def py_set(self, s: set):  # noqa: D102
+        assert isinstance(s, set)
+        self.passed = True
+
+    def py_Set(self, s: Set):  # noqa: D102
+        assert isinstance(s, Set)
+        self.passed = True
+
+    def py_tuple(self, t: tuple):  # noqa: D102
+        assert isinstance(t, tuple)
+        self.passed = True
+
+    def py_Tuple(self, t: Tuple):  # noqa: D102
+        assert isinstance(t, tuple)
+        self.passed = True
+
+    def py_dict(self, d: dict[str, str]):  # noqa: D102
+        assert isinstance(d, dict)
+        self.passed = True
+
+    def py_list(self, ls: list[str]):  # noqa: D102
+        assert isinstance(ls, list)
+        self.passed = True
+
+    def py_Any(self, a: Any):  # noqa: D102
+        assert isinstance(a, list)
+        self.passed = True
+
+    def py_unresolvable(self, u: "Unresolvable"):  # noqa: D102, F821  # type: ignore
+        assert isinstance(u, list)
+        self.passed = True
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("mock_app_simple")
+@pytest.mark.parametrize(
+    ("handler", "payload"),
+    [
+        (UpcastState.rx_model, {"m": {"foo": "bar"}}),
+        (UpcastState.rx_base, {"o": {"foo": "bar"}}),
+        (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
+        (UpcastState.rx_base_or_none, {"o": None}),
+        (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
+        (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
+        (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
+        (UpcastState.py_set, {"s": ["foo", "foo"]}),
+        (UpcastState.py_Set, {"s": ["foo", "foo"]}),
+        (UpcastState.py_tuple, {"t": ["foo", "foo"]}),
+        (UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
+        (UpcastState.py_dict, {"d": {"foo": "bar"}}),
+        (UpcastState.py_list, {"ls": ["foo", "foo"]}),
+        (UpcastState.py_Any, {"a": ["foo"]}),
+        (UpcastState.py_unresolvable, {"u": ["foo"]}),
+    ],
+)
+async def test_upcast_event_handler_arg(handler, payload):
+    """Test that upcast event handler args work correctly.
+
+    Args:
+        handler: The handler to test.
+        payload: The payload to test.
+    """
+    state = UpcastState()
+    async for update in state._process_event(handler, state, payload):
+        assert update.delta == {UpcastState.get_full_name(): {"passed": True}}