|
@@ -3,11 +3,19 @@ from __future__ import annotations
|
|
from typing import Dict, List, Optional, Type, Union
|
|
from typing import Dict, List, Optional, Type, Union
|
|
|
|
|
|
import attrs
|
|
import attrs
|
|
|
|
+import pydantic.v1
|
|
import pytest
|
|
import pytest
|
|
import sqlalchemy
|
|
import sqlalchemy
|
|
|
|
+import sqlmodel
|
|
from sqlalchemy import JSON, TypeDecorator
|
|
from sqlalchemy import JSON, TypeDecorator
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
|
|
|
|
|
+from sqlalchemy.orm import (
|
|
|
|
+ DeclarativeBase,
|
|
|
|
+ Mapped,
|
|
|
|
+ MappedAsDataclass,
|
|
|
|
+ mapped_column,
|
|
|
|
+ relationship,
|
|
|
|
+)
|
|
|
|
|
|
import reflex as rx
|
|
import reflex as rx
|
|
from reflex.utils.types import GenericType, get_attribute_access_type
|
|
from reflex.utils.types import GenericType, get_attribute_access_type
|
|
@@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
|
|
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
|
|
test: Mapped[SQLAClass] = relationship(back_populates="labels")
|
|
test: Mapped[SQLAClass] = relationship(back_populates="labels")
|
|
|
|
+ test_dataclass_id: Mapped[int] = mapped_column(
|
|
|
|
+ sqlalchemy.ForeignKey("test_dataclass.id")
|
|
|
|
+ )
|
|
|
|
+ test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
|
|
|
|
|
|
|
|
|
|
class SQLAClass(SQLABase):
|
|
class SQLAClass(SQLABase):
|
|
@@ -104,9 +116,64 @@ class SQLAClass(SQLABase):
|
|
return self.labels[0] if self.labels else None
|
|
return self.labels[0] if self.labels else None
|
|
|
|
|
|
|
|
|
|
|
|
+class SQLAClassDataclass(MappedAsDataclass, SQLABase):
|
|
|
|
+ """Test sqlalchemy model."""
|
|
|
|
+
|
|
|
|
+ id: Mapped[int] = mapped_column(primary_key=True)
|
|
|
|
+ no_default: Mapped[int] = mapped_column(nullable=True)
|
|
|
|
+ count: Mapped[int] = mapped_column()
|
|
|
|
+ name: Mapped[str] = mapped_column()
|
|
|
|
+ int_list: Mapped[List[int]] = mapped_column(
|
|
|
|
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
|
|
|
|
+ )
|
|
|
|
+ str_list: Mapped[List[str]] = mapped_column(
|
|
|
|
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
|
|
|
|
+ )
|
|
|
|
+ optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
|
|
|
|
+ sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
|
|
|
+ sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
|
|
|
+ labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
|
|
|
|
+ # do not use lower case dict here!
|
|
|
|
+ # https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
|
|
|
+ dict_str_str: Mapped[Dict[str, str]] = mapped_column()
|
|
|
|
+ default_factory: Mapped[List[int]] = mapped_column(
|
|
|
|
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
|
|
|
|
+ default_factory=list,
|
|
|
|
+ )
|
|
|
|
+ __tablename__: str = "test_dataclass"
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def str_property(self) -> str:
|
|
|
|
+ """String property.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ Name attribute
|
|
|
|
+ """
|
|
|
|
+ return self.name
|
|
|
|
+
|
|
|
|
+ @hybrid_property
|
|
|
|
+ def str_or_int_property(self) -> Union[str, int]:
|
|
|
|
+ """String or int property.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ Name attribute
|
|
|
|
+ """
|
|
|
|
+ return self.name
|
|
|
|
+
|
|
|
|
+ @hybrid_property
|
|
|
|
+ def first_label(self) -> Optional[SQLALabel]:
|
|
|
|
+ """First label property.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ First label
|
|
|
|
+ """
|
|
|
|
+ return self.labels[0] if self.labels else None
|
|
|
|
+
|
|
|
|
+
|
|
class ModelClass(rx.Model):
|
|
class ModelClass(rx.Model):
|
|
"""Test reflex model."""
|
|
"""Test reflex model."""
|
|
|
|
|
|
|
|
+ no_default: Optional[int] = sqlmodel.Field(nullable=True)
|
|
count: int = 0
|
|
count: int = 0
|
|
name: str = "test"
|
|
name: str = "test"
|
|
int_list: List[int] = []
|
|
int_list: List[int] = []
|
|
@@ -115,6 +182,7 @@ class ModelClass(rx.Model):
|
|
sqla_tag: Optional[SQLATag] = None
|
|
sqla_tag: Optional[SQLATag] = None
|
|
labels: List[SQLALabel] = []
|
|
labels: List[SQLALabel] = []
|
|
dict_str_str: Dict[str, str] = {}
|
|
dict_str_str: Dict[str, str] = {}
|
|
|
|
+ default_factory: List[int] = sqlmodel.Field(default_factory=list)
|
|
|
|
|
|
@property
|
|
@property
|
|
def str_property(self) -> str:
|
|
def str_property(self) -> str:
|
|
@@ -147,6 +215,7 @@ class ModelClass(rx.Model):
|
|
class BaseClass(rx.Base):
|
|
class BaseClass(rx.Base):
|
|
"""Test rx.Base class."""
|
|
"""Test rx.Base class."""
|
|
|
|
|
|
|
|
+ no_default: Optional[int] = pydantic.v1.Field(required=False)
|
|
count: int = 0
|
|
count: int = 0
|
|
name: str = "test"
|
|
name: str = "test"
|
|
int_list: List[int] = []
|
|
int_list: List[int] = []
|
|
@@ -155,6 +224,7 @@ class BaseClass(rx.Base):
|
|
sqla_tag: Optional[SQLATag] = None
|
|
sqla_tag: Optional[SQLATag] = None
|
|
labels: List[SQLALabel] = []
|
|
labels: List[SQLALabel] = []
|
|
dict_str_str: Dict[str, str] = {}
|
|
dict_str_str: Dict[str, str] = {}
|
|
|
|
+ default_factory: List[int] = pydantic.v1.Field(default_factory=list)
|
|
|
|
|
|
@property
|
|
@property
|
|
def str_property(self) -> str:
|
|
def str_property(self) -> str:
|
|
@@ -236,6 +306,7 @@ class AttrClass:
|
|
sqla_tag: Optional[SQLATag] = None
|
|
sqla_tag: Optional[SQLATag] = None
|
|
labels: List[SQLALabel] = []
|
|
labels: List[SQLALabel] = []
|
|
dict_str_str: Dict[str, str] = {}
|
|
dict_str_str: Dict[str, str] = {}
|
|
|
|
+ default_factory: List[int] = attrs.field(factory=list)
|
|
|
|
|
|
@property
|
|
@property
|
|
def str_property(self) -> str:
|
|
def str_property(self) -> str:
|
|
@@ -265,27 +336,17 @@ class AttrClass:
|
|
return self.labels[0] if self.labels else None
|
|
return self.labels[0] if self.labels else None
|
|
|
|
|
|
|
|
|
|
-@pytest.fixture(
|
|
|
|
- params=[
|
|
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
|
+ "cls",
|
|
|
|
+ [
|
|
SQLAClass,
|
|
SQLAClass,
|
|
|
|
+ SQLAClassDataclass,
|
|
BaseClass,
|
|
BaseClass,
|
|
BareClass,
|
|
BareClass,
|
|
ModelClass,
|
|
ModelClass,
|
|
AttrClass,
|
|
AttrClass,
|
|
- ]
|
|
|
|
|
|
+ ],
|
|
)
|
|
)
|
|
-def cls(request: pytest.FixtureRequest) -> type:
|
|
|
|
- """Fixture for the class to test.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- request: pytest request object.
|
|
|
|
-
|
|
|
|
- Returns:
|
|
|
|
- Class to test.
|
|
|
|
- """
|
|
|
|
- return request.param
|
|
|
|
-
|
|
|
|
-
|
|
|
|
@pytest.mark.parametrize(
|
|
@pytest.mark.parametrize(
|
|
"attr, expected",
|
|
"attr, expected",
|
|
[
|
|
[
|
|
@@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
|
|
expected: Expected type.
|
|
expected: Expected type.
|
|
"""
|
|
"""
|
|
assert get_attribute_access_type(cls, attr) == expected
|
|
assert get_attribute_access_type(cls, attr) == expected
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
|
+ "cls",
|
|
|
|
+ [
|
|
|
|
+ SQLAClassDataclass,
|
|
|
|
+ BaseClass,
|
|
|
|
+ ModelClass,
|
|
|
|
+ AttrClass,
|
|
|
|
+ ],
|
|
|
|
+)
|
|
|
|
+def test_get_attribute_access_type_default_factory(cls: type) -> None:
|
|
|
|
+ """Test get_attribute_access_type returns the correct type for default factory fields.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ cls: Class to test.
|
|
|
|
+ """
|
|
|
|
+ assert get_attribute_access_type(cls, "default_factory") == List[int]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
|
+ "cls",
|
|
|
|
+ [
|
|
|
|
+ SQLAClassDataclass,
|
|
|
|
+ BaseClass,
|
|
|
|
+ ModelClass,
|
|
|
|
+ ],
|
|
|
|
+)
|
|
|
|
+def test_get_attribute_access_type_no_default(cls: type) -> None:
|
|
|
|
+ """Test get_attribute_access_type returns the correct type for fields with no default which are not required.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ cls: Class to test.
|
|
|
|
+ """
|
|
|
|
+ assert get_attribute_access_type(cls, "no_default") == Optional[int]
|