|
@@ -214,7 +214,9 @@ def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
|
|
|
return res
|
|
|
|
|
|
|
|
|
-def _generate_imports(typing_imports: Iterable[str]) -> list[ast.ImportFrom]:
|
|
|
+def _generate_imports(
|
|
|
+ typing_imports: Iterable[str],
|
|
|
+) -> list[ast.ImportFrom | ast.Import]:
|
|
|
"""Generate the import statements for the stub file.
|
|
|
|
|
|
Args:
|
|
@@ -228,6 +230,7 @@ def _generate_imports(typing_imports: Iterable[str]) -> list[ast.ImportFrom]:
|
|
|
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
|
|
|
for name, values in DEFAULT_IMPORTS.items()
|
|
|
],
|
|
|
+ ast.Import([ast.alias("reflex")]),
|
|
|
]
|
|
|
|
|
|
|
|
@@ -372,12 +375,13 @@ def _extract_class_props_as_ast_nodes(
|
|
|
return kwargs
|
|
|
|
|
|
|
|
|
-def type_to_ast(typ) -> ast.AST:
|
|
|
+def type_to_ast(typ, cls: type) -> ast.AST:
|
|
|
"""Converts any type annotation into its AST representation.
|
|
|
Handles nested generic types, unions, etc.
|
|
|
|
|
|
Args:
|
|
|
typ: The type annotation to convert.
|
|
|
+ cls: The class where the type annotation is used.
|
|
|
|
|
|
Returns:
|
|
|
The AST representation of the type annotation.
|
|
@@ -390,6 +394,16 @@ def type_to_ast(typ) -> ast.AST:
|
|
|
# Handle plain types (int, str, custom classes, etc.)
|
|
|
if origin is None:
|
|
|
if hasattr(typ, "__name__"):
|
|
|
+ if typ.__module__.startswith("reflex."):
|
|
|
+ typ_parts = typ.__module__.split(".")
|
|
|
+ cls_parts = cls.__module__.split(".")
|
|
|
+
|
|
|
+ zipped = list(zip(typ_parts, cls_parts, strict=False))
|
|
|
+
|
|
|
+ if all(a == b for a, b in zipped) and len(typ_parts) == len(cls_parts):
|
|
|
+ return ast.Name(id=typ.__name__)
|
|
|
+
|
|
|
+ return ast.Name(id=typ.__module__ + "." + typ.__name__)
|
|
|
return ast.Name(id=typ.__name__)
|
|
|
elif hasattr(typ, "_name"):
|
|
|
return ast.Name(id=typ._name)
|
|
@@ -406,7 +420,7 @@ def type_to_ast(typ) -> ast.AST:
|
|
|
return ast.Name(id=base_name)
|
|
|
|
|
|
# Convert all type arguments recursively
|
|
|
- arg_nodes = [type_to_ast(arg) for arg in args]
|
|
|
+ arg_nodes = [type_to_ast(arg, cls) for arg in args]
|
|
|
|
|
|
# Special case for single-argument types (like List[T] or Optional[T])
|
|
|
if len(arg_nodes) == 1:
|
|
@@ -487,7 +501,7 @@ def _generate_component_create_functiondef(
|
|
|
]
|
|
|
|
|
|
# Convert each argument type to its AST representation
|
|
|
- type_args = [type_to_ast(arg) for arg in arguments_without_var]
|
|
|
+ type_args = [type_to_ast(arg, cls=clz) for arg in arguments_without_var]
|
|
|
|
|
|
# Join the type arguments with commas for EventType
|
|
|
args_str = ", ".join(ast.unparse(arg) for arg in type_args)
|