소스 검색

fix: handle default_factory in get_attribute_access_type (#4517)

* fix: handle default_factory in get_attribute_access_type, add tests for sqla dataclasses

* only test classes which have default_factory + add test for no default
benedikt-bartscher 5 달 전
부모
커밋
e4b5755568
2개의 변경된 파일117개의 추가작업 그리고 17개의 파일을 삭제
  1. 5 1
      reflex/utils/types.py
  2. 112 16
      tests/units/test_attribute_access_type.py

+ 5 - 1
reflex/utils/types.py

@@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
         type_ = field.outer_type_
         type_ = field.outer_type_
         if isinstance(type_, ModelField):
         if isinstance(type_, ModelField):
             type_ = type_.type_
             type_ = type_.type_
-        if not field.required and field.default is None:
+        if (
+            not field.required
+            and field.default is None
+            and field.default_factory is None
+        ):
             # Ensure frontend uses null coalescing when accessing.
             # Ensure frontend uses null coalescing when accessing.
             type_ = Optional[type_]
             type_ = Optional[type_]
         return type_
         return type_

+ 112 - 16
tests/units/test_attribute_access_type.py

@@ -3,11 +3,19 @@ from __future__ import annotations
 from typing import Dict, List, Optional, Type, Union
 from typing import Dict, List, Optional, Type, Union
 
 
 import attrs
 import attrs
+import pydantic.v1
 import pytest
 import pytest
 import sqlalchemy
 import sqlalchemy
+import sqlmodel
 from sqlalchemy import JSON, TypeDecorator
 from sqlalchemy import JSON, TypeDecorator
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy.orm import (
+    DeclarativeBase,
+    Mapped,
+    MappedAsDataclass,
+    mapped_column,
+    relationship,
+)
 
 
 import reflex as rx
 import reflex as rx
 from reflex.utils.types import GenericType, get_attribute_access_type
 from reflex.utils.types import GenericType, get_attribute_access_type
@@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
     id: Mapped[int] = mapped_column(primary_key=True)
     id: Mapped[int] = mapped_column(primary_key=True)
     test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
     test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
     test: Mapped[SQLAClass] = relationship(back_populates="labels")
     test: Mapped[SQLAClass] = relationship(back_populates="labels")
+    test_dataclass_id: Mapped[int] = mapped_column(
+        sqlalchemy.ForeignKey("test_dataclass.id")
+    )
+    test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
 
 
 
 
 class SQLAClass(SQLABase):
 class SQLAClass(SQLABase):
@@ -104,9 +116,64 @@ class SQLAClass(SQLABase):
         return self.labels[0] if self.labels else None
         return self.labels[0] if self.labels else None
 
 
 
 
+class SQLAClassDataclass(MappedAsDataclass, SQLABase):
+    """Test sqlalchemy model."""
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    no_default: Mapped[int] = mapped_column(nullable=True)
+    count: Mapped[int] = mapped_column()
+    name: Mapped[str] = mapped_column()
+    int_list: Mapped[List[int]] = mapped_column(
+        sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
+    )
+    str_list: Mapped[List[str]] = mapped_column(
+        sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
+    )
+    optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
+    sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
+    sqla_tag: Mapped[Optional[SQLATag]] = relationship()
+    labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
+    # do not use lower case dict here!
+    # https://github.com/sqlalchemy/sqlalchemy/issues/9902
+    dict_str_str: Mapped[Dict[str, str]] = mapped_column()
+    default_factory: Mapped[List[int]] = mapped_column(
+        sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
+        default_factory=list,
+    )
+    __tablename__: str = "test_dataclass"
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+    @hybrid_property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+    @hybrid_property
+    def first_label(self) -> Optional[SQLALabel]:
+        """First label property.
+
+        Returns:
+            First label
+        """
+        return self.labels[0] if self.labels else None
+
+
 class ModelClass(rx.Model):
 class ModelClass(rx.Model):
     """Test reflex model."""
     """Test reflex model."""
 
 
+    no_default: Optional[int] = sqlmodel.Field(nullable=True)
     count: int = 0
     count: int = 0
     name: str = "test"
     name: str = "test"
     int_list: List[int] = []
     int_list: List[int] = []
@@ -115,6 +182,7 @@ class ModelClass(rx.Model):
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
     dict_str_str: Dict[str, str] = {}
     dict_str_str: Dict[str, str] = {}
+    default_factory: List[int] = sqlmodel.Field(default_factory=list)
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -147,6 +215,7 @@ class ModelClass(rx.Model):
 class BaseClass(rx.Base):
 class BaseClass(rx.Base):
     """Test rx.Base class."""
     """Test rx.Base class."""
 
 
+    no_default: Optional[int] = pydantic.v1.Field(required=False)
     count: int = 0
     count: int = 0
     name: str = "test"
     name: str = "test"
     int_list: List[int] = []
     int_list: List[int] = []
@@ -155,6 +224,7 @@ class BaseClass(rx.Base):
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
     dict_str_str: Dict[str, str] = {}
     dict_str_str: Dict[str, str] = {}
+    default_factory: List[int] = pydantic.v1.Field(default_factory=list)
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -236,6 +306,7 @@ class AttrClass:
     sqla_tag: Optional[SQLATag] = None
     sqla_tag: Optional[SQLATag] = None
     labels: List[SQLALabel] = []
     labels: List[SQLALabel] = []
     dict_str_str: Dict[str, str] = {}
     dict_str_str: Dict[str, str] = {}
+    default_factory: List[int] = attrs.field(factory=list)
 
 
     @property
     @property
     def str_property(self) -> str:
     def str_property(self) -> str:
@@ -265,27 +336,17 @@ class AttrClass:
         return self.labels[0] if self.labels else None
         return self.labels[0] if self.labels else None
 
 
 
 
-@pytest.fixture(
-    params=[
+@pytest.mark.parametrize(
+    "cls",
+    [
         SQLAClass,
         SQLAClass,
+        SQLAClassDataclass,
         BaseClass,
         BaseClass,
         BareClass,
         BareClass,
         ModelClass,
         ModelClass,
         AttrClass,
         AttrClass,
-    ]
+    ],
 )
 )
-def cls(request: pytest.FixtureRequest) -> type:
-    """Fixture for the class to test.
-
-    Args:
-        request: pytest request object.
-
-    Returns:
-        Class to test.
-    """
-    return request.param
-
-
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "attr, expected",
     "attr, expected",
     [
     [
@@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
         expected: Expected type.
         expected: Expected type.
     """
     """
     assert get_attribute_access_type(cls, attr) == expected
     assert get_attribute_access_type(cls, attr) == expected
+
+
+@pytest.mark.parametrize(
+    "cls",
+    [
+        SQLAClassDataclass,
+        BaseClass,
+        ModelClass,
+        AttrClass,
+    ],
+)
+def test_get_attribute_access_type_default_factory(cls: type) -> None:
+    """Test get_attribute_access_type returns the correct type for default factory fields.
+
+    Args:
+        cls: Class to test.
+    """
+    assert get_attribute_access_type(cls, "default_factory") == List[int]
+
+
+@pytest.mark.parametrize(
+    "cls",
+    [
+        SQLAClassDataclass,
+        BaseClass,
+        ModelClass,
+    ],
+)
+def test_get_attribute_access_type_no_default(cls: type) -> None:
+    """Test get_attribute_access_type returns the correct type for fields with no default which are not required.
+
+    Args:
+        cls: Class to test.
+    """
+    assert get_attribute_access_type(cls, "no_default") == Optional[int]