|
@@ -18,9 +18,10 @@ from typing import (
|
|
|
get_type_hints,
|
|
|
)
|
|
|
|
|
|
+import sqlalchemy
|
|
|
from pydantic.fields import ModelField
|
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
|
-from sqlalchemy.orm import DeclarativeBase, Mapped
|
|
|
+from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
|
|
|
|
|
|
from reflex.base import Base
|
|
|
from reflex.utils import serializers
|
|
@@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool:
|
|
|
return is_union(cls) and type(None) in get_args(cls)
|
|
|
|
|
|
|
|
|
+def get_property_hint(attr: Any | None) -> GenericType | None:
|
|
|
+ """Check if an attribute is a property and return its type hint.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ attr: The descriptor to check.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The type hint of the property, if it is a property, else None.
|
|
|
+ """
|
|
|
+ if not isinstance(attr, (property, hybrid_property)):
|
|
|
+ return None
|
|
|
+ hints = get_type_hints(attr.fget)
|
|
|
+ return hints.get("return", None)
|
|
|
+
|
|
|
+
|
|
|
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
|
|
"""Check if an attribute can be accessed on the cls and return its type.
|
|
|
|
|
@@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|
|
"""
|
|
|
from reflex.model import Model
|
|
|
|
|
|
+ attr = getattr(cls, name, None)
|
|
|
+ if hint := get_property_hint(attr):
|
|
|
+ return hint
|
|
|
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
|
|
# pydantic models
|
|
|
field = cls.__fields__[name]
|
|
@@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|
|
# Ensure frontend uses null coalescing when accessing.
|
|
|
type_ = Optional[type_]
|
|
|
return type_
|
|
|
- elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)):
|
|
|
+ elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
|
|
+ insp = sqlalchemy.inspect(cls)
|
|
|
+ if name in insp.columns:
|
|
|
+ return insp.columns[name].type.python_type
|
|
|
+ if name not in insp.all_orm_descriptors.keys():
|
|
|
+ return None
|
|
|
+ descriptor = insp.all_orm_descriptors[name]
|
|
|
+ if hint := get_property_hint(descriptor):
|
|
|
+ return hint
|
|
|
+ if isinstance(descriptor, QueryableAttribute):
|
|
|
+ prop = descriptor.property
|
|
|
+ if not isinstance(prop, Relationship):
|
|
|
+ return None
|
|
|
+ return prop.mapper.class_
|
|
|
+ elif isinstance(cls, type) and issubclass(cls, Model):
|
|
|
# Check in the annotations directly (for sqlmodel.Relationship)
|
|
|
hints = get_type_hints(cls)
|
|
|
if name in hints:
|
|
@@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|
|
if isinstance(type_, ModelField):
|
|
|
return type_.type_ # SQLAlchemy v1.4
|
|
|
return type_
|
|
|
- if name in cls.__dict__:
|
|
|
- value = cls.__dict__[name]
|
|
|
- if isinstance(value, hybrid_property):
|
|
|
- hints = get_type_hints(value.fget)
|
|
|
- return hints.get("return", None)
|
|
|
elif is_union(cls):
|
|
|
# Check in each arg of the annotation.
|
|
|
for arg in get_args(cls):
|