|
@@ -3,8 +3,22 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import contextlib
|
|
|
-import typing
|
|
|
-from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore
|
|
|
+import types
|
|
|
+from typing import (
|
|
|
+ Any,
|
|
|
+ Callable,
|
|
|
+ Iterable,
|
|
|
+ Literal,
|
|
|
+ Optional,
|
|
|
+ Type,
|
|
|
+ Union,
|
|
|
+ _GenericAlias, # type: ignore
|
|
|
+ get_args,
|
|
|
+ get_origin,
|
|
|
+ get_type_hints,
|
|
|
+)
|
|
|
+
|
|
|
+from pydantic.fields import ModelField
|
|
|
|
|
|
from reflex.base import Base
|
|
|
from reflex.utils import serializers
|
|
@@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
|
|
|
ArgsSpec = Callable
|
|
|
|
|
|
|
|
|
-def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
|
|
|
- """Get the arguments of a type alias.
|
|
|
-
|
|
|
- Args:
|
|
|
- alias: The type alias.
|
|
|
-
|
|
|
- Returns:
|
|
|
- The arguments of the type alias.
|
|
|
- """
|
|
|
- return alias.__args__
|
|
|
-
|
|
|
-
|
|
|
def is_generic_alias(cls: GenericType) -> bool:
|
|
|
"""Check whether the class is a generic alias.
|
|
|
|
|
@@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
|
|
|
Returns:
|
|
|
Whether the class is a Union.
|
|
|
"""
|
|
|
- with contextlib.suppress(ImportError):
|
|
|
- from typing import _UnionGenericAlias # type: ignore
|
|
|
+ # UnionType added in py3.10
|
|
|
+ if not hasattr(types, "UnionType"):
|
|
|
+ return get_origin(cls) is Union
|
|
|
|
|
|
- return isinstance(cls, _UnionGenericAlias)
|
|
|
- return cls.__origin__ == Union if is_generic_alias(cls) else False
|
|
|
+ return get_origin(cls) in [Union, types.UnionType]
|
|
|
|
|
|
|
|
|
def is_literal(cls: GenericType) -> bool:
|
|
@@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
|
|
|
Returns:
|
|
|
Whether the class is a literal.
|
|
|
"""
|
|
|
- return hasattr(cls, "__origin__") and cls.__origin__ is Literal
|
|
|
+ return get_origin(cls) is Literal
|
|
|
+
|
|
|
+
|
|
|
+def is_optional(cls: GenericType) -> bool:
|
|
|
+ """Check if a class is an Optional.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cls: The class to check.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Whether the class is an Optional.
|
|
|
+ """
|
|
|
+ return is_union(cls) and type(None) in get_args(cls)
|
|
|
+
|
|
|
+
|
|
|
+def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
|
|
+ """Check if an attribute can be accessed on the cls and return its type.
|
|
|
+
|
|
|
+ Supports pydantic models, unions, and annotated attributes on rx.Model.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cls: The class to check.
|
|
|
+ name: The name of the attribute to check.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The type of the attribute, if accessible, or None
|
|
|
+ """
|
|
|
+ from reflex.model import Model
|
|
|
+
|
|
|
+ if hasattr(cls, "__fields__") and name in cls.__fields__:
|
|
|
+ # pydantic models
|
|
|
+ field = cls.__fields__[name]
|
|
|
+ type_ = field.outer_type_
|
|
|
+ if isinstance(type_, ModelField):
|
|
|
+ type_ = type_.type_
|
|
|
+ if not field.required and field.default is None:
|
|
|
+ # Ensure frontend uses null coalescing when accessing.
|
|
|
+ type_ = Optional[type_]
|
|
|
+ return type_
|
|
|
+ elif isinstance(cls, type) and issubclass(cls, Model):
|
|
|
+ # Check in the annotations directly (for sqlmodel.Relationship)
|
|
|
+ hints = get_type_hints(cls)
|
|
|
+ if name in hints:
|
|
|
+ type_ = hints[name]
|
|
|
+ if isinstance(type_, ModelField):
|
|
|
+ return type_.type_
|
|
|
+ return type_
|
|
|
+ elif is_union(cls):
|
|
|
+ # Check in each arg of the annotation.
|
|
|
+ for arg in get_args(cls):
|
|
|
+ type_ = get_attribute_access_type(arg, name)
|
|
|
+ if type_ is not None:
|
|
|
+ # Return the first attribute type that is accessible.
|
|
|
+ return type_
|
|
|
+ return None # Attribute is not accessible.
|
|
|
|
|
|
|
|
|
def get_base_class(cls: GenericType) -> Type:
|
|
@@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
|
|
|
Returns:
|
|
|
Whether the value is a dataframe.
|
|
|
"""
|
|
|
- if is_generic_alias(value) or value == typing.Any:
|
|
|
+ if is_generic_alias(value) or value == Any:
|
|
|
return False
|
|
|
return value.__name__ == "DataFrame"
|
|
|
|
|
@@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
|
|
|
Returns:
|
|
|
Whether the type is a valid prop type.
|
|
|
"""
|
|
|
+ if is_union(type_):
|
|
|
+ return all((is_valid_var_type(arg) for arg in get_args(type_)))
|
|
|
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
|
|
|
|
|
|
|
|
@@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
|
|
|
return name.startswith("_") and not name.startswith("__")
|
|
|
|
|
|
|
|
|
-def check_type_in_allowed_types(
|
|
|
- value_type: Type, allowed_types: typing.Iterable
|
|
|
-) -> bool:
|
|
|
+def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
|
|
|
"""Check that a value type is found in a list of allowed types.
|
|
|
|
|
|
Args:
|