1
0
Эх сурвалжийг харах

Improve event processing performance (#153)

Nikhil Rao 2 жил өмнө
parent
commit
57e278ae1c

+ 3 - 3
poetry.lock

@@ -604,14 +604,14 @@ plugins = ["importlib-metadata"]
 
 [[package]]
 name = "pyright"
-version = "1.1.284"
+version = "1.1.285"
 description = "Command line wrapper for pyright"
 category = "dev"
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "pyright-1.1.284-py3-none-any.whl", hash = "sha256:e3bfbd33c20af48eed9d20138767265161ba8a4b55c740476a36ce822bd482d1"},
-    {file = "pyright-1.1.284.tar.gz", hash = "sha256:ef7c0e46e38be95687f5a0633e55c5171ca166048b9560558168a976162e287c"},
+    {file = "pyright-1.1.285-py3-none-any.whl", hash = "sha256:8a6b60b3ff0d000c549621c367cdf0013abdaf24d09e6f0b4b95031b357cc4b1"},
+    {file = "pyright-1.1.285.tar.gz", hash = "sha256:ecd28e8556352e2c7eb5f412c6841ec768d25e8a6136326d4a6a67d94370eba1"},
 ]
 
 [package.dependencies]

+ 9 - 1
pynecone/.templates/web/utils/state.js

@@ -123,8 +123,16 @@ export const updateState = async (state, result, setResult, router, socket) => {
  * @param setResult The function to set the result.
  * @param endpoint The endpoint to connect to.
  */
-export const connect = async (socket, state, setResult, endpoint) => {
+export const connect = async (socket, state, result, setResult, router, endpoint) => {
+  // Create the socket.
   socket.current = new WebSocket(endpoint);
+
+  // Once the socket is open, hydrate the page.
+  socket.current.onopen = () => {
+    updateState(state, result, setResult, router, socket.current)
+  }
+
+  // On each received message, apply the delta and set the result.
   socket.current.onmessage = function (update) {
     update = JSON.parse(update.data);
     applyDelta(state, update.delta);

+ 1 - 1
pynecone/compiler/templates.py

@@ -144,7 +144,7 @@ USE_EFFECT = join(
     [
         "useEffect(() => {{",
         f"  if (!{SOCKET}.current) {{{{",
-        f"    connect({SOCKET}, {{state}}, {SET_RESULT}, {EVENT_ENDPOINT})",
+        f"    connect({SOCKET}, {{state}}, {RESULT}, {SET_RESULT}, {ROUTER}, {EVENT_ENDPOINT})",
         "  }}",
         "  const update = async () => {{",
         f"    if ({RESULT}.{STATE} != null) {{{{",

+ 35 - 16
pynecone/state.py

@@ -257,6 +257,17 @@ class State(Base, ABC):
             field.required = False
             field.default = default_value
 
+    def getattr(self, name: str) -> Any:
+        """Get a non-prop attribute.
+
+        Args:
+            name: The name of the attribute.
+
+        Returns:
+            The attribute.
+        """
+        return super().__getattribute__(name)
+
     def __getattribute__(self, name: str) -> Any:
         """Get the attribute.
 
@@ -287,17 +298,20 @@ class State(Base, ABC):
             name: The name of the attribute.
             value: The value of the attribute.
         """
-        if name != "inherited_vars" and name in self.inherited_vars:
-            setattr(self.parent_state, name, value)
+        # NOTE: We use super().__getattribute__ for performance reasons.
+        if name != "inherited_vars" and name in super().__getattribute__(
+            "inherited_vars"
+        ):
+            setattr(super().__getattribute__("parent_state"), name, value)
             return
 
         # Set the attribute.
         super().__setattr__(name, value)
 
         # Add the var to the dirty list.
-        if name in self.vars:
-            self.dirty_vars.add(name)
-            self.mark_dirty()
+        if name in super().__getattribute__("vars"):
+            super().__getattribute__("dirty_vars").add(name)
+            super().__getattribute__("mark_dirty")()
 
     def reset(self):
         """Reset all the base vars to their default values."""
@@ -344,10 +358,11 @@ class State(Base, ABC):
         Returns:
             The state update after processing the event.
         """
+        # NOTE: We use super().__getattribute__ for performance reasons.
         # Get the event handler.
         path = event.name.split(".")
         path, name = path[:-1], path[-1]
-        substate = self.get_substate(path)
+        substate = super().__getattribute__("get_substate")(path)
         handler = getattr(substate, name)
 
         # Process the event.
@@ -368,10 +383,10 @@ class State(Base, ABC):
         events = utils.fix_events(events, event.token)
 
         # Get the delta after processing the event.
-        delta = self.get_delta()
+        delta = super().__getattribute__("get_delta")()
 
         # Reset the dirty vars.
-        self.clean()
+        super().__getattribute__("clean")()
 
         # Return the state update.
         return StateUpdate(delta=delta, events=events)
@@ -382,19 +397,22 @@ class State(Base, ABC):
         Returns:
             The delta for the state.
         """
+        # NOTE: We use super().__getattribute__ for performance reasons.
         delta = {}
 
         # Return the dirty vars, as well as all computed vars.
         subdelta = {
             prop: getattr(self, prop)
-            for prop in self.dirty_vars | set(self.computed_vars.keys())
+            for prop in super().__getattribute__("dirty_vars")
+            | set(super().__getattribute__("computed_vars").keys())
         }
         if len(subdelta) > 0:
-            delta[self.get_full_name()] = subdelta
+            delta[super().__getattribute__("get_full_name")()] = subdelta
 
         # Recursively find the substate deltas.
-        for substate in self.dirty_substates:
-            delta.update(self.substates[substate].get_delta())
+        substates = super().__getattribute__("substates")
+        for substate in super().__getattribute__("dirty_substates"):
+            delta.update(substates[substate].getattr("get_delta")())
 
         # Format the delta.
         delta = utils.format_state(delta)
@@ -410,13 +428,14 @@ class State(Base, ABC):
 
     def clean(self):
         """Reset the dirty vars."""
+        # NOTE: We use super().__getattribute__ for performance reasons.
         # Recursively clean the substates.
-        for substate in self.dirty_substates:
-            self.substates[substate].clean()
+        for substate in super().__getattribute__("dirty_substates"):
+            super().__getattribute__("substates")[substate].getattr("clean")()
 
         # Clean this state.
-        self.dirty_vars = set()
-        self.dirty_substates = set()
+        super().__setattr__("dirty_vars", set())
+        super().__setattr__("dirty_substates", set())
 
     def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
         """Convert the object to a dictionary.

+ 18 - 14
pynecone/utils.py

@@ -851,8 +851,16 @@ def format_state(value: Any) -> Dict:
     Raises:
         TypeError: If the given value is not a valid state.
     """
+    # Handle dicts.
+    if isinstance(value, dict):
+        return {k: format_state(v) for k, v in value.items()}
+
+    # Return state vars as is.
+    if isinstance(value, StateBases):
+        return value
+
     # Convert plotly figures to JSON.
-    if _isinstance(value, go.Figure):
+    if isinstance(value, go.Figure):
         return json.loads(to_json(value))["data"]
 
     # Convert pandas dataframes to JSON.
@@ -862,19 +870,11 @@ def format_state(value: Any) -> Dict:
             "data": value.values.tolist(),
         }
 
-    # Handle dicts.
-    if _isinstance(value, dict):
-        return {k: format_state(v) for k, v in value.items()}
-
-    # Make sure the value is JSON serializable.
-    if not _isinstance(value, StateVar):
-        raise TypeError(
-            "State vars must be primitive Python types, "
-            "or subclasses of pc.Base. "
-            f"Got var of type {type(value)}."
-        )
-
-    return value
+    raise TypeError(
+        "State vars must be primitive Python types, "
+        "or subclasses of pc.Base. "
+        f"Got var of type {type(value)}."
+    )
 
 
 def get_event(state, event):
@@ -1069,3 +1069,7 @@ def get_redis() -> Optional[Redis]:
     redis_url, redis_port = config.redis_url.split(":")
     print("Using redis at", config.redis_url)
     return Redis(host=redis_url, port=int(redis_port), db=0)
+
+
+# Store this here for performance.
+StateBases = get_base_class(StateVar)