Parcourir la source

don't use _outer_type if we don't have to (#4528)

* don't use _outer_type if we don't have to

* apparently we should use .annotation, and .allow_none is useless

* have a shorter path for get_field_type if it's nice

* check against optional in annotation str

* add check for default value being null

* post merge

* we still console erroring

* bring back nested

* get_type_hints is slow af

* simplify value inside optional

* optimize get_event_triggers a tad bit

* optimize subclass checks

* optimize things even more why not

* what if we don't validate components
Khaleel Al-Adhami il y a 2 mois
Parent
commit
21ba01cb22

+ 1 - 1
reflex/components/base/bare.py

@@ -76,7 +76,7 @@ class Bare(Component):
                 validate_str(contents)
                 validate_str(contents)
             contents = str(contents) if contents is not None else ""
             contents = str(contents) if contents is not None else ""
 
 
-        return cls(contents=contents)
+        return cls._create(children=[], contents=contents)
 
 
     def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
     def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
         """Include the hooks for the component.
         """Include the hooks for the component.

+ 69 - 47
reflex/components/component.py

@@ -19,13 +19,12 @@ from typing import (
     Sequence,
     Sequence,
     Set,
     Set,
     Type,
     Type,
+    TypeVar,
     Union,
     Union,
     get_args,
     get_args,
     get_origin,
     get_origin,
 )
 )
 
 
-from typing_extensions import Self
-
 import reflex.state
 import reflex.state
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler.templates import STATEFUL_COMPONENT
 from reflex.compiler.templates import STATEFUL_COMPONENT
@@ -210,6 +209,27 @@ def _components_from(
     return ()
     return ()
 
 
 
 
+DEFAULT_TRIGGERS: dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
+    EventTriggers.ON_FOCUS: no_args_event_spec,
+    EventTriggers.ON_BLUR: no_args_event_spec,
+    EventTriggers.ON_CLICK: no_args_event_spec,
+    EventTriggers.ON_CONTEXT_MENU: no_args_event_spec,
+    EventTriggers.ON_DOUBLE_CLICK: no_args_event_spec,
+    EventTriggers.ON_MOUSE_DOWN: no_args_event_spec,
+    EventTriggers.ON_MOUSE_ENTER: no_args_event_spec,
+    EventTriggers.ON_MOUSE_LEAVE: no_args_event_spec,
+    EventTriggers.ON_MOUSE_MOVE: no_args_event_spec,
+    EventTriggers.ON_MOUSE_OUT: no_args_event_spec,
+    EventTriggers.ON_MOUSE_OVER: no_args_event_spec,
+    EventTriggers.ON_MOUSE_UP: no_args_event_spec,
+    EventTriggers.ON_SCROLL: no_args_event_spec,
+    EventTriggers.ON_MOUNT: no_args_event_spec,
+    EventTriggers.ON_UNMOUNT: no_args_event_spec,
+}
+
+T = TypeVar("T", bound="Component")
+
+
 class Component(BaseComponent, ABC):
 class Component(BaseComponent, ABC):
     """A component with style, event trigger and other props."""
     """A component with style, event trigger and other props."""
 
 
@@ -364,12 +384,16 @@ class Component(BaseComponent, ABC):
             if field.name not in props:
             if field.name not in props:
                 continue
                 continue
 
 
+            field_type = types.value_inside_optional(
+                types.get_field_type(cls, field.name)
+            )
+
             # Set default values for any props.
             # Set default values for any props.
-            if types._issubclass(field.type_, Var):
+            if types._issubclass(field_type, Var):
                 field.required = False
                 field.required = False
                 if field.default is not None:
                 if field.default is not None:
                     field.default = LiteralVar.create(field.default)
                     field.default = LiteralVar.create(field.default)
-            elif types._issubclass(field.type_, EventHandler):
+            elif types._issubclass(field_type, EventHandler):
                 field.required = False
                 field.required = False
 
 
         # Ensure renamed props from parent classes are applied to the subclass.
         # Ensure renamed props from parent classes are applied to the subclass.
@@ -380,7 +404,7 @@ class Component(BaseComponent, ABC):
                     inherited_rename_props.update(parent._rename_props)
                     inherited_rename_props.update(parent._rename_props)
             cls._rename_props = inherited_rename_props
             cls._rename_props = inherited_rename_props
 
 
-    def __init__(self, *args, **kwargs):
+    def _post_init(self, *args, **kwargs):
         """Initialize the component.
         """Initialize the component.
 
 
         Args:
         Args:
@@ -393,16 +417,6 @@ class Component(BaseComponent, ABC):
         """
         """
         # Set the id and children initially.
         # Set the id and children initially.
         children = kwargs.get("children", [])
         children = kwargs.get("children", [])
-        initial_kwargs = {
-            "id": kwargs.get("id"),
-            "children": children,
-            **{
-                prop: LiteralVar.create(kwargs[prop])
-                for prop in self.get_initial_props()
-                if prop in kwargs
-            },
-        }
-        super().__init__(**initial_kwargs)
 
 
         self._validate_component_children(children)
         self._validate_component_children(children)
 
 
@@ -433,7 +447,9 @@ class Component(BaseComponent, ABC):
                 field_type = EventChain
                 field_type = EventChain
             elif key in props:
             elif key in props:
                 # Set the field type.
                 # Set the field type.
-                field_type = fields[key].type_
+                field_type = types.value_inside_optional(
+                    types.get_field_type(type(self), key)
+                )
 
 
             else:
             else:
                 continue
                 continue
@@ -455,7 +471,10 @@ class Component(BaseComponent, ABC):
                 try:
                 try:
                     kwargs[key] = determine_key(value)
                     kwargs[key] = determine_key(value)
 
 
-                    expected_type = fields[key].outer_type_.__args__[0]
+                    expected_type = types.get_args(
+                        types.get_field_type(type(self), key)
+                    )[0]
+
                     # validate literal fields.
                     # validate literal fields.
                     types.validate_literal(
                     types.validate_literal(
                         key, value, expected_type, type(self).__name__
                         key, value, expected_type, type(self).__name__
@@ -470,7 +489,7 @@ class Component(BaseComponent, ABC):
                 except TypeError:
                 except TypeError:
                     # If it is not a valid var, check the base types.
                     # If it is not a valid var, check the base types.
                     passed_type = type(value)
                     passed_type = type(value)
-                    expected_type = fields[key].outer_type_
+                    expected_type = types.get_field_type(type(self), key)
                 if types.is_union(passed_type):
                 if types.is_union(passed_type):
                     # We need to check all possible types in the union.
                     # We need to check all possible types in the union.
                     passed_types = (
                     passed_types = (
@@ -552,7 +571,8 @@ class Component(BaseComponent, ABC):
                 kwargs["class_name"] = " ".join(class_name)
                 kwargs["class_name"] = " ".join(class_name)
 
 
         # Construct the component.
         # Construct the component.
-        super().__init__(*args, **kwargs)
+        for key, value in kwargs.items():
+            setattr(self, key, value)
 
 
     def get_event_triggers(
     def get_event_triggers(
         self,
         self,
@@ -562,34 +582,17 @@ class Component(BaseComponent, ABC):
         Returns:
         Returns:
             The event triggers.
             The event triggers.
         """
         """
-        default_triggers: dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
-            EventTriggers.ON_FOCUS: no_args_event_spec,
-            EventTriggers.ON_BLUR: no_args_event_spec,
-            EventTriggers.ON_CLICK: no_args_event_spec,
-            EventTriggers.ON_CONTEXT_MENU: no_args_event_spec,
-            EventTriggers.ON_DOUBLE_CLICK: no_args_event_spec,
-            EventTriggers.ON_MOUSE_DOWN: no_args_event_spec,
-            EventTriggers.ON_MOUSE_ENTER: no_args_event_spec,
-            EventTriggers.ON_MOUSE_LEAVE: no_args_event_spec,
-            EventTriggers.ON_MOUSE_MOVE: no_args_event_spec,
-            EventTriggers.ON_MOUSE_OUT: no_args_event_spec,
-            EventTriggers.ON_MOUSE_OVER: no_args_event_spec,
-            EventTriggers.ON_MOUSE_UP: no_args_event_spec,
-            EventTriggers.ON_SCROLL: no_args_event_spec,
-            EventTriggers.ON_MOUNT: no_args_event_spec,
-            EventTriggers.ON_UNMOUNT: no_args_event_spec,
-        }
-
+        triggers = DEFAULT_TRIGGERS.copy()
         # Look for component specific triggers,
         # Look for component specific triggers,
         # e.g. variable declared as EventHandler types.
         # e.g. variable declared as EventHandler types.
         for field in self.get_fields().values():
         for field in self.get_fields().values():
-            if types._issubclass(field.outer_type_, EventHandler):
+            if field.type_ is EventHandler:
                 args_spec = None
                 args_spec = None
                 annotation = field.annotation
                 annotation = field.annotation
                 if (metadata := getattr(annotation, "__metadata__", None)) is not None:
                 if (metadata := getattr(annotation, "__metadata__", None)) is not None:
                     args_spec = metadata[0]
                     args_spec = metadata[0]
-                default_triggers[field.name] = args_spec or (no_args_event_spec)
-        return default_triggers
+                triggers[field.name] = args_spec or (no_args_event_spec)
+        return triggers
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Represent the component in React.
         """Represent the component in React.
@@ -703,9 +706,11 @@ class Component(BaseComponent, ABC):
         """
         """
         return {
         return {
             name
             name
-            for name, field in cls.get_fields().items()
+            for name in cls.get_fields()
             if name in cls.get_props()
             if name in cls.get_props()
-            and types._issubclass(field.outer_type_, Component)
+            and types._issubclass(
+                types.value_inside_optional(types.get_field_type(cls, name)), Component
+            )
         }
         }
 
 
     def _get_components_in_props(self) -> Sequence[BaseComponent]:
     def _get_components_in_props(self) -> Sequence[BaseComponent]:
@@ -729,7 +734,7 @@ class Component(BaseComponent, ABC):
         ]
         ]
 
 
     @classmethod
     @classmethod
-    def create(cls, *children, **props) -> Self:
+    def create(cls: Type[T], *children, **props) -> T:
         """Create the component.
         """Create the component.
 
 
         Args:
         Args:
@@ -774,7 +779,22 @@ class Component(BaseComponent, ABC):
             for child in children
             for child in children
         ]
         ]
 
 
-        return cls(children=children, **props)
+        return cls._create(children, **props)
+
+    @classmethod
+    def _create(cls: Type[T], children: list[Component], **props: Any) -> T:
+        """Create the component.
+
+        Args:
+            children: The children of the component.
+            **props: The props of the component.
+
+        Returns:
+            The component.
+        """
+        comp = cls.construct(id=props.get("id"), children=children)
+        comp._post_init(children=children, **props)
+        return comp
 
 
     def add_style(self) -> dict[str, Any] | None:
     def add_style(self) -> dict[str, Any] | None:
         """Add style to the component.
         """Add style to the component.
@@ -1659,7 +1679,7 @@ class CustomComponent(Component):
     # The props of the component.
     # The props of the component.
     props: dict[str, Any] = {}
     props: dict[str, Any] = {}
 
 
-    def __init__(self, **kwargs):
+    def _post_init(self, **kwargs):
         """Initialize the custom component.
         """Initialize the custom component.
 
 
         Args:
         Args:
@@ -1702,7 +1722,7 @@ class CustomComponent(Component):
                 )
                 )
             )
             )
 
 
-        super().__init__(
+        super()._post_init(
             event_triggers={
             event_triggers={
                 key: EventChain.create(
                 key: EventChain.create(
                     value=props[key],
                     value=props[key],
@@ -1863,7 +1883,9 @@ def custom_component(
     def wrapper(*children, **props) -> CustomComponent:
     def wrapper(*children, **props) -> CustomComponent:
         # Remove the children from the props.
         # Remove the children from the props.
         props.pop("children", None)
         props.pop("children", None)
-        return CustomComponent(component_fn=component_fn, children=children, **props)
+        return CustomComponent._create(
+            children=list(children), component_fn=component_fn, **props
+        )
 
 
     return wrapper
     return wrapper
 
 

+ 9 - 2
reflex/config.py

@@ -36,7 +36,12 @@ from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.utils import console
 from reflex.utils import console
 from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
 from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
-from reflex.utils.types import GenericType, is_union, value_inside_optional
+from reflex.utils.types import (
+    GenericType,
+    is_union,
+    true_type_for_pydantic_field,
+    value_inside_optional,
+)
 
 
 try:
 try:
     from dotenv import load_dotenv  # pyright: ignore [reportMissingImports]
     from dotenv import load_dotenv  # pyright: ignore [reportMissingImports]
@@ -943,7 +948,9 @@ class Config(Base):
             # If the env var is set, override the config value.
             # If the env var is set, override the config value.
             if env_var is not None:
             if env_var is not None:
                 # Interpret the value.
                 # Interpret the value.
-                value = interpret_env_var_value(env_var, field.outer_type_, field.name)
+                value = interpret_env_var_value(
+                    env_var, true_type_for_pydantic_field(field), field.name
+                )
 
 
                 # Set the value.
                 # Set the value.
                 updated_values[key] = value
                 updated_values[key] = value

+ 8 - 6
reflex/state.py

@@ -89,9 +89,9 @@ from reflex.utils.serializers import serializer
 from reflex.utils.types import (
 from reflex.utils.types import (
     _isinstance,
     _isinstance,
     get_origin,
     get_origin,
-    is_optional,
     is_union,
     is_union,
     override,
     override,
+    true_type_for_pydantic_field,
     value_inside_optional,
     value_inside_optional,
 )
 )
 from reflex.vars import VarData
 from reflex.vars import VarData
@@ -272,7 +272,11 @@ class EventHandlerSetVar(EventHandler):
         return super().__call__(*args)
         return super().__call__(*args)
 
 
 
 
-def _unwrap_field_type(type_: Type) -> Type:
+if TYPE_CHECKING:
+    from pydantic.v1.fields import ModelField
+
+
+def _unwrap_field_type(type_: types.GenericType) -> Type:
     """Unwrap rx.Field type annotations.
     """Unwrap rx.Field type annotations.
 
 
     Args:
     Args:
@@ -303,7 +307,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     return dispatch(
     return dispatch(
         field_name=field_name,
         field_name=field_name,
         var_data=VarData.from_state(cls, f.name),
         var_data=VarData.from_state(cls, f.name),
-        result_var_type=_unwrap_field_type(f.outer_type_),
+        result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
     )
     )
 
 
 
 
@@ -1350,9 +1354,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         if name in fields:
         if name in fields:
             field = fields[name]
             field = fields[name]
-            field_type = _unwrap_field_type(field.outer_type_)
-            if field.allow_none and not is_optional(field_type):
-                field_type = field_type | None
+            field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
             if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
             if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
                 console.error(
                 console.error(
                     f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
                     f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"

+ 1 - 0
reflex/utils/pyi_generator.py

@@ -65,6 +65,7 @@ DEFAULT_TYPING_IMPORTS = {
     "Literal",
     "Literal",
     "Optional",
     "Optional",
     "Union",
     "Union",
+    "Annotated",
 }
 }
 
 
 # TODO: fix import ordering and unused imports with ruff later
 # TODO: fix import ordering and unused imports with ruff later

+ 70 - 20
reflex/utils/types.py

@@ -14,11 +14,13 @@ from typing import (
     Callable,
     Callable,
     ClassVar,
     ClassVar,
     Dict,
     Dict,
+    ForwardRef,
     FrozenSet,
     FrozenSet,
     Iterable,
     Iterable,
     List,
     List,
     Literal,
     Literal,
     Mapping,
     Mapping,
+    Optional,
     Sequence,
     Sequence,
     Tuple,
     Tuple,
     Type,
     Type,
@@ -26,9 +28,9 @@ from typing import (
     _GenericAlias,  # pyright: ignore [reportAttributeAccessIssue]
     _GenericAlias,  # pyright: ignore [reportAttributeAccessIssue]
     _SpecialGenericAlias,  # pyright: ignore [reportAttributeAccessIssue]
     _SpecialGenericAlias,  # pyright: ignore [reportAttributeAccessIssue]
     get_args,
     get_args,
-    get_type_hints,
 )
 )
 from typing import get_origin as get_origin_og
 from typing import get_origin as get_origin_og
+from typing import get_type_hints as get_type_hints_og
 
 
 import sqlalchemy
 import sqlalchemy
 from pydantic.v1.fields import ModelField
 from pydantic.v1.fields import ModelField
@@ -48,8 +50,8 @@ from reflex.utils import console
 # Potential GenericAlias types for isinstance checks.
 # Potential GenericAlias types for isinstance checks.
 GenericAliasTypes = (_GenericAlias, GenericAlias, _SpecialGenericAlias)
 GenericAliasTypes = (_GenericAlias, GenericAlias, _SpecialGenericAlias)
 
 
-# Potential Union types for isinstance checks (UnionType added in py3.10).
-UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
+# Potential Union types for isinstance checks.
+UnionTypes = (Union, types.UnionType)
 
 
 # Union of generic types.
 # Union of generic types.
 GenericType = Type | _GenericAlias
 GenericType = Type | _GenericAlias
@@ -140,6 +142,19 @@ def is_generic_alias(cls: GenericType) -> bool:
     return isinstance(cls, GenericAliasTypes)  # pyright: ignore [reportArgumentType]
     return isinstance(cls, GenericAliasTypes)  # pyright: ignore [reportArgumentType]
 
 
 
 
+@lru_cache()
+def get_type_hints(obj: Any) -> Dict[str, Any]:
+    """Get the type hints of a class.
+
+    Args:
+        obj: The class to get the type hints of.
+
+    Returns:
+        The type hints of the class.
+    """
+    return get_type_hints_og(obj)
+
+
 def unionize(*args: GenericType) -> Type:
 def unionize(*args: GenericType) -> Type:
     """Unionize the types.
     """Unionize the types.
 
 
@@ -231,6 +246,33 @@ def is_optional(cls: GenericType) -> bool:
     return is_union(cls) and type(None) in get_args(cls)
     return is_union(cls) and type(None) in get_args(cls)
 
 
 
 
+def true_type_for_pydantic_field(f: ModelField):
+    """Get the type for a pydantic field.
+
+    Args:
+        f: The field to get the type for.
+
+    Returns:
+        The type for the field.
+    """
+    if not isinstance(f.annotation, (str, ForwardRef)):
+        return f.annotation
+
+    type_ = f.outer_type_
+
+    if (
+        f.field_info.default is None
+        or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
+        or (
+            isinstance(f.annotation, ForwardRef)
+            and f.annotation.__forward_arg__.startswith("Optional")
+        )
+    ) and not is_optional(type_):
+        return Optional[type_]
+
+    return type_
+
+
 def value_inside_optional(cls: GenericType) -> GenericType:
 def value_inside_optional(cls: GenericType) -> GenericType:
     """Get the value inside an Optional type or the original type.
     """Get the value inside an Optional type or the original type.
 
 
@@ -241,10 +283,33 @@ def value_inside_optional(cls: GenericType) -> GenericType:
         The value inside the Optional type or the original type.
         The value inside the Optional type or the original type.
     """
     """
     if is_union(cls) and len(args := get_args(cls)) >= 2 and type(None) in args:
     if is_union(cls) and len(args := get_args(cls)) >= 2 and type(None) in args:
+        if len(args) == 2:
+            return args[0] if args[1] is type(None) else args[1]
         return unionize(*[arg for arg in args if arg is not type(None)])
         return unionize(*[arg for arg in args if arg is not type(None)])
     return cls
     return cls
 
 
 
 
+def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
+    """Get the type of a field in a class.
+
+    Args:
+        cls: The class to check.
+        field_name: The name of the field to check.
+
+    Returns:
+        The type of the field, if it exists, else None.
+    """
+    if (
+        hasattr(cls, "__fields__")
+        and field_name in cls.__fields__
+        and hasattr(cls.__fields__[field_name], "annotation")
+        and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
+    ):
+        return cls.__fields__[field_name].annotation
+    type_hints = get_type_hints(cls)
+    return type_hints.get(field_name, None)
+
+
 def get_property_hint(attr: Any | None) -> GenericType | None:
 def get_property_hint(attr: Any | None) -> GenericType | None:
     """Check if an attribute is a property and return its type hint.
     """Check if an attribute is a property and return its type hint.
 
 
@@ -282,24 +347,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
     if hint := get_property_hint(attr):
     if hint := get_property_hint(attr):
         return hint
         return hint
 
 
-    if (
-        hasattr(cls, "__fields__")
-        and name in cls.__fields__
-        and hasattr(cls.__fields__[name], "outer_type_")
-    ):
+    if hasattr(cls, "__fields__") and name in cls.__fields__:
         # pydantic models
         # pydantic models
-        field = cls.__fields__[name]
-        type_ = field.outer_type_
-        if isinstance(type_, ModelField):
-            type_ = type_.type_
-        if (
-            not field.required
-            and field.default is None
-            and field.default_factory is None
-        ):
-            # Ensure frontend uses null coalescing when accessing.
-            type_ = type_ | None
-        return type_
+        return get_field_type(cls, name)
     elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
     elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
         insp = sqlalchemy.inspect(cls)
         insp = sqlalchemy.inspect(cls)
         if name in insp.columns:
         if name in insp.columns:

+ 4 - 4
tests/units/components/test_component.py

@@ -268,7 +268,7 @@ def test_set_style_attrs(component1):
     Args:
     Args:
         component1: A test component.
         component1: A test component.
     """
     """
-    component = component1(color="white", text_align="center")
+    component = component1.create(color="white", text_align="center")
     assert str(component.style["color"]) == '"white"'
     assert str(component.style["color"]) == '"white"'
     assert str(component.style["textAlign"]) == '"center"'
     assert str(component.style["textAlign"]) == '"center"'
 
 
@@ -876,7 +876,7 @@ def test_create_custom_component(my_component):
     Args:
     Args:
         my_component: A test custom component.
         my_component: A test custom component.
     """
     """
-    component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
+    component = rx.memo(my_component)(prop1="test", prop2=1)
     assert component.tag == "MyComponent"
     assert component.tag == "MyComponent"
     assert component.get_props() == {"prop1", "prop2"}
     assert component.get_props() == {"prop1", "prop2"}
     assert component._get_all_custom_components() == {component}
     assert component._get_all_custom_components() == {component}
@@ -888,8 +888,8 @@ def test_custom_component_hash(my_component):
     Args:
     Args:
         my_component: A test custom component.
         my_component: A test custom component.
     """
     """
-    component1 = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
-    component2 = CustomComponent(component_fn=my_component, prop1="test", prop2=2)
+    component1 = rx.memo(my_component)(prop1="test", prop2=1)
+    component2 = rx.memo(my_component)(prop1="test", prop2=2)
     assert {component1, component2} == {component1}
     assert {component1, component2} == {component1}