Browse Source

Improved get_attribute_access_type tests (#3180)

* test get_attribute_access_type against attrs

* add union and hybrid_property tests
benedikt-bartscher 1 year ago
parent
commit
9ead091fec
1 changed files with 72 additions and 2 deletions
  1. 72 2
      tests/test_attribute_access_type.py

+ 72 - 2
tests/test_attribute_access_type.py

@@ -1,9 +1,11 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import List, Optional
+from typing import List, Optional, Union
 
 
+import attrs
 import pytest
 import pytest
 import sqlalchemy
 import sqlalchemy
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
 
 
 import reflex as rx
 import reflex as rx
@@ -60,6 +62,15 @@ class SQLAClass(SQLABase):
         """
         """
         return self.name
         return self.name
 
 
+    @hybrid_property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
 
 
 class ModelClass(rx.Model):
 class ModelClass(rx.Model):
     """Test reflex model."""
     """Test reflex model."""
@@ -81,6 +92,15 @@ class ModelClass(rx.Model):
         """
         """
         return self.name
         return self.name
 
 
+    @property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
 
 
 class BaseClass(rx.Base):
 class BaseClass(rx.Base):
     """Test rx.Base class."""
     """Test rx.Base class."""
@@ -102,6 +122,15 @@ class BaseClass(rx.Base):
         """
         """
         return self.name
         return self.name
 
 
+    @property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
 
 
 class BareClass:
 class BareClass:
     """Bare python class."""
     """Bare python class."""
@@ -123,8 +152,48 @@ class BareClass:
         """
         """
         return self.name
         return self.name
 
 
+    @property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+
+@attrs.define
+class AttrClass:
+    """Test attrs class."""
+
+    count: int = 0
+    name: str = "test"
+    int_list: List[int] = []
+    str_list: List[str] = []
+    optional_int: Optional[int] = None
+    sqla_tag: Optional[SQLATag] = None
+    labels: List[SQLALabel] = []
+
+    @property
+    def str_property(self) -> str:
+        """String property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
+    @property
+    def str_or_int_property(self) -> Union[str, int]:
+        """String or int property.
+
+        Returns:
+            Name attribute
+        """
+        return self.name
+
 
 
-@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass])
+@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass])
 def cls(request: pytest.FixtureRequest) -> type:
 def cls(request: pytest.FixtureRequest) -> type:
     """Fixture for the class to test.
     """Fixture for the class to test.
 
 
@@ -148,6 +217,7 @@ def cls(request: pytest.FixtureRequest) -> type:
         pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
         pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
         pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
         pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
         pytest.param("str_property", str, id="str_property"),
         pytest.param("str_property", str, id="str_property"),
+        pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
     ],
     ],
 )
 )
 def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
 def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None: