瀏覽代碼

Remove Pydantic from some classes (#3907)

* half of the way there

* add dataclass support

* Forbid Computed var shadowing (#3843)

* get it right pyright

* fix unit tests

* rip out more pydantic

* fix weird issues with merge_imports

* add missing docstring

* make special props a list instead of a set

* fix moment pyi

* actually ignore the runtime error

* it's ruff out there

---------

Co-authored-by: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com>
Khaleel Al-Adhami 8 月之前
父節點
當前提交
8f937f0417

+ 4 - 1
reflex/app.py

@@ -9,6 +9,7 @@ import copy
 import functools
 import inspect
 import io
+import json
 import multiprocessing
 import os
 import platform
@@ -1096,6 +1097,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
             if delta:
                 # When the state is modified reset dirty status and emit the delta to the frontend.
                 state._clean()
+                print(dir(state.router))
                 await self.event_namespace.emit_update(
                     update=StateUpdate(delta=delta),
                     sid=state.router.session.session_id,
@@ -1531,8 +1533,9 @@ class EventNamespace(AsyncNamespace):
             sid: The Socket.IO session id.
             data: The event data.
         """
+        fields = json.loads(data)
         # Get the event.
-        event = Event.parse_raw(data)
+        event = Event(**{k: v for k, v in fields.items() if k != "handler"})
 
         self.token_to_sid[event.token] = sid
         self.sid_to_token[sid] = event.token

+ 6 - 3
reflex/components/component.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import copy
 import typing
+import warnings
 from abc import ABC, abstractmethod
 from functools import lru_cache, wraps
 from hashlib import md5
@@ -169,6 +170,8 @@ ComponentStyle = Dict[
 ]
 ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
 
+warnings.filterwarnings("ignore", message="fields may not start with an underscore")
+
 
 class Component(BaseComponent, ABC):
     """A component with style, event trigger and other props."""
@@ -195,7 +198,7 @@ class Component(BaseComponent, ABC):
     class_name: Any = None
 
     # Special component props.
-    special_props: Set[ImmutableVar] = set()
+    special_props: List[ImmutableVar] = []
 
     # Whether the component should take the focus once the page is loaded
     autofocus: bool = False
@@ -655,7 +658,7 @@ class Component(BaseComponent, ABC):
         """
         # Create the base tag.
         tag = Tag(
-            name=self.tag if not self.alias else self.alias,
+            name=(self.tag if not self.alias else self.alias) or "",
             special_props=self.special_props,
         )
 
@@ -2244,7 +2247,7 @@ class StatefulComponent(BaseComponent):
         Returns:
             The tag to render.
         """
-        return dict(Tag(name=self.tag))
+        return dict(Tag(name=self.tag or ""))
 
     def __str__(self) -> str:
         """Represent the component in React.

+ 4 - 4
reflex/components/core/upload.py

@@ -247,9 +247,9 @@ class Upload(MemoizationLeaf):
         }
         # The file input to use.
         upload = Input.create(type="file")
-        upload.special_props = {
+        upload.special_props = [
             ImmutableVar(_var_name="{...getInputProps()}", _var_type=None)
-        }
+        ]
 
         # The dropzone to use.
         zone = Box.create(
@@ -257,9 +257,9 @@ class Upload(MemoizationLeaf):
             *children,
             **{k: v for k, v in props.items() if k not in supported_props},
         )
-        zone.special_props = {
+        zone.special_props = [
             ImmutableVar(_var_name="{...getRootProps()}", _var_type=None)
-        }
+        ]
 
         # Create the component.
         upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)

+ 3 - 3
reflex/components/el/elements/metadata.py

@@ -1,6 +1,6 @@
 """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
 
-from typing import Set, Union
+from typing import List, Union
 
 from reflex.components.el.element import Element
 from reflex.ivars.base import ImmutableVar
@@ -90,9 +90,9 @@ class StyleEl(Element):  # noqa: E742
 
     media: Var[Union[str, int, bool]]
 
-    special_props: Set[ImmutableVar] = {
+    special_props: List[ImmutableVar] = [
         ImmutableVar.create_safe("suppressHydrationWarning")
-    }
+    ]
 
 
 base = Base.create

+ 3 - 3
reflex/components/markdown/markdown.py

@@ -195,17 +195,17 @@ class Markdown(Component):
         if tag not in self.component_map:
             raise ValueError(f"No markdown component found for tag: {tag}.")
 
-        special_props = {_PROPS_IN_TAG}
+        special_props = [_PROPS_IN_TAG]
         children = [_CHILDREN]
 
         # For certain tags, the props from the markdown renderer are not actually valid for the component.
         if tag in NO_PROPS_TAGS:
-            special_props = set()
+            special_props = []
 
         # If the children are set as a prop, don't pass them as children.
         children_prop = props.pop("children", None)
         if children_prop is not None:
-            special_props.add(
+            special_props.append(
                 ImmutableVar.create_safe(f"children={{{str(children_prop)}}}")
             )
             children = []

+ 12 - 11
reflex/components/moment/moment.py

@@ -1,26 +1,27 @@
 """Moment component for humanized date rendering."""
 
+import dataclasses
 from typing import List, Optional
 
-from reflex.base import Base
 from reflex.components.component import Component, NoSSRComponent
 from reflex.event import EventHandler
 from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
 
-class MomentDelta(Base):
+@dataclasses.dataclass(frozen=True)
+class MomentDelta:
     """A delta used for add/subtract prop in Moment."""
 
-    years: Optional[int]
-    quarters: Optional[int]
-    months: Optional[int]
-    weeks: Optional[int]
-    days: Optional[int]
-    hours: Optional[int]
-    minutess: Optional[int]
-    seconds: Optional[int]
-    milliseconds: Optional[int]
+    years: Optional[int] = dataclasses.field(default=None)
+    quarters: Optional[int] = dataclasses.field(default=None)
+    months: Optional[int] = dataclasses.field(default=None)
+    weeks: Optional[int] = dataclasses.field(default=None)
+    days: Optional[int] = dataclasses.field(default=None)
+    hours: Optional[int] = dataclasses.field(default=None)
+    minutess: Optional[int] = dataclasses.field(default=None)
+    seconds: Optional[int] = dataclasses.field(default=None)
+    milliseconds: Optional[int] = dataclasses.field(default=None)
 
 
 class Moment(NoSSRComponent):

+ 3 - 2
reflex/components/moment/moment.pyi

@@ -3,9 +3,9 @@
 # ------------------- DO NOT EDIT ----------------------
 # This file was generated by `reflex/utils/pyi_generator.py`!
 # ------------------------------------------------------
+import dataclasses
 from typing import Any, Callable, Dict, Optional, Union, overload
 
-from reflex.base import Base
 from reflex.components.component import NoSSRComponent
 from reflex.event import EventHandler, EventSpec
 from reflex.ivars.base import ImmutableVar
@@ -13,7 +13,8 @@ from reflex.style import Style
 from reflex.utils.imports import ImportDict
 from reflex.vars import Var
 
-class MomentDelta(Base):
+@dataclasses.dataclass(frozen=True)
+class MomentDelta:
     years: Optional[int]
     quarters: Optional[int]
     months: Optional[int]

+ 2 - 2
reflex/components/plotly/plotly.py

@@ -267,7 +267,7 @@ const extractPoints = (points) => {
             template_dict = LiteralVar.create({"layout": {"template": self.template}})
             merge_dicts.append(template_dict.without_data())
         if merge_dicts:
-            tag.special_props.add(
+            tag.special_props.append(
                 # Merge all dictionaries and spread the result over props.
                 ImmutableVar.create_safe(
                     f"{{...mergician({str(figure)},"
@@ -276,5 +276,5 @@ const extractPoints = (points) => {
             )
         else:
             # Spread the figure dict over props, nothing to merge.
-            tag.special_props.add(ImmutableVar.create_safe(f"{{...{str(figure)}}}"))
+            tag.special_props.append(ImmutableVar.create_safe(f"{{...{str(figure)}}}"))
         return tag

+ 6 - 3
reflex/components/tags/cond_tag.py

@@ -1,19 +1,22 @@
 """Tag to conditionally render components."""
 
+import dataclasses
 from typing import Any, Dict, Optional
 
 from reflex.components.tags.tag import Tag
+from reflex.ivars.base import LiteralVar
 from reflex.vars import Var
 
 
+@dataclasses.dataclass()
 class CondTag(Tag):
     """A conditional tag."""
 
     # The condition to determine which component to render.
-    cond: Var[Any]
+    cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
 
     # The code to render if the condition is true.
-    true_value: Dict
+    true_value: Dict = dataclasses.field(default_factory=dict)
 
     # The code to render if the condition is false.
-    false_value: Optional[Dict]
+    false_value: Optional[Dict] = None

+ 10 - 5
reflex/components/tags/iter_tag.py

@@ -2,31 +2,36 @@
 
 from __future__ import annotations
 
+import dataclasses
 import inspect
 from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, Union, get_args
 
 from reflex.components.tags.tag import Tag
 from reflex.ivars.base import ImmutableVar
-from reflex.vars import Var
+from reflex.ivars.sequence import LiteralArrayVar
+from reflex.vars import Var, get_unique_variable_name
 
 if TYPE_CHECKING:
     from reflex.components.component import Component
 
 
+@dataclasses.dataclass()
 class IterTag(Tag):
     """An iterator tag."""
 
     # The var to iterate over.
-    iterable: Var[List]
+    iterable: Var[List] = dataclasses.field(
+        default_factory=lambda: LiteralArrayVar.create([])
+    )
 
     # The component render function for each item in the iterable.
-    render_fn: Callable
+    render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
 
     # The name of the arg var.
-    arg_var_name: str
+    arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
     # The name of the index var.
-    index_var_name: str
+    index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
 
     def get_iterable_var_type(self) -> Type:
         """Get the type of the iterable var.

+ 6 - 3
reflex/components/tags/match_tag.py

@@ -1,19 +1,22 @@
 """Tag to conditionally match cases."""
 
+import dataclasses
 from typing import Any, List
 
 from reflex.components.tags.tag import Tag
+from reflex.ivars.base import LiteralVar
 from reflex.vars import Var
 
 
+@dataclasses.dataclass()
 class MatchTag(Tag):
     """A match tag."""
 
     # The condition to determine which case to match.
-    cond: Var[Any]
+    cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
 
     # The list of match cases to be matched.
-    match_cases: List[Any]
+    match_cases: List[Any] = dataclasses.field(default_factory=list)
 
     # The catchall case to match.
-    default: Any
+    default: Any = dataclasses.field(default=LiteralVar.create(None))

+ 38 - 21
reflex/components/tags/tag.py

@@ -2,22 +2,23 @@
 
 from __future__ import annotations
 
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+import dataclasses
+from typing import Any, Dict, List, Optional, Tuple, Union
 
-from reflex.base import Base
 from reflex.event import EventChain
 from reflex.ivars.base import ImmutableVar, LiteralVar
 from reflex.utils import format, types
 
 
-class Tag(Base):
+@dataclasses.dataclass()
+class Tag:
     """A React tag."""
 
     # The name of the tag.
     name: str = ""
 
     # The props of the tag.
-    props: Dict[str, Any] = {}
+    props: Dict[str, Any] = dataclasses.field(default_factory=dict)
 
     # The inner contents of the tag.
     contents: str = ""
@@ -26,25 +27,18 @@ class Tag(Base):
     args: Optional[Tuple[str, ...]] = None
 
     # Special props that aren't key value pairs.
-    special_props: Set[ImmutableVar] = set()
+    special_props: List[ImmutableVar] = dataclasses.field(default_factory=list)
 
     # The children components.
-    children: List[Any] = []
-
-    def __init__(self, *args, **kwargs):
-        """Initialize the tag.
-
-        Args:
-            *args: Args to initialize the tag.
-            **kwargs: Kwargs to initialize the tag.
-        """
-        # Convert any props to vars.
-        if "props" in kwargs:
-            kwargs["props"] = {
-                name: LiteralVar.create(value)
-                for name, value in kwargs["props"].items()
-            }
-        super().__init__(*args, **kwargs)
+    children: List[Any] = dataclasses.field(default_factory=list)
+
+    def __post_init__(self):
+        """Post initialize the tag."""
+        object.__setattr__(
+            self,
+            "props",
+            {name: LiteralVar.create(value) for name, value in self.props.items()},
+        )
 
     def format_props(self) -> List:
         """Format the tag's props.
@@ -54,6 +48,29 @@ class Tag(Base):
         """
         return format.format_props(*self.special_props, **self.props)
 
+    def set(self, **kwargs: Any):
+        """Set the tag's fields.
+
+        Args:
+            kwargs: The fields to set.
+
+        Returns:
+            The tag with the fields
+        """
+        for name, value in kwargs.items():
+            setattr(self, name, value)
+
+        return self
+
+    def __iter__(self):
+        """Iterate over the tag's fields.
+
+        Yields:
+            Tuple[str, Any]: The field name and value.
+        """
+        for field in dataclasses.fields(self):
+            yield field.name, getattr(self, field.name)
+
     def add_props(self, **kwargs: Optional[Any]) -> Tag:
         """Add props to the tag.
 

+ 90 - 36
reflex/event.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 import inspect
 import types
 import urllib.parse
@@ -18,7 +19,6 @@ from typing import (
 )
 
 from reflex import constants
-from reflex.base import Base
 from reflex.ivars.base import ImmutableVar, LiteralVar
 from reflex.ivars.function import FunctionStringVar, FunctionVar
 from reflex.ivars.object import ObjectVar
@@ -33,7 +33,11 @@ except ImportError:
     from typing_extensions import Annotated
 
 
-class Event(Base):
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
+class Event:
     """An event that describes any state change in the app."""
 
     # The token to specify the client that the event is for.
@@ -43,10 +47,10 @@ class Event(Base):
     name: str
 
     # The routing data where event occurred
-    router_data: Dict[str, Any] = {}
+    router_data: Dict[str, Any] = dataclasses.field(default_factory=dict)
 
     # The event payload.
-    payload: Dict[str, Any] = {}
+    payload: Dict[str, Any] = dataclasses.field(default_factory=dict)
 
     @property
     def substate_token(self) -> str:
@@ -81,11 +85,15 @@ def background(fn):
     return fn
 
 
-class EventActionsMixin(Base):
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
+class EventActionsMixin:
     """Mixin for DOM event actions."""
 
     # Whether to `preventDefault` or `stopPropagation` on the event.
-    event_actions: Dict[str, Union[bool, int]] = {}
+    event_actions: Dict[str, Union[bool, int]] = dataclasses.field(default_factory=dict)
 
     @property
     def stop_propagation(self):
@@ -94,8 +102,9 @@ class EventActionsMixin(Base):
         Returns:
             New EventHandler-like with stopPropagation set to True.
         """
-        return self.copy(
-            update={"event_actions": {"stopPropagation": True, **self.event_actions}},
+        return dataclasses.replace(
+            self,
+            event_actions={"stopPropagation": True, **self.event_actions},
         )
 
     @property
@@ -105,8 +114,9 @@ class EventActionsMixin(Base):
         Returns:
             New EventHandler-like with preventDefault set to True.
         """
-        return self.copy(
-            update={"event_actions": {"preventDefault": True, **self.event_actions}},
+        return dataclasses.replace(
+            self,
+            event_actions={"preventDefault": True, **self.event_actions},
         )
 
     def throttle(self, limit_ms: int):
@@ -118,8 +128,9 @@ class EventActionsMixin(Base):
         Returns:
             New EventHandler-like with throttle set to limit_ms.
         """
-        return self.copy(
-            update={"event_actions": {"throttle": limit_ms, **self.event_actions}},
+        return dataclasses.replace(
+            self,
+            event_actions={"throttle": limit_ms, **self.event_actions},
         )
 
     def debounce(self, delay_ms: int):
@@ -131,26 +142,25 @@ class EventActionsMixin(Base):
         Returns:
             New EventHandler-like with debounce set to delay_ms.
         """
-        return self.copy(
-            update={"event_actions": {"debounce": delay_ms, **self.event_actions}},
+        return dataclasses.replace(
+            self,
+            event_actions={"debounce": delay_ms, **self.event_actions},
         )
 
 
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
 class EventHandler(EventActionsMixin):
     """An event handler responds to an event to update the state."""
 
     # The function to call in response to the event.
-    fn: Any
+    fn: Any = dataclasses.field(default=None)
 
     # The full name of the state class this event handler is attached to.
     # Empty string means this event handler is a server side event.
-    state_full_name: str = ""
-
-    class Config:
-        """The Pydantic config."""
-
-        # Needed to allow serialization of Callable.
-        frozen = True
+    state_full_name: str = dataclasses.field(default="")
 
     @classmethod
     def __class_getitem__(cls, args_spec: str) -> Annotated:
@@ -215,6 +225,10 @@ class EventHandler(EventActionsMixin):
         )
 
 
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
 class EventSpec(EventActionsMixin):
     """An event specification.
 
@@ -223,19 +237,37 @@ class EventSpec(EventActionsMixin):
     """
 
     # The event handler.
-    handler: EventHandler
+    handler: EventHandler = dataclasses.field(default=None)  # type: ignore
 
     # The handler on the client to process event.
-    client_handler_name: str = ""
+    client_handler_name: str = dataclasses.field(default="")
 
     # The arguments to pass to the function.
-    args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = ()
+    args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = dataclasses.field(
+        default_factory=tuple
+    )
 
-    class Config:
-        """The Pydantic config."""
+    def __init__(
+        self,
+        handler: EventHandler,
+        event_actions: Dict[str, Union[bool, int]] | None = None,
+        client_handler_name: str = "",
+        args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = tuple(),
+    ):
+        """Initialize an EventSpec.
 
-        # Required to allow tuple fields.
-        frozen = True
+        Args:
+            event_actions: The event actions.
+            handler: The event handler.
+            client_handler_name: The client handler name.
+            args: The arguments to pass to the function.
+        """
+        if event_actions is None:
+            event_actions = {}
+        object.__setattr__(self, "event_actions", event_actions)
+        object.__setattr__(self, "handler", handler)
+        object.__setattr__(self, "client_handler_name", client_handler_name)
+        object.__setattr__(self, "args", args or tuple())
 
     def with_args(
         self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...]
@@ -286,6 +318,9 @@ class EventSpec(EventActionsMixin):
         return self.with_args(self.args + new_payload)
 
 
+@dataclasses.dataclass(
+    frozen=True,
+)
 class CallableEventSpec(EventSpec):
     """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
 
@@ -305,10 +340,13 @@ class CallableEventSpec(EventSpec):
         if fn is not None:
             default_event_spec = fn()
             super().__init__(
-                fn=fn,  # type: ignore
-                **default_event_spec.dict(),
+                event_actions=default_event_spec.event_actions,
+                client_handler_name=default_event_spec.client_handler_name,
+                args=default_event_spec.args,
+                handler=default_event_spec.handler,
                 **kwargs,
             )
+            object.__setattr__(self, "fn", fn)
         else:
             super().__init__(**kwargs)
 
@@ -332,12 +370,16 @@ class CallableEventSpec(EventSpec):
         return self.fn(*args, **kwargs)
 
 
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
 class EventChain(EventActionsMixin):
     """Container for a chain of events that will be executed in order."""
 
-    events: List[EventSpec]
+    events: List[EventSpec] = dataclasses.field(default_factory=list)
 
-    args_spec: Optional[Callable]
+    args_spec: Optional[Callable] = dataclasses.field(default=None)
 
 
 # These chains can be used for their side effects when no other events are desired.
@@ -345,14 +387,22 @@ stop_propagation = EventChain(events=[], args_spec=lambda: []).stop_propagation
 prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default
 
 
-class Target(Base):
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
+class Target:
     """A Javascript event target."""
 
     checked: bool = False
     value: Any = None
 
 
-class FrontendEvent(Base):
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
+class FrontendEvent:
     """A Javascript event."""
 
     target: Target = Target()
@@ -360,7 +410,11 @@ class FrontendEvent(Base):
     value: Any = None
 
 
-class FileUpload(Base):
+@dataclasses.dataclass(
+    init=True,
+    frozen=True,
+)
+class FileUpload:
     """Class to represent a file upload."""
 
     upload_id: Optional[str] = None

+ 20 - 1
reflex/ivars/base.py

@@ -421,6 +421,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
         if issubclass(output, (ObjectVar, Base)):
             return ToObjectOperation.create(self, var_type or dict)
 
+        if dataclasses.is_dataclass(output):
+            return ToObjectOperation.create(self, var_type or dict)
+
         if issubclass(output, FunctionVar):
             # if fixed_type is not None and not issubclass(fixed_type, Callable):
             #     raise TypeError(
@@ -479,7 +482,11 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
             ):
                 return self.to(NumberVar, self._var_type)
 
-            if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types):
+            if all(
+                inspect.isclass(t)
+                and (issubclass(t, Base) or dataclasses.is_dataclass(t))
+                for t in inner_types
+            ):
                 return self.to(ObjectVar, self._var_type)
 
             return self
@@ -499,6 +506,8 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
             return self.to(StringVar, self._var_type)
         if issubclass(fixed_type, Base):
             return self.to(ObjectVar, self._var_type)
+        if dataclasses.is_dataclass(fixed_type):
+            return self.to(ObjectVar, self._var_type)
         return self
 
     def get_default_value(self) -> Any:
@@ -985,6 +994,16 @@ class LiteralVar(ImmutableVar):
                 )
             return LiteralVar.create(serialized_value, _var_data=_var_data)
 
+        if dataclasses.is_dataclass(value) and not isinstance(value, type):
+            return LiteralObjectVar.create(
+                {
+                    k: (None if callable(v) else v)
+                    for k, v in dataclasses.asdict(value).items()
+                },
+                _var_type=type(value),
+                _var_data=_var_data,
+            )
+
         raise TypeError(
             f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
         )

+ 2 - 0
reflex/middleware/hydrate_middleware.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 from typing import TYPE_CHECKING, Optional
 
 from reflex import constants
@@ -14,6 +15,7 @@ if TYPE_CHECKING:
     from reflex.app import App
 
 
+@dataclasses.dataclass(init=True)
 class HydrateMiddleware(Middleware):
     """Middleware to handle initial app hydration."""
 

+ 3 - 3
reflex/middleware/middleware.py

@@ -2,10 +2,9 @@
 
 from __future__ import annotations
 
-from abc import ABC
+from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, Optional
 
-from reflex.base import Base
 from reflex.event import Event
 from reflex.state import BaseState, StateUpdate
 
@@ -13,9 +12,10 @@ if TYPE_CHECKING:
     from reflex.app import App
 
 
-class Middleware(Base, ABC):
+class Middleware(ABC):
     """Middleware to preprocess and postprocess requests."""
 
+    @abstractmethod
     async def preprocess(
         self, app: App, state: BaseState, event: Event
     ) -> Optional[StateUpdate]:

+ 97 - 34
reflex/state.py

@@ -5,8 +5,10 @@ from __future__ import annotations
 import asyncio
 import contextlib
 import copy
+import dataclasses
 import functools
 import inspect
+import json
 import os
 import uuid
 from abc import ABC, abstractmethod
@@ -83,13 +85,15 @@ var = immutable_computed_var
 TOO_LARGE_SERIALIZED_STATE = 100 * 1024  # 100kb
 
 
-class HeaderData(Base):
+@dataclasses.dataclass(frozen=True)
+class HeaderData:
     """An object containing headers data."""
 
     host: str = ""
     origin: str = ""
     upgrade: str = ""
     connection: str = ""
+    cookie: str = ""
     pragma: str = ""
     cache_control: str = ""
     user_agent: str = ""
@@ -105,13 +109,16 @@ class HeaderData(Base):
         Args:
             router_data: the router_data dict.
         """
-        super().__init__()
         if router_data:
             for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
-                setattr(self, format.to_snake_case(k), v)
+                object.__setattr__(self, format.to_snake_case(k), v)
+        else:
+            for k in dataclasses.fields(self):
+                object.__setattr__(self, k.name, "")
 
 
-class PageData(Base):
+@dataclasses.dataclass(frozen=True)
+class PageData:
     """An object containing page data."""
 
     host: str = ""  # repeated with self.headers.origin (remove or keep the duplicate?)
@@ -119,7 +126,7 @@ class PageData(Base):
     raw_path: str = ""
     full_path: str = ""
     full_raw_path: str = ""
-    params: dict = {}
+    params: dict = dataclasses.field(default_factory=dict)
 
     def __init__(self, router_data: Optional[dict] = None):
         """Initalize the PageData object based on router_data.
@@ -127,17 +134,34 @@ class PageData(Base):
         Args:
             router_data: the router_data dict.
         """
-        super().__init__()
         if router_data:
-            self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin")
-            self.path = router_data.get(constants.RouteVar.PATH, "")
-            self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "")
-            self.full_path = f"{self.host}{self.path}"
-            self.full_raw_path = f"{self.host}{self.raw_path}"
-            self.params = router_data.get(constants.RouteVar.QUERY, {})
+            object.__setattr__(
+                self,
+                "host",
+                router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""),
+            )
+            object.__setattr__(
+                self, "path", router_data.get(constants.RouteVar.PATH, "")
+            )
+            object.__setattr__(
+                self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "")
+            )
+            object.__setattr__(self, "full_path", f"{self.host}{self.path}")
+            object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}")
+            object.__setattr__(
+                self, "params", router_data.get(constants.RouteVar.QUERY, {})
+            )
+        else:
+            object.__setattr__(self, "host", "")
+            object.__setattr__(self, "path", "")
+            object.__setattr__(self, "raw_path", "")
+            object.__setattr__(self, "full_path", "")
+            object.__setattr__(self, "full_raw_path", "")
+            object.__setattr__(self, "params", {})
 
 
-class SessionData(Base):
+@dataclasses.dataclass(frozen=True, init=False)
+class SessionData:
     """An object containing session data."""
 
     client_token: str = ""
@@ -150,19 +174,24 @@ class SessionData(Base):
         Args:
             router_data: the router_data dict.
         """
-        super().__init__()
         if router_data:
-            self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
-            self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
-            self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
+            client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
+            client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
+            session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
+        else:
+            client_token = client_ip = session_id = ""
+        object.__setattr__(self, "client_token", client_token)
+        object.__setattr__(self, "client_ip", client_ip)
+        object.__setattr__(self, "session_id", session_id)
 
 
-class RouterData(Base):
+@dataclasses.dataclass(frozen=True, init=False)
+class RouterData:
     """An object containing RouterData."""
 
-    session: SessionData = SessionData()
-    headers: HeaderData = HeaderData()
-    page: PageData = PageData()
+    session: SessionData = dataclasses.field(default_factory=SessionData)
+    headers: HeaderData = dataclasses.field(default_factory=HeaderData)
+    page: PageData = dataclasses.field(default_factory=PageData)
 
     def __init__(self, router_data: Optional[dict] = None):
         """Initialize the RouterData object.
@@ -170,10 +199,30 @@ class RouterData(Base):
         Args:
             router_data: the router_data dict.
         """
-        super().__init__()
-        self.session = SessionData(router_data)
-        self.headers = HeaderData(router_data)
-        self.page = PageData(router_data)
+        object.__setattr__(self, "session", SessionData(router_data))
+        object.__setattr__(self, "headers", HeaderData(router_data))
+        object.__setattr__(self, "page", PageData(router_data))
+
+    def toJson(self) -> str:
+        """Convert the object to a JSON string.
+
+        Returns:
+            The JSON string.
+        """
+        return json.dumps(dataclasses.asdict(self))
+
+
+@serializer
+def serialize_routerdata(value: RouterData) -> str:
+    """Serialize a RouterData instance.
+
+    Args:
+        value: The RouterData to serialize.
+
+    Returns:
+        The serialized RouterData.
+    """
+    return value.toJson()
 
 
 def _no_chain_background_task(
@@ -249,10 +298,11 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
     return token, state_name
 
 
+@dataclasses.dataclass(frozen=True, init=False)
 class EventHandlerSetVar(EventHandler):
     """A special event handler to wrap setvar functionality."""
 
-    state_cls: Type[BaseState]
+    state_cls: Type[BaseState] = dataclasses.field(init=False)
 
     def __init__(self, state_cls: Type[BaseState]):
         """Initialize the EventHandlerSetVar.
@@ -263,8 +313,8 @@ class EventHandlerSetVar(EventHandler):
         super().__init__(
             fn=type(self).setvar,
             state_full_name=state_cls.get_full_name(),
-            state_cls=state_cls,  # type: ignore
         )
+        object.__setattr__(self, "state_cls", state_cls)
 
     def setvar(self, var_name: str, value: Any):
         """Set the state variable to the value of the event.
@@ -1826,8 +1876,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
             self.dirty_vars.update(self._always_dirty_computed_vars)
             self._mark_dirty()
 
+        def dictify(value: Any):
+            if dataclasses.is_dataclass(value) and not isinstance(value, type):
+                return dataclasses.asdict(value)
+            return value
+
         base_vars = {
-            prop_name: self.get_value(getattr(self, prop_name))
+            prop_name: dictify(self.get_value(getattr(self, prop_name)))
             for prop_name in self.base_vars
         }
         if initial and include_computed:
@@ -1907,9 +1962,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
 
 
-EventHandlerSetVar.update_forward_refs()
-
-
 class State(BaseState):
     """The app Base State."""
 
@@ -2341,18 +2393,29 @@ class StateProxy(wrapt.ObjectProxy):
             self._self_mutable = original_mutable
 
 
-class StateUpdate(Base):
+@dataclasses.dataclass(
+    frozen=True,
+)
+class StateUpdate:
     """A state update sent to the frontend."""
 
     # The state delta.
-    delta: Delta = {}
+    delta: Delta = dataclasses.field(default_factory=dict)
 
     # Events to be added to the event queue.
-    events: List[Event] = []
+    events: List[Event] = dataclasses.field(default_factory=list)
 
     # Whether this is the final state update for the event.
     final: bool = True
 
+    def json(self) -> str:
+        """Convert the state update to a JSON string.
+
+        Returns:
+            The state update as a JSON string.
+        """
+        return json.dumps(dataclasses.asdict(self))
+
 
 class StateManager(Base, ABC):
     """A class to manage many client states."""

+ 9 - 0
reflex/utils/format.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 import inspect
 import json
 import os
@@ -623,6 +624,14 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
     if isinstance(value, dict):
         return {k: format_state(v, k) for k, v in value.items()}
 
+    # Hand dataclasses.
+    if dataclasses.is_dataclass(value):
+        if isinstance(value, type):
+            raise TypeError(
+                f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass."
+            )
+        return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}
+
     # Handle lists, sets, typles.
     if isinstance(value, types.StateIterBases):
         return [format_state(v) for v in value]

+ 16 - 73
reflex/utils/imports.py

@@ -2,10 +2,9 @@
 
 from __future__ import annotations
 
+import dataclasses
 from collections import defaultdict
-from typing import Dict, List, Optional, Tuple, Union
-
-from reflex.base import Base
+from typing import DefaultDict, Dict, List, Optional, Tuple, Union
 
 
 def merge_imports(
@@ -19,12 +18,22 @@ def merge_imports(
     Returns:
         The merged import dicts.
     """
-    all_imports = defaultdict(list)
+    all_imports: DefaultDict[str, List[ImportVar]] = defaultdict(list)
     for import_dict in imports:
         for lib, fields in (
             import_dict if isinstance(import_dict, tuple) else import_dict.items()
         ):
-            all_imports[lib].extend(fields)
+            if isinstance(fields, (list, tuple, set)):
+                all_imports[lib].extend(
+                    (
+                        ImportVar(field) if isinstance(field, str) else field
+                        for field in fields
+                    )
+                )
+            else:
+                all_imports[lib].append(
+                    ImportVar(fields) if isinstance(fields, str) else fields
+                )
     return all_imports
 
 
@@ -75,7 +84,8 @@ def collapse_imports(
     }
 
 
-class ImportVar(Base):
+@dataclasses.dataclass(order=True, frozen=True)
+class ImportVar:
     """An import var."""
 
     # The name of the import tag.
@@ -111,73 +121,6 @@ class ImportVar(Base):
         else:
             return self.tag or ""
 
-    def __lt__(self, other: ImportVar) -> bool:
-        """Compare two ImportVar objects.
-
-        Args:
-            other: The other ImportVar object to compare.
-
-        Returns:
-            Whether this ImportVar object is less than the other.
-        """
-        return (
-            self.tag,
-            self.is_default,
-            self.alias,
-            self.install,
-            self.render,
-            self.transpile,
-        ) < (
-            other.tag,
-            other.is_default,
-            other.alias,
-            other.install,
-            other.render,
-            other.transpile,
-        )
-
-    def __eq__(self, other: ImportVar) -> bool:
-        """Check if two ImportVar objects are equal.
-
-        Args:
-            other: The other ImportVar object to compare.
-
-        Returns:
-            Whether the two ImportVar objects are equal.
-        """
-        return (
-            self.tag,
-            self.is_default,
-            self.alias,
-            self.install,
-            self.render,
-            self.transpile,
-        ) == (
-            other.tag,
-            other.is_default,
-            other.alias,
-            other.install,
-            other.render,
-            other.transpile,
-        )
-
-    def __hash__(self) -> int:
-        """Hash the ImportVar object.
-
-        Returns:
-            The hash of the ImportVar object.
-        """
-        return hash(
-            (
-                self.tag,
-                self.is_default,
-                self.alias,
-                self.install,
-                self.render,
-                self.transpile,
-            )
-        )
-
 
 ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
 ImportDict = Dict[str, ImportTypes]

+ 6 - 4
reflex/utils/prerequisites.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 import functools
 import glob
 import importlib
@@ -32,7 +33,6 @@ from redis import exceptions
 from redis.asyncio import Redis
 
 from reflex import constants, model
-from reflex.base import Base
 from reflex.compiler import templates
 from reflex.config import Config, get_config
 from reflex.utils import console, net, path_ops, processes
@@ -43,7 +43,8 @@ from reflex.utils.registry import _get_best_registry
 CURRENTLY_INSTALLING_NODE = False
 
 
-class Template(Base):
+@dataclasses.dataclass(frozen=True)
+class Template:
     """A template for a Reflex app."""
 
     name: str
@@ -52,7 +53,8 @@ class Template(Base):
     demo_url: str
 
 
-class CpuInfo(Base):
+@dataclasses.dataclass(frozen=True)
+class CpuInfo:
     """Model to save cpu info."""
 
     manufacturer_id: Optional[str]
@@ -1279,7 +1281,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]:
             None,
         )
     return {
-        tp["name"]: Template.parse_obj(tp)
+        tp["name"]: Template(**tp)
         for tp in templates_data
         if not tp["hidden"] and tp["code_url"] is not None
     }

+ 2 - 1
reflex/utils/telemetry.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import asyncio
+import dataclasses
 import multiprocessing
 import platform
 import warnings
@@ -144,7 +145,7 @@ def _prepare_event(event: str, **kwargs) -> dict:
             "python_version": get_python_version(),
             "cpu_count": get_cpu_count(),
             "memory": get_memory(),
-            "cpu_info": dict(cpuinfo) if cpuinfo else {},
+            "cpu_info": dataclasses.asdict(cpuinfo) if cpuinfo else {},
             **additional_fields,
         },
         "timestamp": stamp,

+ 6 - 1
reflex/utils/types.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import contextlib
+import dataclasses
 import inspect
 import sys
 import types
@@ -480,7 +481,11 @@ def is_valid_var_type(type_: Type) -> bool:
 
     if is_union(type_):
         return all((is_valid_var_type(arg) for arg in get_args(type_)))
-    return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
+    return (
+        _issubclass(type_, StateVar)
+        or serializers.has_serializer(type_)
+        or dataclasses.is_dataclass(type_)
+    )
 
 
 def is_backend_base_variable(name: str, cls: Type) -> bool:

+ 17 - 17
tests/components/test_component.py

@@ -637,21 +637,21 @@ def test_component_create_unallowed_types(children, test_component):
                 "props": [],
                 "contents": "",
                 "args": None,
-                "special_props": set(),
+                "special_props": [],
                 "children": [
                     {
                         "name": "RadixThemesText",
                         "props": ['as={"p"}'],
                         "contents": "",
                         "args": None,
-                        "special_props": set(),
+                        "special_props": [],
                         "children": [
                             {
                                 "name": "",
                                 "props": [],
                                 "contents": '{"first_text"}',
                                 "args": None,
-                                "special_props": set(),
+                                "special_props": [],
                                 "children": [],
                                 "autofocus": False,
                             }
@@ -679,13 +679,13 @@ def test_component_create_unallowed_types(children, test_component):
                                 "contents": '{"first_text"}',
                                 "name": "",
                                 "props": [],
-                                "special_props": set(),
+                                "special_props": [],
                             }
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
                         "props": ['as={"p"}'],
-                        "special_props": set(),
+                        "special_props": [],
                     },
                     {
                         "args": None,
@@ -698,19 +698,19 @@ def test_component_create_unallowed_types(children, test_component):
                                 "contents": '{"second_text"}',
                                 "name": "",
                                 "props": [],
-                                "special_props": set(),
+                                "special_props": [],
                             }
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
                         "props": ['as={"p"}'],
-                        "special_props": set(),
+                        "special_props": [],
                     },
                 ],
                 "contents": "",
                 "name": "Fragment",
                 "props": [],
-                "special_props": set(),
+                "special_props": [],
             },
         ),
         (
@@ -730,13 +730,13 @@ def test_component_create_unallowed_types(children, test_component):
                                 "contents": '{"first_text"}',
                                 "name": "",
                                 "props": [],
-                                "special_props": set(),
+                                "special_props": [],
                             }
                         ],
                         "contents": "",
                         "name": "RadixThemesText",
                         "props": ['as={"p"}'],
-                        "special_props": set(),
+                        "special_props": [],
                     },
                     {
                         "args": None,
@@ -757,31 +757,31 @@ def test_component_create_unallowed_types(children, test_component):
                                                 "contents": '{"second_text"}',
                                                 "name": "",
                                                 "props": [],
-                                                "special_props": set(),
+                                                "special_props": [],
                                             }
                                         ],
                                         "contents": "",
                                         "name": "RadixThemesText",
                                         "props": ['as={"p"}'],
-                                        "special_props": set(),
+                                        "special_props": [],
                                     }
                                 ],
                                 "contents": "",
                                 "name": "Fragment",
                                 "props": [],
-                                "special_props": set(),
+                                "special_props": [],
                             }
                         ],
                         "contents": "",
                         "name": "RadixThemesBox",
                         "props": [],
-                        "special_props": set(),
+                        "special_props": [],
                     },
                 ],
                 "contents": "",
                 "name": "Fragment",
                 "props": [],
-                "special_props": set(),
+                "special_props": [],
             },
         ),
     ],
@@ -1289,12 +1289,12 @@ class EventState(rx.State):
             id="fstring-class_name",
         ),
         pytest.param(
-            rx.fragment(special_props={TEST_VAR}),
+            rx.fragment(special_props=[TEST_VAR]),
             [TEST_VAR],
             id="direct-special_props",
         ),
         pytest.param(
-            rx.fragment(special_props={LiteralVar.create(f"foo{TEST_VAR}bar")}),
+            rx.fragment(special_props=[LiteralVar.create(f"foo{TEST_VAR}bar")]),
             [FORMATTED_TEST_VAR],
             id="fstring-special_props",
         ),

+ 2 - 1
tests/test_app.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import dataclasses
 import functools
 import io
 import json
@@ -1052,7 +1053,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
                     f"comp_{arg_name}": exp_val,
                     constants.CompileVars.IS_HYDRATED: False,
                     # "side_effect_counter": exp_index,
-                    "router": exp_router,
+                    "router": dataclasses.asdict(exp_router),
                 }
             },
             events=[

+ 20 - 14
tests/test_state.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import copy
+import dataclasses
 import datetime
 import functools
 import json
@@ -58,6 +59,7 @@ formatted_router = {
         "origin": "",
         "upgrade": "",
         "connection": "",
+        "cookie": "",
         "pragma": "",
         "cache_control": "",
         "user_agent": "",
@@ -865,8 +867,10 @@ def test_get_headers(test_state, router_data, router_data_headers):
         router_data: The router data fixture.
         router_data_headers: The expected headers.
     """
+    print(router_data_headers)
     test_state.router = RouterData(router_data)
-    assert test_state.router.headers.dict() == {
+    print(test_state.router.headers)
+    assert dataclasses.asdict(test_state.router.headers) == {
         format.to_snake_case(k): v for k, v in router_data_headers.items()
     }
 
@@ -1908,19 +1912,21 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
     mock_app.event_namespace.emit.assert_called_once()
     mcall = mock_app.event_namespace.emit.mock_calls[0]
     assert mcall.args[0] == str(SocketEvent.EVENT)
-    assert json.loads(mcall.args[1]) == StateUpdate(
-        delta={
-            parent_state.get_full_name(): {
-                "upper": "",
-                "sum": 3.14,
-            },
-            grandchild_state.get_full_name(): {
-                "value2": "42",
-            },
-            GrandchildState3.get_full_name(): {
-                "computed": "",
-            },
-        }
+    assert json.loads(mcall.args[1]) == dataclasses.asdict(
+        StateUpdate(
+            delta={
+                parent_state.get_full_name(): {
+                    "upper": "",
+                    "sum": 3.14,
+                },
+                grandchild_state.get_full_name(): {
+                    "value2": "42",
+                },
+                GrandchildState3.get_full_name(): {
+                    "computed": "",
+                },
+            }
+        )
     )
     assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
 

+ 1 - 0
tests/utils/test_format.py

@@ -553,6 +553,7 @@ formatted_router = {
         "origin": "",
         "upgrade": "",
         "connection": "",
+        "cookie": "",
         "pragma": "",
         "cache_control": "",
         "user_agent": "",

+ 7 - 3
tests/utils/test_imports.py

@@ -54,17 +54,21 @@ def test_import_var(import_var, expected_name):
         (
             {"react": {"Component"}},
             {"react": {"Component"}, "react-dom": {"render"}},
-            {"react": {"Component"}, "react-dom": {"render"}},
+            {"react": {ImportVar("Component")}, "react-dom": {ImportVar("render")}},
         ),
         (
             {"react": {"Component"}, "next/image": {"Image"}},
             {"react": {"Component"}, "react-dom": {"render"}},
-            {"react": {"Component"}, "react-dom": {"render"}, "next/image": {"Image"}},
+            {
+                "react": {ImportVar("Component")},
+                "react-dom": {ImportVar("render")},
+                "next/image": {ImportVar("Image")},
+            },
         ),
         (
             {"react": {"Component"}},
             {"": {"some/custom.css"}},
-            {"react": {"Component"}, "": {"some/custom.css"}},
+            {"react": {ImportVar("Component")}, "": {ImportVar("some/custom.css")}},
         ),
     ],
 )