Explorar el Código

Improved get_attribute_access_type (#3156)

benedikt-bartscher hace 1 año
padre
commit
c2017b295e
Se han modificado 2 ficheros con 208 adiciones y 8 borrados
  1. 47 8
      reflex/utils/types.py
  2. 161 0
      tests/test_attribute_access_type.py

+ 47 - 8
reflex/utils/types.py

@@ -4,16 +4,19 @@ from __future__ import annotations
 
 
 import contextlib
 import contextlib
 import inspect
 import inspect
+import sys
 import types
 import types
 from functools import wraps
 from functools import wraps
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
     Callable,
     Callable,
+    Dict,
     Iterable,
     Iterable,
     List,
     List,
     Literal,
     Literal,
     Optional,
     Optional,
+    Tuple,
     Type,
     Type,
     Union,
     Union,
     _GenericAlias,  # type: ignore
     _GenericAlias,  # type: ignore
@@ -37,11 +40,16 @@ except ModuleNotFoundError:
 
 
 from sqlalchemy.ext.associationproxy import AssociationProxyInstance
 from sqlalchemy.ext.associationproxy import AssociationProxyInstance
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
+from sqlalchemy.orm import (
+    DeclarativeBase,
+    Mapped,
+    QueryableAttribute,
+    Relationship,
+)
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
-from reflex.utils import serializers
+from reflex.utils import console, serializers
 
 
 # Potential GenericAlias types for isinstance checks.
 # Potential GenericAlias types for isinstance checks.
 GenericAliasTypes = [_GenericAlias]
 GenericAliasTypes = [_GenericAlias]
@@ -76,6 +84,13 @@ StateIterVar = Union[list, set, tuple]
 ArgsSpec = Callable
 ArgsSpec = Callable
 
 
 
 
+PrimitiveToAnnotation = {
+    list: List,
+    tuple: Tuple,
+    dict: Dict,
+}
+
+
 class Unset:
 class Unset:
     """A class to represent an unset value.
     """A class to represent an unset value.
 
 
@@ -192,7 +207,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
     elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
     elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
         insp = sqlalchemy.inspect(cls)
         insp = sqlalchemy.inspect(cls)
         if name in insp.columns:
         if name in insp.columns:
-            return insp.columns[name].type.python_type
+            # check for list types
+            column = insp.columns[name]
+            column_type = column.type
+            type_ = insp.columns[name].type.python_type
+            if hasattr(column_type, "item_type") and (
+                item_type := column_type.item_type.python_type  # type: ignore
+            ):
+                if type_ in PrimitiveToAnnotation:
+                    type_ = PrimitiveToAnnotation[type_]  # type: ignore
+                type_ = type_[item_type]  # type: ignore
+            if column.nullable:
+                type_ = Optional[type_]
+            return type_
         if name not in insp.all_orm_descriptors:
         if name not in insp.all_orm_descriptors:
             return None
             return None
         descriptor = insp.all_orm_descriptors[name]
         descriptor = insp.all_orm_descriptors[name]
@@ -202,11 +229,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             prop = descriptor.property
             prop = descriptor.property
             if not isinstance(prop, Relationship):
             if not isinstance(prop, Relationship):
                 return None
                 return None
-            class_ = prop.mapper.class_
-            if prop.uselist:
-                return List[class_]
-            else:
-                return class_
+            type_ = prop.mapper.class_
+            # TODO: check for nullable?
+            type_ = List[type_] if prop.uselist else Optional[type_]
+            return type_
         if isinstance(attr, AssociationProxyInstance):
         if isinstance(attr, AssociationProxyInstance):
             return List[
             return List[
                 get_attribute_access_type(
                 get_attribute_access_type(
@@ -232,6 +258,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             if type_ is not None:
             if type_ is not None:
                 # Return the first attribute type that is accessible.
                 # Return the first attribute type that is accessible.
                 return type_
                 return type_
+    elif isinstance(cls, type):
+        # Bare class
+        if sys.version_info >= (3, 10):
+            exceptions = NameError
+        else:
+            exceptions = (NameError, TypeError)
+        try:
+            hints = get_type_hints(cls)
+            if name in hints:
+                return hints[name]
+        except exceptions as e:
+            console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
+            pass
     return None  # Attribute is not accessible.
     return None  # Attribute is not accessible.
 
 
 
 

+ 161 - 0
tests/test_attribute_access_type.py

@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+from typing import List, Optional
+
+import pytest
+import sqlalchemy
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+
+import reflex as rx
+from reflex.utils.types import GenericType, get_attribute_access_type
+
+
+class SQLABase(DeclarativeBase):
+    """Base class for bare SQLAlchemy models."""
+
+    pass
+
+
+class SQLATag(SQLABase):
+    """Tag sqlalchemy model."""
+
+    __tablename__: str = "tag"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str] = mapped_column()
+
+
+class SQLALabel(SQLABase):
+    """Label sqlalchemy model."""
+
+    __tablename__: str = "label"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
+    test: Mapped[SQLAClass] = relationship(back_populates="labels")
+
+
+class SQLAClass(SQLABase):
+    """Test sqlalchemy model."""
+
+    __tablename__: str = "test"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    count: Mapped[int] = mapped_column()
+    name: Mapped[str] = mapped_column()
+    int_list: Mapped[List[int]] = mapped_column(
+        sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
+    )
+    str_list: Mapped[List[str]] = mapped_column(
+        sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
+    )
+    optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
+    sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
+    sqla_tag: Mapped[Optional[SQLATag]] = relationship()
+    labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+
+class ModelClass(rx.Model):
+    """Test reflex model."""
+
+    count: int = 0
+    name: str = "test"
+    int_list: List[int] = []
+    str_list: List[str] = []
+    optional_int: Optional[int] = None
+    sqla_tag: Optional[SQLATag] = None
+    labels: List[SQLALabel] = []
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+
+class BaseClass(rx.Base):
+    """Test rx.Base class."""
+
+    count: int = 0
+    name: str = "test"
+    int_list: List[int] = []
+    str_list: List[str] = []
+    optional_int: Optional[int] = None
+    sqla_tag: Optional[SQLATag] = None
+    labels: List[SQLALabel] = []
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+
+class BareClass:
+    """Bare python class."""
+
+    count: int = 0
+    name: str = "test"
+    int_list: List[int] = []
+    str_list: List[str] = []
+    optional_int: Optional[int] = None
+    sqla_tag: Optional[SQLATag] = None
+    labels: List[SQLALabel] = []
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+
+@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass])
+def cls(request: pytest.FixtureRequest) -> type:
+    """Fixture for the class to test.
+
+    Args:
+        request: pytest request object.
+
+    Returns:
+        Class to test.
+    """
+    return request.param
+
+
+@pytest.mark.parametrize(
+    "attr, expected",
+    [
+        pytest.param("count", int, id="int"),
+        pytest.param("name", str, id="str"),
+        pytest.param("int_list", List[int], id="List[int]"),
+        pytest.param("str_list", List[str], id="List[str]"),
+        pytest.param("optional_int", Optional[int], id="Optional[int]"),
+        pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
+        pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
+        pytest.param("str_property", str, id="str_property"),
+    ],
+)
+def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
+    """Test get_attribute_access_type returns the correct type.
+
+    Args:
+        cls: Class to test.
+        attr: Attribute to test.
+        expected: Expected type.
+    """
+    assert get_attribute_access_type(cls, attr) == expected