瀏覽代碼

improve sqlalchemy type parsing (#2474)

* improve sqlalchemy type parsing

* add support for propertys and relationships

* cleanup duplicate property check

* avoid confusion, improve readability
benedikt-bartscher 1 年之前
父節點
當前提交
be7f7969ed
共有 1 個文件被更改,包括 35 次插入7 次删除
  1. 35 7
      reflex/utils/types.py

+ 35 - 7
reflex/utils/types.py

@@ -18,9 +18,10 @@ from typing import (
     get_type_hints,
 )
 
+import sqlalchemy
 from pydantic.fields import ModelField
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import DeclarativeBase, Mapped
+from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
 
 from reflex.base import Base
 from reflex.utils import serializers
@@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool:
     return is_union(cls) and type(None) in get_args(cls)
 
 
+def get_property_hint(attr: Any | None) -> GenericType | None:
+    """Check if an attribute is a property and return its type hint.
+
+    Args:
+        attr: The descriptor to check.
+
+    Returns:
+        The type hint of the property, if it is a property, else None.
+    """
+    if not isinstance(attr, (property, hybrid_property)):
+        return None
+    hints = get_type_hints(attr.fget)
+    return hints.get("return", None)
+
+
 def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
     """Check if an attribute can be accessed on the cls and return its type.
 
@@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
     """
     from reflex.model import Model
 
+    attr = getattr(cls, name, None)
+    if hint := get_property_hint(attr):
+        return hint
     if hasattr(cls, "__fields__") and name in cls.__fields__:
         # pydantic models
         field = cls.__fields__[name]
@@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             # Ensure frontend uses null coalescing when accessing.
             type_ = Optional[type_]
         return type_
-    elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)):
+    elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
+        insp = sqlalchemy.inspect(cls)
+        if name in insp.columns:
+            return insp.columns[name].type.python_type
+        if name not in insp.all_orm_descriptors.keys():
+            return None
+        descriptor = insp.all_orm_descriptors[name]
+        if hint := get_property_hint(descriptor):
+            return hint
+        if isinstance(descriptor, QueryableAttribute):
+            prop = descriptor.property
+            if not isinstance(prop, Relationship):
+                return None
+            return prop.mapper.class_
+    elif isinstance(cls, type) and issubclass(cls, Model):
         # Check in the annotations directly (for sqlmodel.Relationship)
         hints = get_type_hints(cls)
         if name in hints:
@@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             if isinstance(type_, ModelField):
                 return type_.type_  # SQLAlchemy v1.4
             return type_
-        if name in cls.__dict__:
-            value = cls.__dict__[name]
-            if isinstance(value, hybrid_property):
-                hints = get_type_hints(value.fget)
-                return hints.get("return", None)
     elif is_union(cls):
         # Check in each arg of the annotation.
         for arg in get_args(cls):