1
0
Эх сурвалжийг харах

fix sqla python_type issues, add tests (#3613)

benedikt-bartscher 10 сар өмнө
parent
commit
6d3321284c

+ 35 - 30
reflex/utils/types.py

@@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             # check for list types
             # check for list types
             column = insp.columns[name]
             column = insp.columns[name]
             column_type = column.type
             column_type = column.type
-            type_ = insp.columns[name].type.python_type
-            if hasattr(column_type, "item_type") and (
-                item_type := column_type.item_type.python_type  # type: ignore
-            ):
-                if type_ in PrimitiveToAnnotation:
-                    type_ = PrimitiveToAnnotation[type_]  # type: ignore
-                type_ = type_[item_type]  # type: ignore
-            if column.nullable:
-                type_ = Optional[type_]
-            return type_
-        if name not in insp.all_orm_descriptors:
-            return None
-        descriptor = insp.all_orm_descriptors[name]
-        if hint := get_property_hint(descriptor):
-            return hint
-        if isinstance(descriptor, QueryableAttribute):
-            prop = descriptor.property
-            if not isinstance(prop, Relationship):
-                return None
-            type_ = prop.mapper.class_
-            # TODO: check for nullable?
-            type_ = List[type_] if prop.uselist else Optional[type_]
-            return type_
-        if isinstance(attr, AssociationProxyInstance):
-            return List[
-                get_attribute_access_type(
-                    attr.target_class,
-                    attr.remote_attr.key,  # type: ignore[attr-defined]
-                )
-            ]
+            try:
+                type_ = insp.columns[name].type.python_type
+            except NotImplementedError:
+                type_ = None
+            if type_ is not None:
+                if hasattr(column_type, "item_type"):
+                    try:
+                        item_type = column_type.item_type.python_type  # type: ignore
+                    except NotImplementedError:
+                        item_type = None
+                    if item_type is not None:
+                        if type_ in PrimitiveToAnnotation:
+                            type_ = PrimitiveToAnnotation[type_]  # type: ignore
+                        type_ = type_[item_type]  # type: ignore
+                if column.nullable:
+                    type_ = Optional[type_]
+                return type_
+        if name in insp.all_orm_descriptors:
+            descriptor = insp.all_orm_descriptors[name]
+            if hint := get_property_hint(descriptor):
+                return hint
+            if isinstance(descriptor, QueryableAttribute):
+                prop = descriptor.property
+                if isinstance(prop, Relationship):
+                    type_ = prop.mapper.class_
+                    # TODO: check for nullable?
+                    type_ = List[type_] if prop.uselist else Optional[type_]
+                    return type_
+            if isinstance(attr, AssociationProxyInstance):
+                return List[
+                    get_attribute_access_type(
+                        attr.target_class,
+                        attr.remote_attr.key,  # type: ignore[attr-defined]
+                    )
+                ]
     elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
     elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
         # Check in the annotations directly (for sqlmodel.Relationship)
         # Check in the annotations directly (for sqlmodel.Relationship)
         hints = get_type_hints(cls)
         hints = get_type_hints(cls)

+ 39 - 3
tests/test_attribute_access_type.py

@@ -1,10 +1,11 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import List, Optional, Union
+from typing import Dict, List, Optional, Type, Union
 
 
 import attrs
 import attrs
 import pytest
 import pytest
 import sqlalchemy
 import sqlalchemy
+from sqlalchemy import JSON, TypeDecorator
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
 
 
@@ -12,10 +13,29 @@ import reflex as rx
 from reflex.utils.types import GenericType, get_attribute_access_type
 from reflex.utils.types import GenericType, get_attribute_access_type
 
 
 
 
+class SQLAType(TypeDecorator):
+    """SQLAlchemy custom dict type."""
+
+    impl = JSON
+
+    @property
+    def python_type(self) -> Type[Dict[str, str]]:
+        """Python type.
+
+        Returns:
+            Python Type of the column.
+        """
+        return Dict[str, str]
+
+
 class SQLABase(DeclarativeBase):
 class SQLABase(DeclarativeBase):
     """Base class for bare SQLAlchemy models."""
     """Base class for bare SQLAlchemy models."""
 
 
-    pass
+    type_annotation_map = {
+        # do not use lower case dict here!
+        # https://github.com/sqlalchemy/sqlalchemy/issues/9902
+        Dict[str, str]: SQLAType,
+    }
 
 
 
 
 class SQLATag(SQLABase):
 class SQLATag(SQLABase):
@@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
     sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
     sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
     sqla_tag: Mapped[Optional[SQLATag]] = relationship()
     sqla_tag: Mapped[Optional[SQLATag]] = relationship()
     labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
     labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
+    # do not use lower case dict here!
+    # https://github.com/sqlalchemy/sqlalchemy/issues/9902
+    dict_str_str: Mapped[Dict[str, str]] = mapped_column()
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -82,6 +105,7 @@ class ModelClass(rx.Model):
     optional_int: Optional[int] = None
     optional_int: Optional[int] = None
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
+    dict_str_str: Dict[str, str] = {}
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -112,6 +136,7 @@ class BaseClass(rx.Base):
     optional_int: Optional[int] = None
     optional_int: Optional[int] = None
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
+    dict_str_str: Dict[str, str] = {}
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -142,6 +167,7 @@ class BareClass:
     optional_int: Optional[int] = None
     optional_int: Optional[int] = None
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
+    dict_str_str: Dict[str, str] = {}
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -173,6 +199,7 @@ class AttrClass:
     optional_int: Optional[int] = None
     optional_int: Optional[int] = None
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
+    dict_str_str: Dict[str, str] = {}
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -193,7 +220,15 @@ class AttrClass:
         return self.name
         return self.name
 
 
 
 
-@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass])
+@pytest.fixture(
+    params=[
+        SQLAClass,
+        BaseClass,
+        BareClass,
+        ModelClass,
+        AttrClass,
+    ]
+)
 def cls(request: pytest.FixtureRequest) -> type:
 def cls(request: pytest.FixtureRequest) -> type:
     """Fixture for the class to test.
     """Fixture for the class to test.
 
 
@@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
         pytest.param("optional_int", Optional[int], id="Optional[int]"),
         pytest.param("optional_int", Optional[int], id="Optional[int]"),
         pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
         pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
         pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
         pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
+        pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
         pytest.param("str_property", str, id="str_property"),
         pytest.param("str_property", str, id="str_property"),
         pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
         pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
     ],
     ],