Преглед изворни кода

Improve Var type handling for better rx.Model attribute access (#2010)

Masen Furer пре 1 година
родитељ
комит
92dd68c51f
5 измењених фајлова са 175 додато и 36 уклоњено
  1. 30 1
      reflex/model.py
  2. 7 0
      reflex/state.py
  3. 79 23
      reflex/utils/types.py
  4. 8 10
      reflex/vars.py
  5. 51 2
      tests/test_state.py

+ 30 - 1
reflex/model.py

@@ -15,6 +15,7 @@ import alembic.runtime.environment
 import alembic.script
 import alembic.script
 import alembic.util
 import alembic.util
 import sqlalchemy
 import sqlalchemy
+import sqlalchemy.orm
 import sqlmodel
 import sqlmodel
 
 
 from reflex import constants
 from reflex import constants
@@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel):
 
 
         super().__init_subclass__()
         super().__init_subclass__()
 
 
+    @classmethod
+    def _dict_recursive(cls, value):
+        """Recursively serialize the relationship object(s).
+
+        Args:
+            value: The value to serialize.
+
+        Returns:
+            The serialized value.
+        """
+        if hasattr(value, "dict"):
+            return value.dict()
+        elif isinstance(value, list):
+            return [cls._dict_recursive(item) for item in value]
+        return value
+
     def dict(self, **kwargs):
     def dict(self, **kwargs):
         """Convert the object to a dictionary.
         """Convert the object to a dictionary.
 
 
@@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel):
         Returns:
         Returns:
             The object as a dictionary.
             The object as a dictionary.
         """
         """
-        return {name: getattr(self, name) for name in self.__fields__}
+        base_fields = {name: getattr(self, name) for name in self.__fields__}
+        relationships = {}
+        # SQLModel relationships do not appear in __fields__, but should be included if present.
+        for name in self.__sqlmodel_relationships__:
+            try:
+                relationships[name] = self._dict_recursive(getattr(self, name))
+            except sqlalchemy.orm.exc.DetachedInstanceError:
+                # This happens when the relationship was never loaded and the session is closed.
+                continue
+        return {
+            **base_fields,
+            **relationships,
+        }
 
 
     @staticmethod
     @staticmethod
     def create_all():
     def create_all():

+ 7 - 0
reflex/state.py

@@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
             if default_value is not None:
             if default_value is not None:
                 field.required = False
                 field.required = False
                 field.default = default_value
                 field.default = default_value
+        if (
+            not field.required
+            and field.default is None
+            and not types.is_optional(prop._var_type)
+        ):
+            # Ensure frontend uses null coalescing when accessing.
+            prop._var_type = Optional[prop._var_type]
 
 
     @staticmethod
     @staticmethod
     def _get_base_functions() -> dict[str, FunctionType]:
     def _get_base_functions() -> dict[str, FunctionType]:

+ 79 - 23
reflex/utils/types.py

@@ -3,8 +3,22 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import contextlib
 import contextlib
-import typing
-from typing import Any, Callable, Literal, Type, Union, _GenericAlias  # type: ignore
+import types
+from typing import (
+    Any,
+    Callable,
+    Iterable,
+    Literal,
+    Optional,
+    Type,
+    Union,
+    _GenericAlias,  # type: ignore
+    get_args,
+    get_origin,
+    get_type_hints,
+)
+
+from pydantic.fields import ModelField
 
 
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import serializers
 from reflex.utils import serializers
@@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
 ArgsSpec = Callable
 ArgsSpec = Callable
 
 
 
 
-def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
-    """Get the arguments of a type alias.
-
-    Args:
-        alias: The type alias.
-
-    Returns:
-        The arguments of the type alias.
-    """
-    return alias.__args__
-
-
 def is_generic_alias(cls: GenericType) -> bool:
 def is_generic_alias(cls: GenericType) -> bool:
     """Check whether the class is a generic alias.
     """Check whether the class is a generic alias.
 
 
@@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
     Returns:
     Returns:
         Whether the class is a Union.
         Whether the class is a Union.
     """
     """
-    with contextlib.suppress(ImportError):
-        from typing import _UnionGenericAlias  # type: ignore
+    # UnionType added in py3.10
+    if not hasattr(types, "UnionType"):
+        return get_origin(cls) is Union
 
 
-        return isinstance(cls, _UnionGenericAlias)
-    return cls.__origin__ == Union if is_generic_alias(cls) else False
+    return get_origin(cls) in [Union, types.UnionType]
 
 
 
 
 def is_literal(cls: GenericType) -> bool:
 def is_literal(cls: GenericType) -> bool:
@@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
     Returns:
     Returns:
         Whether the class is a literal.
         Whether the class is a literal.
     """
     """
-    return hasattr(cls, "__origin__") and cls.__origin__ is Literal
+    return get_origin(cls) is Literal
+
+
+def is_optional(cls: GenericType) -> bool:
+    """Check if a class is an Optional.
+
+    Args:
+        cls: The class to check.
+
+    Returns:
+        Whether the class is an Optional.
+    """
+    return is_union(cls) and type(None) in get_args(cls)
+
+
+def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
+    """Check if an attribute can be accessed on the cls and return its type.
+
+    Supports pydantic models, unions, and annotated attributes on rx.Model.
+
+    Args:
+        cls: The class to check.
+        name: The name of the attribute to check.
+
+    Returns:
+        The type of the attribute, if accessible, or None
+    """
+    from reflex.model import Model
+
+    if hasattr(cls, "__fields__") and name in cls.__fields__:
+        # pydantic models
+        field = cls.__fields__[name]
+        type_ = field.outer_type_
+        if isinstance(type_, ModelField):
+            type_ = type_.type_
+        if not field.required and field.default is None:
+            # Ensure frontend uses null coalescing when accessing.
+            type_ = Optional[type_]
+        return type_
+    elif isinstance(cls, type) and issubclass(cls, Model):
+        # Check in the annotations directly (for sqlmodel.Relationship)
+        hints = get_type_hints(cls)
+        if name in hints:
+            type_ = hints[name]
+            if isinstance(type_, ModelField):
+                return type_.type_
+            return type_
+    elif is_union(cls):
+        # Check in each arg of the annotation.
+        for arg in get_args(cls):
+            type_ = get_attribute_access_type(arg, name)
+            if type_ is not None:
+                # Return the first attribute type that is accessible.
+                return type_
+    return None  # Attribute is not accessible.
 
 
 
 
 def get_base_class(cls: GenericType) -> Type:
 def get_base_class(cls: GenericType) -> Type:
@@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
     Returns:
     Returns:
         Whether the value is a dataframe.
         Whether the value is a dataframe.
     """
     """
-    if is_generic_alias(value) or value == typing.Any:
+    if is_generic_alias(value) or value == Any:
         return False
         return False
     return value.__name__ == "DataFrame"
     return value.__name__ == "DataFrame"
 
 
@@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
     Returns:
     Returns:
         Whether the type is a valid prop type.
         Whether the type is a valid prop type.
     """
     """
+    if is_union(type_):
+        return all((is_valid_var_type(arg) for arg in get_args(type_)))
     return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
     return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
 
 
 
 
@@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
     return name.startswith("_") and not name.startswith("__")
     return name.startswith("_") and not name.startswith("__")
 
 
 
 
-def check_type_in_allowed_types(
-    value_type: Type, allowed_types: typing.Iterable
-) -> bool:
+def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
     """Check that a value type is found in a list of allowed types.
     """Check that a value type is found in a list of allowed types.
 
 
     Args:
     Args:

+ 8 - 10
reflex/vars.py

@@ -27,8 +27,6 @@ from typing import (
     get_type_hints,
     get_type_hints,
 )
 )
 
 
-from pydantic.fields import ModelField
-
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import console, format, serializers, types
 from reflex.utils import console, format, serializers, types
@@ -420,15 +418,12 @@ class Var:
                 raise TypeError(
                 raise TypeError(
                     f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
                     f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
                 ) from None
                 ) from None
-            if (
-                hasattr(self._var_type, "__fields__")
-                and name in self._var_type.__fields__
-            ):
-                type_ = self._var_type.__fields__[name].outer_type_
-                if isinstance(type_, ModelField):
-                    type_ = type_.type_
+            is_optional = types.is_optional(self._var_type)
+            type_ = types.get_attribute_access_type(self._var_type, name)
+
+            if type_ is not None:
                 return BaseVar(
                 return BaseVar(
-                    _var_name=f"{self._var_name}.{name}",
+                    _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
                     _var_type=type_,
                     _var_type=type_,
                     _var_state=self._var_state,
                     _var_state=self._var_state,
                     _var_is_local=self._var_is_local,
                     _var_is_local=self._var_is_local,
@@ -1235,6 +1230,9 @@ class BaseVar(Var):
         Raises:
         Raises:
             ImportError: If the var is a dataframe and pandas is not installed.
             ImportError: If the var is a dataframe and pandas is not installed.
         """
         """
+        if types.is_optional(self._var_type):
+            return None
+
         type_ = (
         type_ = (
             get_origin(self._var_type)
             get_origin(self._var_type)
             if types.is_generic_alias(self._var_type)
             if types.is_generic_alias(self._var_type)

+ 51 - 2
tests/test_state.py

@@ -7,7 +7,7 @@ import functools
 import json
 import json
 import os
 import os
 import sys
 import sys
-from typing import Dict, Generator, List
+from typing import Dict, Generator, List, Optional, Union
 from unittest.mock import AsyncMock, Mock
 from unittest.mock import AsyncMock, Mock
 
 
 import pytest
 import pytest
@@ -30,7 +30,7 @@ from reflex.state import (
     StateProxy,
     StateProxy,
     StateUpdate,
     StateUpdate,
 )
 )
-from reflex.utils import prerequisites
+from reflex.utils import prerequisites, types
 from reflex.utils.format import json_dumps
 from reflex.utils.format import json_dumps
 from reflex.vars import BaseVar, ComputedVar
 from reflex.vars import BaseVar, ComputedVar
 
 
@@ -2239,3 +2239,52 @@ def test_reset_with_mutables():
     instance.items.append([3, 3])
     instance.items.append([3, 3])
     assert instance.items != default
     assert instance.items != default
     assert instance.items != copied_default
     assert instance.items != copied_default
+
+
+class Custom1(Base):
+    """A custom class with a str field."""
+
+    foo: str
+
+
+class Custom2(Base):
+    """A custom class with a Custom1 field."""
+
+    c1: Optional[Custom1] = None
+    c1r: Custom1
+
+
+class Custom3(Base):
+    """A custom class with a Custom2 field."""
+
+    c2: Optional[Custom2] = None
+    c2r: Custom2
+
+
+def test_state_union_optional():
+    """Test that state can be defined with Union and Optional vars."""
+
+    class UnionState(State):
+        int_float: Union[int, float] = 0
+        opt_int: Optional[int]
+        c3: Optional[Custom3]
+        c3i: Custom3  # implicitly required
+        c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
+        custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
+
+    assert UnionState.c3.c2._var_name == "c3?.c2"  # type: ignore
+    assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1"  # type: ignore
+    assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo"  # type: ignore
+    assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo"  # type: ignore
+    assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1"  # type: ignore
+    assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo"  # type: ignore
+    assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo"  # type: ignore
+    assert UnionState.c3i.c2._var_name == "c3i.c2"  # type: ignore
+    assert UnionState.c3r.c2._var_name == "c3r.c2"  # type: ignore
+    assert UnionState.custom_union.foo is not None  # type: ignore
+    assert UnionState.custom_union.c1 is not None  # type: ignore
+    assert UnionState.custom_union.c1r is not None  # type: ignore
+    assert UnionState.custom_union.c2 is not None  # type: ignore
+    assert UnionState.custom_union.c2r is not None  # type: ignore
+    assert types.is_optional(UnionState.opt_int._var_type)  # type: ignore
+    assert types.is_union(UnionState.int_float._var_type)  # type: ignore