1
0

serializers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. """Serializers used to convert Var types to JSON strings."""
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import decimal
  6. import functools
  7. import inspect
  8. import json
  9. import warnings
  10. from collections.abc import Callable, Sequence
  11. from datetime import date, datetime, time, timedelta
  12. from enum import Enum
  13. from pathlib import Path
  14. from typing import Any, Literal, TypeVar, get_type_hints, overload
  15. from uuid import UUID
  16. from pydantic import BaseModel as BaseModelV2
  17. from pydantic.v1 import BaseModel as BaseModelV1
  18. from reflex.base import Base
  19. from reflex.constants.colors import Color, format_color
  20. from reflex.utils import console, types
  21. # Mapping from type to a serializer.
  22. # The serializer should convert the type to a JSON object.
  23. SerializedType = str | bool | int | float | list | dict | None
  24. Serializer = Callable[[Any], SerializedType]
  25. SERIALIZERS: dict[type, Serializer] = {}
  26. SERIALIZER_TYPES: dict[type, type] = {}
  27. SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
  28. @overload
  29. def serializer(
  30. fn: None = None,
  31. to: type[SerializedType] | None = None,
  32. overwrite: bool | None = None,
  33. ) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
  34. @overload
  35. def serializer(
  36. fn: SERIALIZED_FUNCTION,
  37. to: type[SerializedType] | None = None,
  38. overwrite: bool | None = None,
  39. ) -> SERIALIZED_FUNCTION: ...
  40. def serializer(
  41. fn: SERIALIZED_FUNCTION | None = None,
  42. to: Any = None,
  43. overwrite: bool | None = None,
  44. ) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
  45. """Decorator to add a serializer for a given type.
  46. Args:
  47. fn: The function to decorate.
  48. to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string.
  49. overwrite: Whether to overwrite the existing serializer.
  50. Returns:
  51. The decorated function.
  52. """
  53. def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
  54. # Check the type hints to get the type of the argument.
  55. type_hints = get_type_hints(fn)
  56. args = [arg for arg in type_hints if arg != "return"]
  57. # Make sure the function takes a single argument.
  58. if len(args) != 1:
  59. raise ValueError("Serializer must take a single argument.")
  60. # Get the type of the argument.
  61. type_ = type_hints[args[0]]
  62. # Make sure the type is not already registered.
  63. registered_fn = SERIALIZERS.get(type_)
  64. if registered_fn is not None and registered_fn != fn and overwrite is not True:
  65. message = f"Overwriting serializer for type {type_} from {registered_fn.__module__}:{registered_fn.__qualname__} to {fn.__module__}:{fn.__qualname__}."
  66. if overwrite is False:
  67. raise ValueError(message)
  68. caller_frame = next(
  69. filter(
  70. lambda frame: frame.filename != __file__,
  71. inspect.getouterframes(inspect.currentframe()),
  72. ),
  73. None,
  74. )
  75. file_info = (
  76. f"(at {caller_frame.filename}:{caller_frame.lineno})"
  77. if caller_frame
  78. else ""
  79. )
  80. console.warn(
  81. f"{message} Call rx.serializer with `overwrite=True` if this is intentional. {file_info}"
  82. )
  83. to_type = to or type_hints.get("return")
  84. # Apply type transformation if requested
  85. if to_type:
  86. SERIALIZER_TYPES[type_] = to_type
  87. get_serializer_type.cache_clear()
  88. # Register the serializer.
  89. SERIALIZERS[type_] = fn
  90. get_serializer.cache_clear()
  91. # Return the function.
  92. return fn
  93. if fn is not None:
  94. return wrapper(fn)
  95. return wrapper
  96. @overload
  97. def serialize(
  98. value: Any, get_type: Literal[True]
  99. ) -> tuple[SerializedType | None, types.GenericType | None]: ...
  100. @overload
  101. def serialize(value: Any, get_type: Literal[False]) -> SerializedType | None: ...
  102. @overload
  103. def serialize(value: Any) -> SerializedType | None: ...
  104. def serialize(
  105. value: Any, get_type: bool = False
  106. ) -> SerializedType | None | tuple[SerializedType | None, types.GenericType | None]:
  107. """Serialize the value to a JSON string.
  108. Args:
  109. value: The value to serialize.
  110. get_type: Whether to return the type of the serialized value.
  111. Returns:
  112. The serialized value, or None if a serializer is not found.
  113. """
  114. # Get the serializer for the type.
  115. serializer = get_serializer(type(value))
  116. # If there is no serializer, return None.
  117. if serializer is None:
  118. if dataclasses.is_dataclass(value) and not isinstance(value, type):
  119. return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}
  120. if get_type:
  121. return None, None
  122. return None
  123. # Serialize the value.
  124. serialized = serializer(value)
  125. # Return the serialized value and the type.
  126. if get_type:
  127. return serialized, get_serializer_type(type(value))
  128. else:
  129. return serialized
  130. @functools.lru_cache
  131. def get_serializer(type_: type) -> Serializer | None:
  132. """Get the serializer for the type.
  133. Args:
  134. type_: The type to get the serializer for.
  135. Returns:
  136. The serializer for the type, or None if there is no serializer.
  137. """
  138. # First, check if the type is registered.
  139. serializer = SERIALIZERS.get(type_)
  140. if serializer is not None:
  141. return serializer
  142. # If the type is not registered, check if it is a subclass of a registered type.
  143. for registered_type, serializer in reversed(SERIALIZERS.items()):
  144. if types._issubclass(type_, registered_type):
  145. return serializer
  146. # If there is no serializer, return None.
  147. return None
  148. @functools.lru_cache
  149. def get_serializer_type(type_: type) -> type | None:
  150. """Get the converted type for the type after serializing.
  151. Args:
  152. type_: The type to get the serializer type for.
  153. Returns:
  154. The serialized type for the type, or None if there is no type conversion registered.
  155. """
  156. # First, check if the type is registered.
  157. serializer = SERIALIZER_TYPES.get(type_)
  158. if serializer is not None:
  159. return serializer
  160. # If the type is not registered, check if it is a subclass of a registered type.
  161. for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
  162. if types._issubclass(type_, registered_type):
  163. return serializer
  164. # If there is no serializer, return None.
  165. return None
  166. def has_serializer(type_: type, into_type: type | None = None) -> bool:
  167. """Check if there is a serializer for the type.
  168. Args:
  169. type_: The type to check.
  170. into_type: The type to serialize into.
  171. Returns:
  172. Whether there is a serializer for the type.
  173. """
  174. serializer_for_type = get_serializer(type_)
  175. return serializer_for_type is not None and (
  176. into_type is None or get_serializer_type(type_) == into_type
  177. )
  178. def can_serialize(type_: type, into_type: type | None = None) -> bool:
  179. """Check if there is a serializer for the type.
  180. Args:
  181. type_: The type to check.
  182. into_type: The type to serialize into.
  183. Returns:
  184. Whether there is a serializer for the type.
  185. """
  186. return has_serializer(type_, into_type) or (
  187. isinstance(type_, type)
  188. and dataclasses.is_dataclass(type_)
  189. and (into_type is None or into_type is dict)
  190. )
  191. @serializer(to=str)
  192. def serialize_type(value: type) -> str:
  193. """Serialize a python type.
  194. Args:
  195. value: the type to serialize.
  196. Returns:
  197. The serialized type.
  198. """
  199. return value.__name__
  200. @serializer(to=dict)
  201. def serialize_base(value: Base) -> dict:
  202. """Serialize a Base instance.
  203. Args:
  204. value : The Base to serialize.
  205. Returns:
  206. The serialized Base.
  207. """
  208. from reflex.vars.base import Var
  209. return {
  210. k: v for k, v in value.dict().items() if isinstance(v, Var) or not callable(v)
  211. }
  212. @serializer(to=dict)
  213. def serialize_base_model_v1(model: BaseModelV1) -> dict:
  214. """Serialize a pydantic v1 BaseModel instance.
  215. Args:
  216. model: The BaseModel to serialize.
  217. Returns:
  218. The serialized BaseModel.
  219. """
  220. return model.dict()
  221. if BaseModelV1 is not BaseModelV2:
  222. @serializer(to=dict)
  223. def serialize_base_model_v2(model: BaseModelV2) -> dict:
  224. """Serialize a pydantic v2 BaseModel instance.
  225. Args:
  226. model: The BaseModel to serialize.
  227. Returns:
  228. The serialized BaseModel.
  229. """
  230. return model.model_dump()
  231. @serializer
  232. def serialize_set(value: set) -> list:
  233. """Serialize a set to a JSON serializable list.
  234. Args:
  235. value: The set to serialize.
  236. Returns:
  237. The serialized list.
  238. """
  239. return list(value)
  240. @serializer
  241. def serialize_sequence(value: Sequence) -> list:
  242. """Serialize a sequence to a JSON serializable list.
  243. Args:
  244. value: The sequence to serialize.
  245. Returns:
  246. The serialized list.
  247. """
  248. return list(value)
  249. @serializer(to=str)
  250. def serialize_datetime(dt: date | datetime | time | timedelta) -> str:
  251. """Serialize a datetime to a JSON string.
  252. Args:
  253. dt: The datetime to serialize.
  254. Returns:
  255. The serialized datetime.
  256. """
  257. return str(dt)
  258. @serializer(to=str)
  259. def serialize_path(path: Path) -> str:
  260. """Serialize a pathlib.Path to a JSON string.
  261. Args:
  262. path: The path to serialize.
  263. Returns:
  264. The serialized path.
  265. """
  266. return str(path.as_posix())
  267. @serializer
  268. def serialize_enum(en: Enum) -> str:
  269. """Serialize a enum to a JSON string.
  270. Args:
  271. en: The enum to serialize.
  272. Returns:
  273. The serialized enum.
  274. """
  275. return en.value
  276. @serializer(to=str)
  277. def serialize_uuid(uuid: UUID) -> str:
  278. """Serialize a UUID to a JSON string.
  279. Args:
  280. uuid: The UUID to serialize.
  281. Returns:
  282. The serialized UUID.
  283. """
  284. return str(uuid)
  285. @serializer(to=float)
  286. def serialize_decimal(value: decimal.Decimal) -> float:
  287. """Serialize a Decimal to a float.
  288. Args:
  289. value: The Decimal to serialize.
  290. Returns:
  291. The serialized Decimal as a float.
  292. """
  293. return float(value)
  294. @serializer(to=str)
  295. def serialize_color(color: Color) -> str:
  296. """Serialize a color.
  297. Args:
  298. color: The color to serialize.
  299. Returns:
  300. The serialized color.
  301. """
  302. return format_color(color.color, color.shade, color.alpha)
  303. with contextlib.suppress(ImportError):
  304. from pandas import DataFrame
  305. def format_dataframe_values(df: DataFrame) -> list[list[Any]]:
  306. """Format dataframe values to a list of lists.
  307. Args:
  308. df: The dataframe to format.
  309. Returns:
  310. The dataframe as a list of lists.
  311. """
  312. return [
  313. [str(d) if isinstance(d, (list, tuple)) else d for d in data]
  314. for data in list(df.values.tolist())
  315. ]
  316. @serializer
  317. def serialize_dataframe(df: DataFrame) -> dict:
  318. """Serialize a pandas dataframe.
  319. Args:
  320. df: The dataframe to serialize.
  321. Returns:
  322. The serialized dataframe.
  323. """
  324. return {
  325. "columns": df.columns.tolist(),
  326. "data": format_dataframe_values(df),
  327. }
  328. with contextlib.suppress(ImportError):
  329. from plotly.graph_objects import Figure, layout
  330. from plotly.io import to_json
  331. @serializer
  332. def serialize_figure(figure: Figure) -> dict:
  333. """Serialize a plotly figure.
  334. Args:
  335. figure: The figure to serialize.
  336. Returns:
  337. The serialized figure.
  338. """
  339. return json.loads(str(to_json(figure)))
  340. @serializer
  341. def serialize_template(template: layout.Template) -> dict:
  342. """Serialize a plotly template.
  343. Args:
  344. template: The template to serialize.
  345. Returns:
  346. The serialized template.
  347. """
  348. return {
  349. "data": json.loads(str(to_json(template.data))),
  350. "layout": json.loads(str(to_json(template.layout))),
  351. }
  352. with contextlib.suppress(ImportError):
  353. import base64
  354. import io
  355. from PIL.Image import MIME
  356. from PIL.Image import Image as Img
  357. @serializer
  358. def serialize_image(image: Img) -> str:
  359. """Serialize a plotly figure.
  360. Args:
  361. image: The image to serialize.
  362. Returns:
  363. The serialized image.
  364. """
  365. buff = io.BytesIO()
  366. image_format = getattr(image, "format", None) or "PNG"
  367. image.save(buff, format=image_format)
  368. image_bytes = buff.getvalue()
  369. base64_image = base64.b64encode(image_bytes).decode("utf-8")
  370. try:
  371. # Newer method to get the mime type, but does not always work.
  372. mime_type = image.get_format_mimetype() # pyright: ignore [reportAttributeAccessIssue]
  373. except AttributeError:
  374. try:
  375. # Fallback method
  376. mime_type = MIME[image_format]
  377. except KeyError:
  378. # Unknown mime_type: warn and return image/png and hope the browser can sort it out.
  379. warnings.warn( # noqa: B028
  380. f"Unknown mime type for {image} {image_format}. Defaulting to image/png"
  381. )
  382. mime_type = "image/png"
  383. return f"data:{mime_type};base64,{base64_image}"