Browse Source

add Bare SQLAlchemy mutation tracking, improve typing (#3628)

benedikt-bartscher 10 months ago
parent
commit
d621115f9b
4 changed files with 120 additions and 31 deletions
  1. 2 1
      reflex/state.py
  2. 1 1
      tests/conftest.py
  3. 48 2
      tests/states/mutation.py
  4. 69 27
      tests/test_state.py

+ 2 - 1
reflex/state.py

@@ -29,6 +29,7 @@ from typing import (
 )
 
 import dill
+from sqlalchemy.orm import DeclarativeBase
 
 try:
     import pydantic.v1 as pydantic
@@ -2963,7 +2964,7 @@ class MutableProxy(wrapt.ObjectProxy):
         pydantic.BaseModel.__dict__
     )
 
-    __mutable_types__ = (list, dict, set, Base)
+    __mutable_types__ = (list, dict, set, Base, DeclarativeBase)
 
     def __init__(self, wrapped: Any, state: BaseState, field_name: str):
         """Create a proxy for a mutable object that tracks changes.

+ 1 - 1
tests/conftest.py

@@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path):
 
 
 @pytest.fixture
-def mutable_state():
+def mutable_state() -> MutableTestState:
     """Create a Test state containing mutable types.
 
     Returns:

+ 48 - 2
tests/states/mutation.py

@@ -2,8 +2,12 @@
 
 from typing import Dict, List, Set, Union
 
+from sqlalchemy import ARRAY, JSON, String
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+
 import reflex as rx
 from reflex.state import BaseState
+from reflex.utils.serializers import serializer
 
 
 class DictMutationTestState(BaseState):
@@ -145,15 +149,47 @@ class CustomVar(rx.Base):
     custom: OtherBase = OtherBase()
 
 
+class MutableSQLABase(DeclarativeBase):
+    """SQLAlchemy base model for mutable vars."""
+
+    pass
+
+
+class MutableSQLAModel(MutableSQLABase):
+    """SQLAlchemy model for mutable vars."""
+
+    __tablename__: str = "mutable_test_state"
+
+    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
+    strlist: Mapped[List[str]] = mapped_column(ARRAY(String))
+    hashmap: Mapped[Dict[str, str]] = mapped_column(JSON)
+    test_set: Mapped[Set[str]] = mapped_column(ARRAY(String))
+
+
+@serializer
+def serialize_mutable_sqla_model(
+    model: MutableSQLAModel,
+) -> Dict[str, Union[List[str], Dict[str, str]]]:
+    """Serialize the MutableSQLAModel.
+
+    Args:
+        model: The MutableSQLAModel instance to serialize.
+
+    Returns:
+        The serialized model.
+    """
+    return {"strlist": model.strlist, "hashmap": model.hashmap}
+
+
 class MutableTestState(BaseState):
     """A test state."""
 
-    array: List[Union[str, List, Dict[str, str]]] = [
+    array: List[Union[str, int, List, Dict[str, str]]] = [
         "value",
         [1, 2, 3],
         {"key": "value"},
     ]
-    hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
+    hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = {
         "key": ["list", "of", "values"],
         "another_key": "another_value",
         "third_key": {"key": "value"},
@@ -161,6 +197,11 @@ class MutableTestState(BaseState):
     test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
     custom: CustomVar = CustomVar()
     _be_custom: CustomVar = CustomVar()
+    sqla_model: MutableSQLAModel = MutableSQLAModel(
+        strlist=["a", "b", "c"],
+        hashmap={"key": "value"},
+        test_set={"one", "two", "three"},
+    )
 
     def reassign_mutables(self):
         """Assign mutable fields to different values."""
@@ -171,3 +212,8 @@ class MutableTestState(BaseState):
             "mod_third_key": {"key": "value"},
         }
         self.test_set = {1, 2, 3, 4, "five"}
+        self.sqla_model = MutableSQLAModel(
+            strlist=["d", "e", "f"],
+            hashmap={"key": "value"},
+            test_set={"one", "two", "three"},
+        )

+ 69 - 27
tests/test_state.py

@@ -8,7 +8,7 @@ import json
 import os
 import sys
 from textwrap import dedent
-from typing import Any, Dict, Generator, List, Optional, Union
+from typing import Any, Callable, Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 
 import pytest
@@ -40,6 +40,7 @@ from reflex.testing import chdir
 from reflex.utils import format, prerequisites, types
 from reflex.utils.format import json_dumps
 from reflex.vars import BaseVar, ComputedVar
+from tests.states.mutation import MutableSQLAModel, MutableTestState
 
 from .states import GenState
 
@@ -1389,7 +1390,7 @@ def test_backend_method():
     assert bms._be_method()
 
 
-def test_setattr_of_mutable_types(mutable_state):
+def test_setattr_of_mutable_types(mutable_state: MutableTestState):
     """Test that mutable types are converted to corresponding Reflex wrappers.
 
     Args:
@@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state):
     array = mutable_state.array
     hashmap = mutable_state.hashmap
     test_set = mutable_state.test_set
+    sqla_model = mutable_state.sqla_model
 
     assert isinstance(array, MutableProxy)
     assert isinstance(array, list)
@@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state):
     assert isinstance(mutable_state.custom.test_set, set)
     assert isinstance(mutable_state.custom.custom, MutableProxy)
 
+    assert isinstance(sqla_model, MutableProxy)
+    assert isinstance(sqla_model, MutableSQLAModel)
+    assert isinstance(sqla_model.strlist, MutableProxy)
+    assert isinstance(sqla_model.strlist, list)
+    assert isinstance(sqla_model.hashmap, MutableProxy)
+    assert isinstance(sqla_model.hashmap, dict)
+    assert isinstance(sqla_model.test_set, MutableProxy)
+    assert isinstance(sqla_model.test_set, set)
+
     mutable_state.reassign_mutables()
 
     array = mutable_state.array
     hashmap = mutable_state.hashmap
     test_set = mutable_state.test_set
+    sqla_model = mutable_state.sqla_model
 
     assert isinstance(array, MutableProxy)
     assert isinstance(array, list)
@@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state):
     assert isinstance(test_set, MutableProxy)
     assert isinstance(test_set, set)
 
+    assert isinstance(sqla_model, MutableProxy)
+    assert isinstance(sqla_model, MutableSQLAModel)
+    assert isinstance(sqla_model.strlist, MutableProxy)
+    assert isinstance(sqla_model.strlist, list)
+    assert isinstance(sqla_model.hashmap, MutableProxy)
+    assert isinstance(sqla_model.hashmap, dict)
+    assert isinstance(sqla_model.test_set, MutableProxy)
+    assert isinstance(sqla_model.test_set, set)
+
 
 def test_error_on_state_method_shadow():
     """Test that an error is thrown when an event handler shadows a state method."""
@@ -2091,7 +2112,7 @@ async def test_background_task_no_chain():
         await bts.bad_chain2()
 
 
-def test_mutable_list(mutable_state):
+def test_mutable_list(mutable_state: MutableTestState):
     """Test that mutable lists are tracked correctly.
 
     Args:
@@ -2121,7 +2142,7 @@ def test_mutable_list(mutable_state):
     assert_array_dirty()
     mutable_state.array.reverse()
     assert_array_dirty()
-    mutable_state.array.sort()
+    mutable_state.array.sort()  # type: ignore[reportCallIssue,reportUnknownMemberType]
     assert_array_dirty()
     mutable_state.array[0] = 666
     assert_array_dirty()
@@ -2145,7 +2166,7 @@ def test_mutable_list(mutable_state):
         assert_array_dirty()
 
 
-def test_mutable_dict(mutable_state):
+def test_mutable_dict(mutable_state: MutableTestState):
     """Test that mutable dicts are tracked correctly.
 
     Args:
@@ -2159,40 +2180,40 @@ def test_mutable_dict(mutable_state):
         assert not mutable_state.dirty_vars
 
     # Test all dict operations
-    mutable_state.hashmap.update({"new_key": 43})
+    mutable_state.hashmap.update({"new_key": "43"})
     assert_hashmap_dirty()
-    assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
+    assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value"
     assert_hashmap_dirty()
-    assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
+    assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67"
     assert_hashmap_dirty()
-    assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
+    assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67"
     assert_hashmap_dirty()
-    assert mutable_state.hashmap.pop("new_key") == 43
+    assert mutable_state.hashmap.pop("new_key") == "43"
     assert_hashmap_dirty()
     mutable_state.hashmap.popitem()
     assert_hashmap_dirty()
     mutable_state.hashmap.clear()
     assert_hashmap_dirty()
-    mutable_state.hashmap["new_key"] = 42
+    mutable_state.hashmap["new_key"] = "42"
     assert_hashmap_dirty()
     del mutable_state.hashmap["new_key"]
     assert_hashmap_dirty()
     if sys.version_info >= (3, 9):
-        mutable_state.hashmap |= {"new_key": 44}
+        mutable_state.hashmap |= {"new_key": "44"}
         assert_hashmap_dirty()
 
     # Test nested dict operations
     mutable_state.hashmap["array"] = []
     assert_hashmap_dirty()
-    mutable_state.hashmap["array"].append(1)
+    mutable_state.hashmap["array"].append("1")
     assert_hashmap_dirty()
     mutable_state.hashmap["dict"] = {}
     assert_hashmap_dirty()
-    mutable_state.hashmap["dict"]["key"] = 42
+    mutable_state.hashmap["dict"]["key"] = "42"
     assert_hashmap_dirty()
     mutable_state.hashmap["dict"]["dict"] = {}
     assert_hashmap_dirty()
-    mutable_state.hashmap["dict"]["dict"]["key"] = 43
+    mutable_state.hashmap["dict"]["dict"]["key"] = "43"
     assert_hashmap_dirty()
 
     # Test proxy returned from `setdefault` and `get`
@@ -2214,14 +2235,14 @@ def test_mutable_dict(mutable_state):
     mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
     assert not isinstance(mutable_value_third_ref, MutableProxy)
     assert_hashmap_dirty()
-    mutable_value_third_ref.append("baz")
+    mutable_value_third_ref.append("baz")  # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult]
     assert not mutable_state.dirty_vars
     # Unfortunately previous refs still will mark the state dirty... nothing doing about that
     assert mutable_value.pop()
     assert_hashmap_dirty()
 
 
-def test_mutable_set(mutable_state):
+def test_mutable_set(mutable_state: MutableTestState):
     """Test that mutable sets are tracked correctly.
 
     Args:
@@ -2263,7 +2284,7 @@ def test_mutable_set(mutable_state):
     assert_set_dirty()
 
 
-def test_mutable_custom(mutable_state):
+def test_mutable_custom(mutable_state: MutableTestState):
     """Test that mutable custom types derived from Base are tracked correctly.
 
     Args:
@@ -2278,17 +2299,38 @@ def test_mutable_custom(mutable_state):
 
     mutable_state.custom.foo = "bar"
     assert_custom_dirty()
-    mutable_state.custom.array.append(42)
+    mutable_state.custom.array.append("42")
     assert_custom_dirty()
-    mutable_state.custom.hashmap["key"] = 68
+    mutable_state.custom.hashmap["key"] = "value"
     assert_custom_dirty()
-    mutable_state.custom.test_set.add(42)
+    mutable_state.custom.test_set.add("foo")
     assert_custom_dirty()
     mutable_state.custom.custom.bar = "baz"
     assert_custom_dirty()
 
 
-def test_mutable_backend(mutable_state):
+def test_mutable_sqla_model(mutable_state: MutableTestState):
+    """Test that mutable SQLA models are tracked correctly.
+
+    Args:
+        mutable_state: A test state.
+    """
+    assert not mutable_state.dirty_vars
+
+    def assert_sqla_model_dirty():
+        assert mutable_state.dirty_vars == {"sqla_model"}
+        mutable_state._clean()
+        assert not mutable_state.dirty_vars
+
+    mutable_state.sqla_model.strlist.append("foo")
+    assert_sqla_model_dirty()
+    mutable_state.sqla_model.hashmap["key"] = "value"
+    assert_sqla_model_dirty()
+    mutable_state.sqla_model.test_set.add("bar")
+    assert_sqla_model_dirty()
+
+
+def test_mutable_backend(mutable_state: MutableTestState):
     """Test that mutable backend vars are tracked correctly.
 
     Args:
@@ -2303,11 +2345,11 @@ def test_mutable_backend(mutable_state):
 
     mutable_state._be_custom.foo = "bar"
     assert_custom_dirty()
-    mutable_state._be_custom.array.append(42)
+    mutable_state._be_custom.array.append("baz")
     assert_custom_dirty()
-    mutable_state._be_custom.hashmap["key"] = 68
+    mutable_state._be_custom.hashmap["key"] = "value"
     assert_custom_dirty()
-    mutable_state._be_custom.test_set.add(42)
+    mutable_state._be_custom.test_set.add("foo")
     assert_custom_dirty()
     mutable_state._be_custom.custom.bar = "baz"
     assert_custom_dirty()
@@ -2320,7 +2362,7 @@ def test_mutable_backend(mutable_state):
         (copy.deepcopy,),
     ],
 )
-def test_mutable_copy(mutable_state, copy_func):
+def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable):
     """Test that mutable types are copied correctly.
 
     Args:
@@ -2347,7 +2389,7 @@ def test_mutable_copy(mutable_state, copy_func):
         (copy.deepcopy,),
     ],
 )
-def test_mutable_copy_vars(mutable_state, copy_func):
+def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable):
     """Test that mutable types are copied correctly.
 
     Args: