Pārlūkot izejas kodu

Improvements to custom styles in rx.markdown (#1852)

Nikhil Rao 1 gadu atpakaļ
vecāks
revīzija
8231993e5a

+ 142 - 142
integration/test_var_operations.py

@@ -21,8 +21,11 @@ def VarOperations():
         float_var2: float = 5.5
         float_var2: float = 5.5
         list1: list = [1, 2]
         list1: list = [1, 2]
         list2: list = [3, 4]
         list2: list = [3, 4]
+        list3: list = ["first", "second", "third"]
         str_var1: str = "first"
         str_var1: str = "first"
         str_var2: str = "second"
         str_var2: str = "second"
+        str_var3: str = "ThIrD"
+        str_var4: str = "a long string"
         dict1: dict = {1: 2}
         dict1: dict = {1: 2}
         dict2: dict = {3: 4}
         dict2: dict = {3: 4}
 
 
@@ -514,6 +517,11 @@ def VarOperations():
             rx.text(
             rx.text(
                 VarOperationState.dict1.contains(1).to_string(), id="dict_contains"
                 VarOperationState.dict1.contains(1).to_string(), id="dict_contains"
             ),
             ),
+            rx.text(VarOperationState.str_var3.lower(), id="str_lower"),
+            rx.text(VarOperationState.str_var3.upper(), id="str_upper"),
+            rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
+            rx.text(VarOperationState.list3.join(""), id="list_join"),
+            rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
         )
         )
 
 
     app.compile()
     app.compile()
@@ -567,145 +575,137 @@ def test_var_operations(driver, var_operations: AppHarness):
         driver: selenium WebDriver open to the app
         driver: selenium WebDriver open to the app
         var_operations: AppHarness for the var operations app
         var_operations: AppHarness for the var operations app
     """
     """
-    assert var_operations.app_instance is not None, "app is not running"
-    # INT INT
-    assert driver.find_element(By.ID, "int_add_int").text == "15"
-    assert driver.find_element(By.ID, "int_mult_int").text == "50"
-    assert driver.find_element(By.ID, "int_sub_int").text == "5"
-    assert driver.find_element(By.ID, "int_exp_int").text == "100000"
-    assert driver.find_element(By.ID, "int_div_int").text == "2"
-    assert driver.find_element(By.ID, "int_floor_int").text == "1"
-    assert driver.find_element(By.ID, "int_mod_int").text == "0"
-    assert driver.find_element(By.ID, "int_gt_int").text == "true"
-    assert driver.find_element(By.ID, "int_lt_int").text == "false"
-    assert driver.find_element(By.ID, "int_gte_int").text == "true"
-    assert driver.find_element(By.ID, "int_lte_int").text == "false"
-    assert driver.find_element(By.ID, "int_and_int").text == "5"
-    assert driver.find_element(By.ID, "int_or_int").text == "10"
-    assert driver.find_element(By.ID, "int_eq_int").text == "false"
-    assert driver.find_element(By.ID, "int_neq_int").text == "true"
-
-    # INT FLOAT OR FLOAT INT
-    assert driver.find_element(By.ID, "float_add_int").text == "15.5"
-    assert driver.find_element(By.ID, "float_mult_int").text == "52.5"
-    assert driver.find_element(By.ID, "float_sub_int").text == "5.5"
-    assert driver.find_element(By.ID, "float_exp_int").text == "127628.15625"
-    assert driver.find_element(By.ID, "float_div_int").text == "2.1"
-    assert driver.find_element(By.ID, "float_floor_int").text == "1"
-    assert driver.find_element(By.ID, "float_mod_int").text == "0.5"
-    assert driver.find_element(By.ID, "float_gt_int").text == "true"
-    assert driver.find_element(By.ID, "float_lt_int").text == "false"
-    assert driver.find_element(By.ID, "float_gte_int").text == "true"
-    assert driver.find_element(By.ID, "float_lte_int").text == "false"
-    assert driver.find_element(By.ID, "float_eq_int").text == "false"
-    assert driver.find_element(By.ID, "float_neq_int").text == "true"
-    assert driver.find_element(By.ID, "float_and_int").text == "5"
-    assert driver.find_element(By.ID, "float_or_int").text == "10.5"
-
-    # INT, DICT
-    assert driver.find_element(By.ID, "int_or_dict").text == "10"
-    assert driver.find_element(By.ID, "int_and_dict").text == '{"1":2}'
-    assert driver.find_element(By.ID, "int_eq_dict").text == "false"
-    assert driver.find_element(By.ID, "int_neq_dict").text == "true"
-
-    # FLOAT FLOAT
-    assert driver.find_element(By.ID, "float_add_float").text == "16"
-    assert driver.find_element(By.ID, "float_mult_float").text == "57.75"
-    assert driver.find_element(By.ID, "float_sub_float").text == "5"
-    assert driver.find_element(By.ID, "float_exp_float").text == "413562.49323606625"
-    assert driver.find_element(By.ID, "float_div_float").text == "1.9090909090909092"
-    assert driver.find_element(By.ID, "float_floor_float").text == "1"
-    assert driver.find_element(By.ID, "float_mod_float").text == "5"
-    assert driver.find_element(By.ID, "float_gt_float").text == "true"
-    assert driver.find_element(By.ID, "float_lt_float").text == "false"
-    assert driver.find_element(By.ID, "float_gte_float").text == "true"
-    assert driver.find_element(By.ID, "float_lte_float").text == "false"
-    assert driver.find_element(By.ID, "float_eq_float").text == "false"
-    assert driver.find_element(By.ID, "float_neq_float").text == "true"
-    assert driver.find_element(By.ID, "float_and_float").text == "5.5"
-    assert driver.find_element(By.ID, "float_or_float").text == "10.5"
-
-    # FLOAT STR
-    assert driver.find_element(By.ID, "float_or_str").text == "10.5"
-    assert driver.find_element(By.ID, "float_and_str").text == "first"
-    assert driver.find_element(By.ID, "float_eq_str").text == "false"
-    assert driver.find_element(By.ID, "float_neq_str").text == "true"
-
-    # FLOAT,LIST
-    assert driver.find_element(By.ID, "float_or_list").text == "10.5"
-    assert driver.find_element(By.ID, "float_and_list").text == "[1,2]"
-    assert driver.find_element(By.ID, "float_eq_list").text == "false"
-    assert driver.find_element(By.ID, "float_neq_list").text == "true"
-
-    # FLOAT, DICT
-    assert driver.find_element(By.ID, "float_or_dict").text == "10.5"
-    assert driver.find_element(By.ID, "float_and_dict").text == '{"1":2}'
-    assert driver.find_element(By.ID, "float_eq_dict").text == "false"
-    assert driver.find_element(By.ID, "float_neq_dict").text == "true"
-
-    # STR STR
-    assert driver.find_element(By.ID, "str_add_str").text == "firstsecond"
-    assert driver.find_element(By.ID, "str_gt_str").text == "false"
-    assert driver.find_element(By.ID, "str_lt_str").text == "true"
-    assert driver.find_element(By.ID, "str_gte_str").text == "false"
-    assert driver.find_element(By.ID, "str_lte_str").text == "true"
-    assert driver.find_element(By.ID, "str_eq_str").text == "false"
-    assert driver.find_element(By.ID, "str_neq_str").text == "true"
-    assert driver.find_element(By.ID, "str_and_str").text == "second"
-    assert driver.find_element(By.ID, "str_or_str").text == "first"
-    assert driver.find_element(By.ID, "str_contains").text == "true"
-
-    # STR INT
-    assert (
-        driver.find_element(By.ID, "str_mult_int").text == "firstfirstfirstfirstfirst"
-    )
-    assert driver.find_element(By.ID, "str_and_int").text == "5"
-    assert driver.find_element(By.ID, "str_or_int").text == "first"
-    assert driver.find_element(By.ID, "str_eq_int").text == "false"
-    assert driver.find_element(By.ID, "str_neq_int").text == "true"
-
-    # STR, LIST
-    assert driver.find_element(By.ID, "str_and_list").text == "[1,2]"
-    assert driver.find_element(By.ID, "str_or_list").text == "first"
-    assert driver.find_element(By.ID, "str_eq_list").text == "false"
-    assert driver.find_element(By.ID, "str_neq_list").text == "true"
-
-    # STR, DICT
-
-    assert driver.find_element(By.ID, "str_or_dict").text == "first"
-    assert driver.find_element(By.ID, "str_and_dict").text == '{"1":2}'
-    assert driver.find_element(By.ID, "str_eq_dict").text == "false"
-    assert driver.find_element(By.ID, "str_neq_dict").text == "true"
-
-    # LIST,LIST
-    assert driver.find_element(By.ID, "list_add_list").text == "[1,2,3,4]"
-    assert driver.find_element(By.ID, "list_gt_list").text == "false"
-    assert driver.find_element(By.ID, "list_lt_list").text == "true"
-    assert driver.find_element(By.ID, "list_gte_list").text == "false"
-    assert driver.find_element(By.ID, "list_lte_list").text == "true"
-    assert driver.find_element(By.ID, "list_eq_list").text == "false"
-    assert driver.find_element(By.ID, "list_neq_list").text == "true"
-    assert driver.find_element(By.ID, "list_and_list").text == "[3,4]"
-    assert driver.find_element(By.ID, "list_or_list").text == "[1,2]"
-    assert driver.find_element(By.ID, "list_contains").text == "true"
-    assert driver.find_element(By.ID, "list_reverse").text == "[2,1]"
-
-    # LIST INT
-    assert driver.find_element(By.ID, "list_mult_int").text == "[1,2,1,2,1,2,1,2,1,2]"
-    assert driver.find_element(By.ID, "list_or_int").text == "[1,2]"
-    assert driver.find_element(By.ID, "list_and_int").text == "10"
-    assert driver.find_element(By.ID, "list_eq_int").text == "false"
-    assert driver.find_element(By.ID, "list_neq_int").text == "true"
-
-    # LIST DICT
-    assert driver.find_element(By.ID, "list_and_dict").text == '{"1":2}'
-    assert driver.find_element(By.ID, "list_or_dict").text == "[1,2]"
-    assert driver.find_element(By.ID, "list_eq_dict").text == "false"
-    assert driver.find_element(By.ID, "list_neq_dict").text == "true"
-
-    # DICT, DICT
-    assert driver.find_element(By.ID, "dict_or_dict").text == '{"1":2}'
-    assert driver.find_element(By.ID, "dict_and_dict").text == '{"3":4}'
-    assert driver.find_element(By.ID, "dict_eq_dict").text == "false"
-    assert driver.find_element(By.ID, "dict_neq_dict").text == "true"
-    assert driver.find_element(By.ID, "dict_contains").text == "true"
+    tests = [
+        # int, int
+        ("int_add_int", "15"),
+        ("int_mult_int", "50"),
+        ("int_sub_int", "5"),
+        ("int_exp_int", "100000"),
+        ("int_div_int", "2"),
+        ("int_floor_int", "1"),
+        ("int_mod_int", "0"),
+        ("int_gt_int", "true"),
+        ("int_lt_int", "false"),
+        ("int_gte_int", "true"),
+        ("int_lte_int", "false"),
+        ("int_and_int", "5"),
+        ("int_or_int", "10"),
+        ("int_eq_int", "false"),
+        ("int_neq_int", "true"),
+        # int, float
+        ("float_add_int", "15.5"),
+        ("float_mult_int", "52.5"),
+        ("float_sub_int", "5.5"),
+        ("float_exp_int", "127628.15625"),
+        ("float_div_int", "2.1"),
+        ("float_floor_int", "1"),
+        ("float_mod_int", "0.5"),
+        ("float_gt_int", "true"),
+        ("float_lt_int", "false"),
+        ("float_gte_int", "true"),
+        ("float_lte_int", "false"),
+        ("float_eq_int", "false"),
+        ("float_neq_int", "true"),
+        ("float_and_int", "5"),
+        ("float_or_int", "10.5"),
+        # int, dict
+        ("int_or_dict", "10"),
+        ("int_and_dict", '{"1":2}'),
+        ("int_eq_dict", "false"),
+        ("int_neq_dict", "true"),
+        # float, float
+        ("float_add_float", "16"),
+        ("float_mult_float", "57.75"),
+        ("float_sub_float", "5"),
+        ("float_exp_float", "413562.49323606625"),
+        ("float_div_float", "1.9090909090909092"),
+        ("float_floor_float", "1"),
+        ("float_mod_float", "5"),
+        ("float_gt_float", "true"),
+        ("float_lt_float", "false"),
+        ("float_gte_float", "true"),
+        ("float_lte_float", "false"),
+        ("float_eq_float", "false"),
+        ("float_neq_float", "true"),
+        ("float_and_float", "5.5"),
+        ("float_or_float", "10.5"),
+        # float, str
+        ("float_or_str", "10.5"),
+        ("float_and_str", "first"),
+        ("float_eq_str", "false"),
+        ("float_neq_str", "true"),
+        # float, list
+        ("float_or_list", "10.5"),
+        ("float_and_list", "[1,2]"),
+        ("float_eq_list", "false"),
+        ("float_neq_list", "true"),
+        # float, dict
+        ("float_or_dict", "10.5"),
+        ("float_and_dict", '{"1":2}'),
+        ("float_eq_dict", "false"),
+        ("float_neq_dict", "true"),
+        # str, str
+        ("str_add_str", "firstsecond"),
+        ("str_gt_str", "false"),
+        ("str_lt_str", "true"),
+        ("str_gte_str", "false"),
+        ("str_lte_str", "true"),
+        ("str_eq_str", "false"),
+        ("str_neq_str", "true"),
+        ("str_and_str", "second"),
+        ("str_or_str", "first"),
+        ("str_contains", "true"),
+        ("str_lower", "third"),
+        ("str_upper", "THIRD"),
+        ("str_split", '["a","long","string"]'),
+        # str, int
+        ("str_mult_int", "firstfirstfirstfirstfirst"),
+        ("str_and_int", "5"),
+        ("str_or_int", "first"),
+        ("str_eq_int", "false"),
+        ("str_neq_int", "true"),
+        # str, list
+        ("str_and_list", "[1,2]"),
+        ("str_or_list", "first"),
+        ("str_eq_list", "false"),
+        ("str_neq_list", "true"),
+        # str, dict
+        ("str_or_dict", "first"),
+        ("str_and_dict", '{"1":2}'),
+        ("str_eq_dict", "false"),
+        ("str_neq_dict", "true"),
+        # list, list
+        ("list_add_list", "[1,2,3,4]"),
+        ("list_gt_list", "false"),
+        ("list_lt_list", "true"),
+        ("list_gte_list", "false"),
+        ("list_lte_list", "true"),
+        ("list_eq_list", "false"),
+        ("list_neq_list", "true"),
+        ("list_and_list", "[3,4]"),
+        ("list_or_list", "[1,2]"),
+        ("list_contains", "true"),
+        ("list_reverse", "[2,1]"),
+        ("list_join", "firstsecondthird"),
+        ("list_join_comma", "first,second,third"),
+        # list, int
+        ("list_mult_int", "[1,2,1,2,1,2,1,2,1,2]"),
+        ("list_or_int", "[1,2]"),
+        ("list_and_int", "10"),
+        ("list_eq_int", "false"),
+        ("list_neq_int", "true"),
+        # list, dict
+        ("list_and_dict", '{"1":2}'),
+        ("list_or_dict", "[1,2]"),
+        ("list_eq_dict", "false"),
+        ("list_neq_dict", "true"),
+        # dict, dict
+        ("dict_or_dict", '{"1":2}'),
+        ("dict_and_dict", '{"3":4}'),
+        ("dict_eq_dict", "false"),
+        ("dict_neq_dict", "true"),
+        ("dict_contains", "true"),
+    ]
+
+    for tag, expected in tests:
+        assert driver.find_element(By.ID, tag).text == expected

+ 1 - 1
reflex/components/base/script.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventHandler, EventChain, EventSpec
 from reflex.event import EventHandler, EventChain, EventSpec

+ 6 - 1
reflex/components/component.py

@@ -447,7 +447,12 @@ class Component(Base, ABC):
 
 
         return cls(children=children, **props)
         return cls(children=children, **props)
 
 
-    def _add_style(self, style):
+    def _add_style(self, style: dict):
+        """Add additional style to the component.
+
+        Args:
+            style: A style dict to apply.
+        """
         self.style.update(style)
         self.style.update(style)
 
 
     def add_style(self, style: ComponentStyle) -> Component:
     def add_style(self, style: ComponentStyle) -> Component:

+ 1 - 2
reflex/components/datadisplay/code.py

@@ -39,7 +39,7 @@ class CodeBlock(Component):
     wrap_long_lines: Var[bool]
     wrap_long_lines: Var[bool]
 
 
     # A custom style for the code block.
     # A custom style for the code block.
-    custom_style: Var[Dict[str, str]]
+    custom_style: Dict[str, str] = {}
 
 
     # Props passed down to the code tag.
     # Props passed down to the code tag.
     code_tag_props: Var[Dict[str, str]]
     code_tag_props: Var[Dict[str, str]]
@@ -107,7 +107,6 @@ class CodeBlock(Component):
             return code_block
             return code_block
 
 
     def _add_style(self, style):
     def _add_style(self, style):
-        self.custom_style = self.custom_style or {}
         self.custom_style.update(style)  # type: ignore
         self.custom_style.update(style)  # type: ignore
 
 
     def _render(self):
     def _render(self):

+ 5 - 1
reflex/components/datadisplay/list.py

@@ -1,5 +1,7 @@
 """List components."""
 """List components."""
 
 
+from __future__ import annotations
+
 from reflex.components import Component
 from reflex.components import Component
 from reflex.components.layout.foreach import Foreach
 from reflex.components.layout.foreach import Foreach
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
@@ -21,7 +23,9 @@ class List(ChakraComponent):
     style_type: Var[str]
     style_type: Var[str]
 
 
     @classmethod
     @classmethod
-    def create(cls, *children, items=None, **props) -> Component:
+    def create(
+        cls, *children, items: list | Var[list] | None = None, **props
+    ) -> Component:
         """Create a list component.
         """Create a list component.
 
 
         Args:
         Args:

+ 3 - 3
reflex/components/datadisplay/list.pyi

@@ -12,7 +12,7 @@ from reflex.event import EventHandler, EventChain, EventSpec
 class List(ChakraComponent):
 class List(ChakraComponent):
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(cls, *children, items, spacing: Optional[Union[Var[str], str]] = None, style_position: Optional[Union[Var[str], str]] = None, style_type: Optional[Union[Var[str], str]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "List":  # type: ignore
+    def create(cls, *children, items: Optional[list | Var[list] | None] = None, spacing: Optional[Union[Var[str], str]] = None, style_position: Optional[Union[Var[str], str]] = None, style_type: Optional[Union[Var[str], str]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "List":  # type: ignore
         """Create a list component.
         """Create a list component.
 
 
         Args:
         Args:
@@ -49,7 +49,7 @@ class ListItem(ChakraComponent):
 class OrderedList(List):
 class OrderedList(List):
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(cls, *children, items, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "OrderedList":  # type: ignore
+    def create(cls, *children, items: Optional[list | Var[list] | None] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "OrderedList":  # type: ignore
         """Create a list component.
         """Create a list component.
 
 
         Args:
         Args:
@@ -65,7 +65,7 @@ class OrderedList(List):
 class UnorderedList(List):
 class UnorderedList(List):
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(cls, *children, items, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "UnorderedList":  # type: ignore
+    def create(cls, *children, items: Optional[list | Var[list] | None] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "UnorderedList":  # type: ignore
         """Create a list component.
         """Create a list component.
 
 
         Args:
         Args:

+ 1 - 1
reflex/components/forms/checkbox.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/forms/editable.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 2 - 2
reflex/components/forms/pininput.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
@@ -44,7 +44,7 @@ class PinInput(ChakraComponent):
 class PinInputField(ChakraComponent):
 class PinInputField(ChakraComponent):
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(cls, *children, index: Optional[Union[Var[int], int]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "PinInputField":  # type: ignore
+    def create(cls, *children, index: Optional[Var[int]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "PinInputField":  # type: ignore
         """Create the component.
         """Create the component.
 
 
                Args:
                Args:

+ 1 - 1
reflex/components/forms/rangeslider.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, List, Optional, Union, overload
+from typing import Any, List, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/forms/slider.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/forms/switch.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/forms/textarea.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/forms/upload.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, List, Optional, Union, overload
+from typing import Any, Dict, List, Optional, Union, overload
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventHandler, EventChain, EventSpec
 from reflex.event import EventHandler, EventChain, EventSpec

+ 1 - 1
reflex/components/media/avatar.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/alertdialog.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/drawer.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/menu.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, List, Optional, Union, overload
+from typing import Any, List, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/modal.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/popover.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 1 - 1
reflex/components/overlay/tooltip.pyi

@@ -3,7 +3,7 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Dict, Optional, Union, overload
+from typing import Any, Optional, Union, overload
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.libs.chakra import ChakraComponent
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar

+ 146 - 86
reflex/components/typography/markdown.py

@@ -1,32 +1,60 @@
 """Markdown component."""
 """Markdown component."""
 
 
+from __future__ import annotations
+
 import textwrap
 import textwrap
-from typing import Callable, Dict, List, Union
+from typing import Any, Callable, Dict, Union
 
 
 from reflex.compiler import utils
 from reflex.compiler import utils
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
 from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
 from reflex.components.navigation import Link
 from reflex.components.navigation import Link
+from reflex.components.tags.tag import Tag
 from reflex.components.typography.heading import Heading
 from reflex.components.typography.heading import Heading
 from reflex.components.typography.text import Text
 from reflex.components.typography.text import Text
-from reflex.style import Style
-from reflex.utils import types
-from reflex.vars import BaseVar, ImportVar, Var
-
-# Mapping from markdown tags to components.
-components_by_tag: Dict[str, Callable] = {
-    "h1": Heading,
-    "h2": Heading,
-    "h3": Heading,
-    "h4": Heading,
-    "h5": Heading,
-    "h6": Heading,
-    "p": Text,
-    "ul": UnorderedList,
-    "ol": OrderedList,
-    "li": ListItem,
-    "a": Link,
-}
+from reflex.utils import console, imports, types
+from reflex.vars import ImportVar, Var
+
+# Special vars used in the component map.
+_CHILDREN = Var.create_safe("children", is_local=False)
+_PROPS = Var.create_safe("...props", is_local=False)
+
+# Special remark plugins.
+_REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
+_REMARK_GFM = Var.create_safe("remarkGfm", is_local=False)
+_REMARK_PLUGINS = Var.create_safe([_REMARK_MATH, _REMARK_GFM])
+
+# Special rehype plugins.
+_REHYPE_KATEX = Var.create_safe("rehypeKatex", is_local=False)
+_REHYPE_RAW = Var.create_safe("rehypeRaw", is_local=False)
+_REHYPE_PLUGINS = Var.create_safe([_REHYPE_KATEX, _REHYPE_RAW])
+
+# Component Mapping
+def get_base_component_map() -> dict[str, Callable]:
+    """Get the base component map.
+
+    Returns:
+        The base component map.
+    """
+    from reflex.components.datadisplay.code import Code, CodeBlock
+
+    return {
+        "h1": lambda value: Heading.create(value, as_="h1", size="2xl"),
+        "h2": lambda value: Heading.create(value, as_="h2", size="xl"),
+        "h3": lambda value: Heading.create(value, as_="h3", size="lg"),
+        "h4": lambda value: Heading.create(value, as_="h4", size="md"),
+        "h5": lambda value: Heading.create(value, as_="h5", size="sm"),
+        "h6": lambda value: Heading.create(value, as_="h6", size="xs"),
+        "p": lambda value: Text.create(value),
+        "ul": lambda value: UnorderedList.create(value),  # type: ignore
+        "ol": lambda value: OrderedList.create(value),  # type: ignore
+        "li": lambda value: ListItem.create(value),
+        "a": lambda value: Link.create(value),
+        "code": lambda value: Code.create(value),
+        "codeblock": lambda *children, **props: CodeBlock.create(
+            *children, theme="light", **props
+        ),
+    }
 
 
 
 
 class Markdown(Component):
 class Markdown(Component):
@@ -38,36 +66,11 @@ class Markdown(Component):
 
 
     is_default = True
     is_default = True
 
 
-    # Custom defined styles for the markdown elements.
-    custom_styles: Dict[str, Style] = {
-        k: Style(v)
-        for k, v in {
-            "h1": {
-                "as_": "h1",
-                "size": "2xl",
-            },
-            "h2": {
-                "as_": "h2",
-                "size": "xl",
-            },
-            "h3": {
-                "as_": "h3",
-                "size": "lg",
-            },
-            "h4": {
-                "as_": "h4",
-                "size": "md",
-            },
-            "h5": {
-                "as_": "h5",
-                "size": "sm",
-            },
-            "h6": {
-                "as_": "h6",
-                "size": "xs",
-            },
-        }.items()
-    }
+    # The component map from a tag to a lambda that creates a component.
+    component_map: Dict[str, Any] = {}
+
+    # Custom styles for the markdown (deprecated in v0.2.9).
+    custom_styles: Dict[str, Any] = {}
 
 
     @classmethod
     @classmethod
     def create(cls, *children, **props) -> Component:
     def create(cls, *children, **props) -> Component:
@@ -84,13 +87,29 @@ class Markdown(Component):
             children[0], Union[str, Var]
             children[0], Union[str, Var]
         ), "Markdown component must have exactly one child containing the markdown source."
         ), "Markdown component must have exactly one child containing the markdown source."
 
 
+        # Custom styles are deprecated.
+        if "custom_styles" in props:
+            console.deprecate(
+                "rx.markdown custom_styles",
+                "Use the component_map prop instead.",
+                "0.2.9",
+                "0.2.11",
+            )
+
+        # Update the base component map with the custom component map.
+        component_map = {**get_base_component_map(), **props.pop("component_map", {})}
+
         # Get the markdown source.
         # Get the markdown source.
         src = children[0]
         src = children[0]
+
+        # Dedent the source.
         if isinstance(src, str):
         if isinstance(src, str):
             src = textwrap.dedent(src)
             src = textwrap.dedent(src)
-        return super().create(src, **props)
 
 
-    def _get_imports(self):
+        # Create the component.
+        return super().create(src, component_map=component_map, **props)
+
+    def _get_imports(self) -> imports.ImportDict:
         # Import here to avoid circular imports.
         # Import here to avoid circular imports.
         from reflex.components.datadisplay.code import Code, CodeBlock
         from reflex.components.datadisplay.code import Code, CodeBlock
 
 
@@ -100,16 +119,22 @@ class Markdown(Component):
         imports.update(
         imports.update(
             {
             {
                 "": {ImportVar(tag="katex/dist/katex.min.css")},
                 "": {ImportVar(tag="katex/dist/katex.min.css")},
-                "rehype-katex@^6.0.3": {ImportVar(tag="rehypeKatex", is_default=True)},
-                "remark-math@^5.1.1": {ImportVar(tag="remarkMath", is_default=True)},
-                "rehype-raw@^6.1.1": {ImportVar(tag="rehypeRaw", is_default=True)},
-                "remark-gfm@^3.0.1": {ImportVar(tag="remarkGfm", is_default=True)},
+                "remark-math@^5.1.1": {
+                    ImportVar(tag=_REMARK_MATH.name, is_default=True)
+                },
+                "remark-gfm@^3.0.1": {ImportVar(tag=_REMARK_GFM.name, is_default=True)},
+                "rehype-katex@^6.0.3": {
+                    ImportVar(tag=_REHYPE_KATEX.name, is_default=True)
+                },
+                "rehype-raw@^6.1.1": {ImportVar(tag=_REHYPE_RAW.name, is_default=True)},
             }
             }
         )
         )
 
 
         # Get the imports for each component.
         # Get the imports for each component.
-        for component in components_by_tag.values():
-            imports = utils.merge_imports(imports, component()._get_imports())
+        for component in self.component_map.values():
+            imports = utils.merge_imports(
+                imports, component(Var.create("")).get_imports()
+            )
 
 
         # Get the imports for the code components.
         # Get the imports for the code components.
         imports = utils.merge_imports(
         imports = utils.merge_imports(
@@ -118,52 +143,87 @@ class Markdown(Component):
         imports = utils.merge_imports(imports, Code.create()._get_imports())
         imports = utils.merge_imports(imports, Code.create()._get_imports())
         return imports
         return imports
 
 
-    def _render(self):
-        # Import here to avoid circular imports.
-        from reflex.components.datadisplay.code import Code, CodeBlock
-        from reflex.components.tags.tag import Tag
+    def get_component(self, tag: str, **props) -> Component:
+        """Get the component for a tag and props.
 
 
-        def format_props(tag):
-            return "".join(
-                Tag(
-                    name="", props=Style(self.custom_styles.get(tag, {}))
-                ).format_props()
-            )
+        Args:
+            tag: The tag of the component.
+            **props: The props of the component.
+
+        Returns:
+            The component.
+
+        Raises:
+            ValueError: If the tag is invalid.
+        """
+        # Check the tag is valid.
+        if tag not in self.component_map:
+            raise ValueError(f"No markdown component found for tag: {tag}.")
+
+        special_props = {_PROPS}
+        children = [_CHILDREN]
+
+        # If the children are set as a prop, don't pass them as children.
+        children_prop = props.pop("children", None)
+        if children_prop is not None:
+            special_props.add(Var.create_safe(f"children={str(children_prop)}"))
+            children = []
+
+        # Get the component.
+        component = self.component_map[tag](*children, **props).set(
+            special_props=special_props
+        )
+        component._add_style(self.custom_styles.get(tag, {}))
+        return component
 
 
+    def format_component(self, tag: str, **props) -> str:
+        """Format a component for rendering in the component map.
+
+        Args:
+            tag: The tag of the component.
+            **props: Extra props to pass to the component function.
+
+        Returns:
+            The formatted component.
+        """
+        return str(self.get_component(tag, **props)).replace("\n", " ")
+
+    def format_component_map(self) -> dict[str, str]:
+        """Format the component map for rendering.
+
+        Returns:
+            The formatted component map.
+        """
         components = {
         components = {
-            tag: f"{{({{node, ...props}}) => <{(component().tag)} {{...props}} {format_props(tag)} />}}"
-            for tag, component in components_by_tag.items()
+            tag: f"{{({{{_CHILDREN.name}, {_PROPS.name}}}) => {self.format_component(tag)}}}"
+            for tag in self.component_map
         }
         }
+
+        # Separate out inline code and code blocks.
         components[
         components[
             "code"
             "code"
-        ] = f"""{{({{node, inline, className, children, ...props}}) => {{
+        ] = f"""{{({{inline, className, {_CHILDREN.name}, {_PROPS.name}}}) => {{
     const match = (className || '').match(/language-(?<lang>.*)/);
     const match = (className || '').match(/language-(?<lang>.*)/);
+    const language = match ? match[1] : '';
     return !inline ? (
     return !inline ? (
-        <{CodeBlock().tag}
-        children={{String(children).replace(/\n$/, '')}}
-        language={{match ? match[1] : ''}}
-        style={{light}}
-        {{...props}}
-        {format_props("pre")}
-        />
+        {self.format_component("codeblock", language=Var.create_safe("language", is_local=False), children=Var.create_safe("String(children)", is_local=False))}
     ) : (
     ) : (
-        <{Code.create().tag} {{...props}} {format_props("code")}>
-        {{children}}
-        </{Code.create().tag}>
+        {self.format_component("code")}
     );
     );
       }}}}""".replace(
       }}}}""".replace(
             "\n", " "
             "\n", " "
         )
         )
 
 
+        return components
+
+    def _render(self) -> Tag:
         return (
         return (
             super()
             super()
             ._render()
             ._render()
             .add_props(
             .add_props(
-                components=components,
-                remark_plugins=BaseVar(name="[remarkMath, remarkGfm]", type_=List[str]),
-                rehype_plugins=BaseVar(
-                    name="[rehypeKatex, rehypeRaw]", type_=List[str]
-                ),
+                components=self.format_component_map(),
+                remark_plugins=_REMARK_PLUGINS,
+                rehype_plugins=_REHYPE_PLUGINS,
             )
             )
-            .remove_props("custom_components")
+            .remove_props("componentMap")
         )
         )

+ 5 - 5
reflex/components/typography/markdown.pyi

@@ -3,23 +3,23 @@
 # This file was generated by `scripts/pyi_generator.py`!
 # This file was generated by `scripts/pyi_generator.py`!
 # ------------------------------------------------------
 # ------------------------------------------------------
 
 
-from typing import Callable, Dict, List, Optional, Union, overload
+from typing import Any, Callable, Dict, Optional, Union, overload
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventHandler, EventChain, EventSpec
 from reflex.event import EventHandler, EventChain, EventSpec
 
 
-components_by_tag: Dict[str, Callable]
+def get_base_component_map() -> dict[str, Callable]: ...
 
 
 class Markdown(Component):
 class Markdown(Component):
     @overload
     @overload
     @classmethod
     @classmethod
-    def create(cls, *children, lib_dependencies: Optional[List[str]] = None, custom_styles: Optional[Dict[str, Style]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "Markdown":  # type: ignore
+    def create(cls, *children, component_map: Optional[Dict[str, Any]] = None, custom_styles: Optional[Dict[str, Any]] = None, on_blur: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_context_menu: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_double_click: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_focus: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_down: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_enter: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_leave: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_move: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_out: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_over: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_mouse_up: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_scroll: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, on_unmount: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, **props) -> "Markdown":  # type: ignore
         """Create a markdown component.
         """Create a markdown component.
 
 
         Args:
         Args:
             *children: The children of the component.
             *children: The children of the component.
-            lib_dependencies:
-            custom_styles: Custom defined styles for the markdown elements.
+            component_map: The component map from a tag to a lambda that creates a component.
+            custom_styles: Custom styles for the markdown (deprecated in v0.2.9).
             **props: The properties of the component.
             **props: The properties of the component.
 
 
         Returns:
         Returns:

+ 31 - 1
reflex/utils/format.py

@@ -355,7 +355,7 @@ def format_props(*single_props, **key_value_props) -> list[str]:
         f"{name}={format_prop(prop)}"
         f"{name}={format_prop(prop)}"
         for name, prop in sorted(key_value_props.items())
         for name, prop in sorted(key_value_props.items())
         if prop is not None
         if prop is not None
-    ] + [str(prop) for prop in sorted(single_props)]
+    ] + [str(prop) for prop in single_props]
 
 
 
 
 def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
 def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
@@ -574,3 +574,33 @@ def json_dumps(obj: Any) -> str:
         A string
         A string
     """
     """
     return json.dumps(obj, ensure_ascii=False, default=list)
     return json.dumps(obj, ensure_ascii=False, default=list)
+
+
+def unwrap_vars(value: str) -> str:
+    """Unwrap var values from a JSON string.
+
+    For example, "{var}" will be unwrapped to "var".
+
+    Args:
+        value: The JSON string to unwrap.
+
+    Returns:
+        The unwrapped JSON string.
+    """
+
+    def unescape_double_quotes_in_var(m: re.Match) -> str:
+        # Since the outer quotes are removed, the inner escaped quotes must be unescaped.
+        return re.sub('\\\\"', '"', m.group(1))
+
+    # This substitution is necessary to unwrap var values.
+    return re.sub(
+        pattern=r"""
+            (?<!\\)      # must NOT start with a backslash
+            "            # match opening double quote of JSON value
+            {(.*?)}      # extract the value between curly braces (non-greedy)
+            "            # match must end with an unescaped double quote
+        """,
+        repl=unescape_double_quotes_in_var,
+        string=value,
+        flags=re.VERBOSE,
+    )

+ 40 - 26
reflex/utils/serializers.py

@@ -2,12 +2,12 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-import re
 import types as builtin_types
 import types as builtin_types
 from datetime import date, datetime, time, timedelta
 from datetime import date, datetime, time, timedelta
-from typing import Any, Callable, Dict, Type, Union, get_type_hints
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
 
 
-from reflex.utils import exceptions, types
+from reflex.base import Base
+from reflex.utils import exceptions, format, types
 
 
 # Mapping from type to a serializer.
 # Mapping from type to a serializer.
 # The serializer should convert the type to a JSON object.
 # The serializer should convert the type to a JSON object.
@@ -126,6 +126,38 @@ def serialize_str(value: str) -> str:
     return value
     return value
 
 
 
 
+@serializer
+def serialize_primitive(value: Union[bool, int, float, Base, None]) -> str:
+    """Serialize a primitive type.
+
+    Args:
+        value: The number to serialize.
+
+    Returns:
+        The serialized number.
+    """
+    return format.json_dumps(value)
+
+
+@serializer
+def serialize_list(value: Union[List, Tuple, Set]) -> str:
+    """Serialize a list to a JSON string.
+
+    Args:
+        value: The list to serialize.
+
+    Returns:
+        The serialized list.
+    """
+    from reflex.vars import Var
+
+    # Convert any var values to strings.
+    fprop = format.json_dumps([str(v) if isinstance(v, Var) else v for v in value])
+
+    # Unwrap var values.
+    return format.unwrap_vars(fprop)
+
+
 @serializer
 @serializer
 def serialize_dict(prop: Dict[str, Any]) -> str:
 def serialize_dict(prop: Dict[str, Any]) -> str:
     """Serialize a dictionary to a JSON string.
     """Serialize a dictionary to a JSON string.
@@ -141,7 +173,6 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
     """
     """
     # Import here to avoid circular imports.
     # Import here to avoid circular imports.
     from reflex.event import EventHandler
     from reflex.event import EventHandler
-    from reflex.utils.format import json_dumps, to_snake_case
     from reflex.vars import Var
     from reflex.vars import Var
 
 
     prop_dict = {}
     prop_dict = {}
@@ -150,34 +181,17 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
     for key, value in prop.items():
     for key, value in prop.items():
         if types._issubclass(type(value), Callable):
         if types._issubclass(type(value), Callable):
             raise exceptions.InvalidStylePropError(
             raise exceptions.InvalidStylePropError(
-                f"The style prop `{to_snake_case(key)}` cannot have "  # type: ignore
+                f"The style prop `{format.to_snake_case(key)}` cannot have "  # type: ignore
                 f"`{value.fn.__qualname__ if isinstance(value, EventHandler) else value.__qualname__ if isinstance(value, builtin_types.FunctionType) else value}`, "
                 f"`{value.fn.__qualname__ if isinstance(value, EventHandler) else value.__qualname__ if isinstance(value, builtin_types.FunctionType) else value}`, "
                 f"an event handler or callable as its value"
                 f"an event handler or callable as its value"
             )
             )
         prop_dict[key] = str(value) if isinstance(value, Var) else value
         prop_dict[key] = str(value) if isinstance(value, Var) else value
 
 
     # Dump the dict to a string.
     # Dump the dict to a string.
-    fprop = json_dumps(prop_dict)
-
-    def unescape_double_quotes_in_var(m: re.Match) -> str:
-        # Since the outer quotes are removed, the inner escaped quotes must be unescaped.
-        return re.sub('\\\\"', '"', m.group(1))
-
-    # This substitution is necessary to unwrap var values.
-    fprop = re.sub(
-        pattern=r"""
-            (?<!\\)      # must NOT start with a backslash
-            "            # match opening double quote of JSON value
-            {(.*?)}      # extract the value between curly braces (non-greedy)
-            "            # match must end with an unescaped double quote
-        """,
-        repl=unescape_double_quotes_in_var,
-        string=fprop,
-        flags=re.VERBOSE,
-    )
-
-    # Return the formatted dict.
-    return fprop
+    fprop = format.json_dumps(prop_dict)
+
+    # Unwrap var values.
+    return format.unwrap_vars(fprop)
 
 
 
 
 @serializer
 @serializer

+ 1 - 1
reflex/utils/types.py

@@ -104,7 +104,7 @@ def _issubclass(cls: GenericType, cls_check: GenericType) -> bool:
     # Special check for Any.
     # Special check for Any.
     if cls_check == Any:
     if cls_check == Any:
         return True
         return True
-    if cls in [Any, Callable]:
+    if cls in [Any, Callable, None]:
         return False
         return False
 
 
     # Get the base classes.
     # Get the base classes.

+ 104 - 16
reflex/vars.py

@@ -122,19 +122,14 @@ class Var(ABC):
         if isinstance(value, Var):
         if isinstance(value, Var):
             return value
             return value
 
 
-        type_ = type(value)
-
         # Try to serialize the value.
         # Try to serialize the value.
-        serialized = serialize(value)
-        if serialized is not None:
-            value = serialized
-
-        try:
-            name = value if isinstance(value, str) else json.dumps(value)
-        except TypeError as e:
+        type_ = type(value)
+        name = serialize(value)
+        if name is None:
             raise TypeError(
             raise TypeError(
                 f"No JSON serializer found for var {value} of type {type_}."
                 f"No JSON serializer found for var {value} of type {type_}."
-            ) from e
+            )
+        name = name if isinstance(name, str) else format.json_dumps(name)
 
 
         return BaseVar(name=name, type_=type_, is_local=is_local, is_string=is_string)
         return BaseVar(name=name, type_=type_, is_local=is_local, is_string=is_string)
 
 
@@ -202,13 +197,17 @@ class Var(ABC):
             and self.is_local == other.is_local
             and self.is_local == other.is_local
         )
         )
 
 
-    def to_string(self) -> Var:
+    def to_string(self, json: bool = True) -> Var:
         """Convert a var to a string.
         """Convert a var to a string.
 
 
+        Args:
+            json: Whether to convert to a JSON string.
+
         Returns:
         Returns:
             The stringified var.
             The stringified var.
         """
         """
-        return self.operation(fn="JSON.stringify", type_=str)
+        fn = "JSON.stringify" if json else "String"
+        return self.operation(fn=fn, type_=str)
 
 
     def __hash__(self) -> int:
     def __hash__(self) -> int:
         """Define a hash function for a var.
         """Define a hash function for a var.
@@ -945,9 +944,7 @@ class Var(ABC):
         Returns:
         Returns:
             A var representing the contain check.
             A var representing the contain check.
         """
         """
-        if self.type_ is None or not (
-            types._issubclass(self.type_, Union[dict, list, tuple, str])
-        ):
+        if not (types._issubclass(self.type_, Union[dict, list, tuple, str])):
             raise TypeError(
             raise TypeError(
                 f"Var {self.full_name} of type {self.type_} does not support contains check."
                 f"Var {self.full_name} of type {self.type_} does not support contains check."
             )
             )
@@ -987,7 +984,7 @@ class Var(ABC):
         Returns:
         Returns:
             A var with the reversed list.
             A var with the reversed list.
         """
         """
-        if self.type_ is None or not types._issubclass(self.type_, list):
+        if not types._issubclass(self.type_, list):
             raise TypeError(f"Cannot reverse non-list var {self.full_name}.")
             raise TypeError(f"Cannot reverse non-list var {self.full_name}.")
 
 
         return BaseVar(
         return BaseVar(
@@ -996,6 +993,97 @@ class Var(ABC):
             is_local=self.is_local,
             is_local=self.is_local,
         )
         )
 
 
+    def lower(self) -> Var:
+        """Convert a string var to lowercase.
+
+        Returns:
+            A var with the lowercase string.
+
+        Raises:
+            TypeError: If the var is not a string.
+        """
+        if not types._issubclass(self.type_, str):
+            raise TypeError(
+                f"Cannot convert non-string var {self.full_name} to lowercase."
+            )
+
+        return BaseVar(
+            name=f"{self.full_name}.toLowerCase()",
+            type_=str,
+            is_local=self.is_local,
+        )
+
+    def upper(self) -> Var:
+        """Convert a string var to uppercase.
+
+        Returns:
+            A var with the uppercase string.
+
+        Raises:
+            TypeError: If the var is not a string.
+        """
+        if not types._issubclass(self.type_, str):
+            raise TypeError(
+                f"Cannot convert non-string var {self.full_name} to uppercase."
+            )
+
+        return BaseVar(
+            name=f"{self.full_name}.toUpperCase()",
+            type_=str,
+            is_local=self.is_local,
+        )
+
+    def split(self, other: str | Var[str] = " ") -> Var:
+        """Split a string var into a list.
+
+        Args:
+            other: The string to split the var with.
+
+        Returns:
+            A var with the list.
+
+        Raises:
+            TypeError: If the var is not a string.
+        """
+        if not types._issubclass(self.type_, str):
+            raise TypeError(f"Cannot split non-string var {self.full_name}.")
+
+        other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
+
+        return BaseVar(
+            name=f"{self.full_name}.split({other.full_name})",
+            type_=list[str],
+            is_local=self.is_local,
+        )
+
+    def join(self, other: str | Var[str] | None = None) -> Var:
+        """Join a list var into a string.
+
+        Args:
+            other: The string to join the list with.
+
+        Returns:
+            A var with the string.
+
+        Raises:
+            TypeError: If the var is not a list.
+        """
+        if not types._issubclass(self.type_, list):
+            raise TypeError(f"Cannot join non-list var {self.full_name}.")
+
+        if other is None:
+            other = Var.create_safe("")
+        if isinstance(other, str):
+            other = Var.create_safe(json.dumps(other))
+        else:
+            other = Var.create_safe(other)
+
+        return BaseVar(
+            name=f"{self.full_name}.join({other.full_name})",
+            type_=str,
+            is_local=self.is_local,
+        )
+
     def foreach(self, fn: Callable) -> Var:
     def foreach(self, fn: Callable) -> Var:
         """Return a list of components. after doing a foreach on this var.
         """Return a list of components. after doing a foreach on this var.
 
 

+ 1 - 1
scripts/pyi_generator.py

@@ -233,7 +233,7 @@ class PyiGenerator:
         local_variables = [
         local_variables = [
             (name, obj)
             (name, obj)
             for name, obj in vars(self.current_module).items()
             for name, obj in vars(self.current_module).items()
-            if not name.startswith("__")
+            if not name.startswith("_")
             and not inspect.isclass(obj)
             and not inspect.isclass(obj)
             and not inspect.isfunction(obj)
             and not inspect.isfunction(obj)
         ]
         ]

+ 0 - 0
tests/components/typography/__init__.py


+ 59 - 0
tests/components/typography/test_markdown.py

@@ -0,0 +1,59 @@
+import pytest
+
+import reflex as rx
+from reflex.components.typography.markdown import Markdown
+
+
+@pytest.mark.parametrize(
+    "tag,expected",
+    [
+        ("h1", "Heading"),
+        ("h2", "Heading"),
+        ("h3", "Heading"),
+        ("h4", "Heading"),
+        ("h5", "Heading"),
+        ("h6", "Heading"),
+        ("p", "Text"),
+        ("ul", "UnorderedList"),
+        ("ol", "OrderedList"),
+        ("li", "ListItem"),
+        ("a", "Link"),
+        ("code", "Code"),
+    ],
+)
+def test_get_component(tag, expected):
+    """Test getting a component from the component map.
+
+    Args:
+        tag: The tag to get.
+        expected: The expected component.
+    """
+    md = Markdown.create("# Hello")
+    assert tag in md.component_map  # type: ignore
+    assert md.get_component(tag).tag == expected  # type: ignore
+
+
+def test_set_component_map():
+    """Test setting the component map."""
+    component_map = {
+        "h1": lambda value: rx.box(
+            rx.heading(value, as_="h1", size="2xl"), padding="1em"
+        ),
+        "p": lambda value: rx.box(rx.text(value), padding="1em"),
+    }
+    md = Markdown.create("# Hello", component_map=component_map)
+
+    # Check that the new tags have been added.
+    assert md.get_component("h1").tag == "Box"  # type: ignore
+    assert md.get_component("p").tag == "Box"  # type: ignore
+
+    # Make sure the old tags are still there.
+    assert md.get_component("h2").tag == "Heading"  # type: ignore
+
+
+def test_pass_custom_styles():
+    """Test that passing custom styles works."""
+    md = Markdown.create("# Hello", custom_styles={"h1": {"color": "red"}})
+
+    comp = md.get_component("h1")  # type: ignore
+    assert comp.style == {"color": "red"}

+ 46 - 14
tests/utils/test_serializers.py

@@ -1,9 +1,10 @@
 import datetime
 import datetime
-from typing import Any, Dict, Type
+from typing import Any, Dict, List, Type
 
 
 import pytest
 import pytest
 
 
 from reflex.utils import serializers
 from reflex.utils import serializers
+from reflex.vars import Var
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -29,12 +30,19 @@ def test_has_serializer(type_: Type, expected: bool):
     "type_,expected",
     "type_,expected",
     [
     [
         (str, serializers.serialize_str),
         (str, serializers.serialize_str),
+        (list, serializers.serialize_list),
+        (tuple, serializers.serialize_list),
+        (set, serializers.serialize_list),
         (dict, serializers.serialize_dict),
         (dict, serializers.serialize_dict),
+        (List[str], serializers.serialize_list),
         (Dict[int, int], serializers.serialize_dict),
         (Dict[int, int], serializers.serialize_dict),
         (datetime.datetime, serializers.serialize_datetime),
         (datetime.datetime, serializers.serialize_datetime),
         (datetime.date, serializers.serialize_datetime),
         (datetime.date, serializers.serialize_datetime),
         (datetime.time, serializers.serialize_datetime),
         (datetime.time, serializers.serialize_datetime),
         (datetime.timedelta, serializers.serialize_datetime),
         (datetime.timedelta, serializers.serialize_datetime),
+        (int, serializers.serialize_primitive),
+        (float, serializers.serialize_primitive),
+        (bool, serializers.serialize_primitive),
     ],
     ],
 )
 )
 def test_get_serializer(type_: Type, expected: serializers.Serializer):
 def test_get_serializer(type_: Type, expected: serializers.Serializer):
@@ -51,8 +59,14 @@ def test_get_serializer(type_: Type, expected: serializers.Serializer):
 def test_add_serializer():
 def test_add_serializer():
     """Test that adding a serializer works."""
     """Test that adding a serializer works."""
 
 
-    def serialize_test(value: int) -> str:
-        """Serialize an int to a string.
+    class Foo:
+        """A test class."""
+
+        def __init__(self, name: str):
+            self.name = name
+
+    def serialize_foo(value: Foo) -> str:
+        """Serialize an foo to a string.
 
 
         Args:
         Args:
             value: The value to serialize.
             value: The value to serialize.
@@ -60,35 +74,53 @@ def test_add_serializer():
         Returns:
         Returns:
             The serialized value.
             The serialized value.
         """
         """
-        return str(value)
+        return value.name
 
 
     # Initially there should be no serializer for int.
     # Initially there should be no serializer for int.
-    assert not serializers.has_serializer(int)
-    assert serializers.serialize(5) is None
+    assert not serializers.has_serializer(Foo)
+    assert serializers.serialize(Foo("hi")) is None
 
 
     # Register the serializer.
     # Register the serializer.
-    assert serializers.serializer(serialize_test) == serialize_test
+    assert serializers.serializer(serialize_foo) == serialize_foo
 
 
     # There should now be a serializer for int.
     # There should now be a serializer for int.
-    assert serializers.has_serializer(int)
-    assert serializers.get_serializer(int) == serialize_test
-    assert serializers.serialize(5) == "5"
+    assert serializers.has_serializer(Foo)
+    assert serializers.get_serializer(Foo) == serialize_foo
+    assert serializers.serialize(Foo("hi")) == "hi"
 
 
     # Remove the serializer.
     # Remove the serializer.
-    serializers.SERIALIZERS.pop(int)
+    serializers.SERIALIZERS.pop(Foo)
+    assert not serializers.has_serializer(Foo)
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "value,expected",
     "value,expected",
     [
     [
         ("test", "test"),
         ("test", "test"),
+        (1, "1"),
+        (1.0, "1.0"),
+        (True, "true"),
+        (False, "false"),
+        (None, "null"),
+        ([1, 2, 3], "[1, 2, 3]"),
+        ([1, "2", 3.0], '[1, "2", 3.0]'),
+        (
+            [1, Var.create_safe("hi"), Var.create_safe("bye", is_local=False)],
+            '[1, "hi", bye]',
+        ),
+        (
+            (1, Var.create_safe("hi"), Var.create_safe("bye", is_local=False)),
+            '[1, "hi", bye]',
+        ),
+        ({1: 2, 3: 4}, '{"1": 2, "3": 4}'),
+        (
+            {1: Var.create_safe("hi"), 3: Var.create_safe("bye", is_local=False)},
+            '{"1": "hi", "3": bye}',
+        ),
         (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"),
         (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"),
         (datetime.date(2021, 1, 1), "2021-01-01"),
         (datetime.date(2021, 1, 1), "2021-01-01"),
         (datetime.time(1, 1, 1, 1), "01:01:01.000001"),
         (datetime.time(1, 1, 1, 1), "01:01:01.000001"),
         (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
         (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
-        (5, None),
-        (None, None),
-        ([], None),
     ],
     ],
 )
 )
 def test_serialize(value: Any, expected: str):
 def test_serialize(value: Any, expected: str):