Browse Source

validate classname (#5204)

* validate classname

* add subclass check

* has_var = bool???

what did i do

Co-authored-by: Masen Furer <m_github@0x26.net>

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Khaleel Al-Adhami 4 tuần trước cách đây
mục cha
commit
70ab9dc0dc
1 tập tin đã thay đổi với 35 bổ sung9 xóa
  1. 35 9
      reflex/components/component.py

+ 35 - 9
reflex/components/component.py

@@ -50,14 +50,15 @@ from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_
 from reflex.vars import VarData
 from reflex.vars.base import (
     CachedVarOperation,
+    LiteralNoneVar,
     LiteralVar,
     Var,
     cached_property_no_lock,
 )
 from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
 from reflex.vars.number import ternary_operation
-from reflex.vars.object import ObjectVar
-from reflex.vars.sequence import LiteralArrayVar
+from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import LiteralArrayVar, LiteralStringVar, StringVar
 
 
 class BaseComponent(Base, ABC):
@@ -598,13 +599,36 @@ class Component(BaseComponent, ABC):
         # Convert class_name to str if it's list
         class_name = kwargs.get("class_name", "")
         if isinstance(class_name, (list, tuple)):
-            if any(isinstance(c, Var) for c in class_name):
+            has_var = False
+            for c in class_name:
+                if isinstance(c, str):
+                    continue
+                if isinstance(c, Var):
+                    if not isinstance(c, StringVar) and not issubclass(
+                        c._var_type, str
+                    ):
+                        raise TypeError(
+                            f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c._js_expr} of type {c._var_type}."
+                        )
+                    has_var = True
+                else:
+                    raise TypeError(
+                        f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {c} of type {type(c)}."
+                    )
+            if has_var:
                 kwargs["class_name"] = LiteralArrayVar.create(
                     class_name, _var_type=list[str]
                 ).join(" ")
             else:
                 kwargs["class_name"] = " ".join(class_name)
-
+        elif (
+            isinstance(class_name, Var)
+            and not isinstance(class_name, StringVar)
+            and not issubclass(class_name._var_type, str)
+        ):
+            raise TypeError(
+                f"Invalid class_name passed for prop {type(self).__name__}.class_name, expected type str, got value {class_name._js_expr} of type {class_name._var_type}."
+            )
         # Construct the component.
         for key, value in kwargs.items():
             setattr(self, key, value)
@@ -1146,7 +1170,7 @@ class Component(BaseComponent, ABC):
                 vars.append(comp_prop)
             elif isinstance(comp_prop, str):
                 # Collapse VarData encoded in f-strings.
-                var = LiteralVar.create(comp_prop)
+                var = LiteralStringVar.create(comp_prop)
                 if var._get_all_var_data() is not None:
                     vars.append(var)
 
@@ -2494,7 +2518,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
         return Var.create(tag)
 
     if "iterable" in tag:
-        function_return = Var.create(
+        function_return = LiteralArrayVar.create(
             [
                 render_dict_to_var(child.render(), imported_names)
                 for child in tag["children"]
@@ -2537,7 +2561,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
             render_dict_to_var(tag["true_value"], imported_names),
             render_dict_to_var(tag["false_value"], imported_names)
             if tag["false_value"] is not None
-            else Var.create(None),
+            else LiteralNoneVar.create(),
         )
 
     props = {}
@@ -2553,7 +2577,9 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
         value = prop_str[prop + 2 : -1]
         props[key] = value
 
-    props = Var.create({Var.create(k): Var(v) for k, v in props.items()})
+    props = LiteralObjectVar.create(
+        {LiteralStringVar.create(k): Var(v) for k, v in props.items()}
+    )
 
     for prop in special_props:
         props = props.merge(prop)
@@ -2564,7 +2590,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
     tag_name = Var(raw_tag_name or "Fragment")
 
     tag_name = (
-        Var.create(raw_tag_name)
+        LiteralStringVar.create(raw_tag_name)
         if raw_tag_name
         and raw_tag_name.split(".")[0] not in imported_names
         and raw_tag_name.lower() == raw_tag_name