Browse Source

Dynamically add vars to a State (#381)

Thomas Brandého 2 years ago
parent
commit
b06f612a7d
3 changed files with 92 additions and 13 deletions
  1. 20 0
      pynecone/base.py
  2. 55 13
      pynecone/state.py
  3. 17 0
      tests/test_state.py

+ 20 - 0
pynecone/base.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 from typing import Any, Dict, TypeVar
 
 import pydantic
+from pydantic.fields import ModelField
 
 # Typevar to represent any class subclassing Base.
 PcType = TypeVar("PcType")
@@ -55,6 +56,25 @@ class Base(pydantic.BaseModel):
         """
         return cls.__fields__
 
+    @classmethod
+    def add_field(cls, var: Any, default_value: Any):
+        """Add a pydantic field after class definition.
+
+        Used by State.add_var() to correctly handle the new variable.
+
+        Args:
+            var: The variable to add a pydantic field for.
+            default_value: The default value of the field
+        """
+        new_field = ModelField.infer(
+            name=var.name,
+            value=default_value,
+            annotation=var.type_,
+            class_validators=None,
+            config=cls.__config__,
+        )
+        cls.__fields__.update({var.name: new_field})
+
     def get_value(self, key: str) -> Any:
         """Get the value of a field.
 

+ 55 - 13
pynecone/state.py

@@ -109,9 +109,6 @@ class State(Base, ABC):
 
         Args:
             **kwargs: The kwargs to pass to the pydantic init_subclass method.
-
-        Raises:
-            TypeError: If the class has a var with an invalid type.
         """
         super().__init_subclass__(**kwargs)
 
@@ -146,16 +143,7 @@ class State(Base, ABC):
 
         # Setup the base vars at the class level.
         for prop in cls.base_vars.values():
-            if not utils.is_valid_var_type(prop.type_):
-                raise TypeError(
-                    "State vars must be primitive Python types, "
-                    "Plotly figures, Pandas dataframes, "
-                    "or subclasses of pc.Base. "
-                    f'Found var "{prop.name}" with type {prop.type_}.'
-                )
-            cls._set_var(prop)
-            cls._create_setter(prop)
-            cls._set_default_value(prop)
+            cls._init_var(prop)
 
         # Set up the event handlers.
         events = {
@@ -261,6 +249,60 @@ class State(Base, ABC):
             raise ValueError(f"Invalid path: {path}")
         return getattr(substate, name)
 
+    @classmethod
+    def _init_var(cls, prop: BaseVar):
+        """Initialize a variable.
+
+        Args:
+            prop (BaseVar): The variable to initialize
+
+        Raises:
+            TypeError: if the variable has an incorrect type
+        """
+        if not utils.is_valid_var_type(prop.type_):
+            raise TypeError(
+                "State vars must be primitive Python types, "
+                "Plotly figures, Pandas dataframes, "
+                "or subclasses of pc.Base. "
+                f'Found var "{prop.name}" with type {prop.type_}.'
+            )
+        cls._set_var(prop)
+        cls._create_setter(prop)
+        cls._set_default_value(prop)
+
+    @classmethod
+    def add_var(cls, name: str, type_: Any, default_value: Any = None):
+        """Add dynamically a variable to the State.
+
+        The variable added this way can be used in the same way as a variable
+        defined statically in the model.
+
+        Args:
+            name (str): The name of the variable
+            type_ (Any): The type of the variable
+            default_value (Any): The default value of the variable
+
+        Raises:
+            NameError: if a variable of this name already exists
+        """
+        if name in cls.__fields__:
+            raise NameError(
+                f"The variable '{name}' already exist. Use a different name"
+            )
+
+        # create the variable based on name and type
+        var = BaseVar(name=name, type_=type_)
+        var.set_state(cls)
+
+        # add the pydantic field dynamically (must be done before _init_var)
+        cls.add_field(var, default_value)
+
+        cls._init_var(var)
+
+        # update the internal dicts so the new variable is correctly handled
+        cls.base_vars.update({name: var})
+        cls.vars.update({name: var})
+
     @classmethod
     def _set_var(cls, prop: BaseVar):
         """Set the var as a class member.

+ 17 - 0
tests/test_state.py

@@ -638,3 +638,20 @@ def test_get_query_params(test_state):
     test_state.router_data = {RouteVar.QUERY: params}
 
     assert test_state.get_query_params() == params
+
+
+def test_add_var(test_state):
+    test_state.add_var("dynamic_int", int, 42)
+    assert test_state.dynamic_int == 42
+
+    test_state.add_var("dynamic_list", List[int], [5, 10])
+    assert test_state.dynamic_list == [5, 10]
+    assert getattr(test_state, "dynamic_list") == [5, 10]
+
+    # how to test that one?
+    # test_state.dynamic_list.append(15)
+    # assert test_state.dynamic_list == [5, 10, 15]
+
+    test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
+    assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
+    assert getattr(test_state, "dynamic_dict") == {"k1": 5, "k2": 10}