test_serializers.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import datetime
  2. from typing import Any, Dict, List, Type
  3. import pytest
  4. from reflex.utils import serializers
  5. from reflex.vars import Var
  6. @pytest.mark.parametrize(
  7. "type_,expected",
  8. [
  9. (str, True),
  10. (dict, True),
  11. (Dict[int, int], True),
  12. ],
  13. )
  14. def test_has_serializer(type_: Type, expected: bool):
  15. """Test that has_serializer returns the correct value.
  16. Args:
  17. type_: The type to check.
  18. expected: The expected result.
  19. """
  20. assert serializers.has_serializer(type_) == expected
  21. @pytest.mark.parametrize(
  22. "type_,expected",
  23. [
  24. (str, serializers.serialize_str),
  25. (list, serializers.serialize_list),
  26. (tuple, serializers.serialize_list),
  27. (set, serializers.serialize_list),
  28. (dict, serializers.serialize_dict),
  29. (List[str], serializers.serialize_list),
  30. (Dict[int, int], serializers.serialize_dict),
  31. (datetime.datetime, serializers.serialize_datetime),
  32. (datetime.date, serializers.serialize_datetime),
  33. (datetime.time, serializers.serialize_datetime),
  34. (datetime.timedelta, serializers.serialize_datetime),
  35. (int, serializers.serialize_primitive),
  36. (float, serializers.serialize_primitive),
  37. (bool, serializers.serialize_primitive),
  38. ],
  39. )
  40. def test_get_serializer(type_: Type, expected: serializers.Serializer):
  41. """Test that get_serializer returns the correct value.
  42. Args:
  43. type_: The type to check.
  44. expected: The expected result.
  45. """
  46. assert serializers.get_serializer(type_) == expected
  47. def test_add_serializer():
  48. """Test that adding a serializer works."""
  49. class Foo:
  50. """A test class."""
  51. def __init__(self, name: str):
  52. self.name = name
  53. def serialize_foo(value: Foo) -> str:
  54. """Serialize an foo to a string.
  55. Args:
  56. value: The value to serialize.
  57. Returns:
  58. The serialized value.
  59. """
  60. return value.name
  61. # Initially there should be no serializer for int.
  62. assert not serializers.has_serializer(Foo)
  63. assert serializers.serialize(Foo("hi")) is None
  64. # Register the serializer.
  65. assert serializers.serializer(serialize_foo) == serialize_foo
  66. # There should now be a serializer for int.
  67. assert serializers.has_serializer(Foo)
  68. assert serializers.get_serializer(Foo) == serialize_foo
  69. assert serializers.serialize(Foo("hi")) == "hi"
  70. # Remove the serializer.
  71. serializers.SERIALIZERS.pop(Foo)
  72. assert not serializers.has_serializer(Foo)
  73. @pytest.mark.parametrize(
  74. "value,expected",
  75. [
  76. ("test", "test"),
  77. (1, "1"),
  78. (1.0, "1.0"),
  79. (True, "true"),
  80. (False, "false"),
  81. (None, "null"),
  82. ([1, 2, 3], "[1, 2, 3]"),
  83. ([1, "2", 3.0], '[1, "2", 3.0]'),
  84. (
  85. [1, Var.create_safe("hi"), Var.create_safe("bye", is_local=False)],
  86. '[1, "hi", bye]',
  87. ),
  88. (
  89. (1, Var.create_safe("hi"), Var.create_safe("bye", is_local=False)),
  90. '[1, "hi", bye]',
  91. ),
  92. ({1: 2, 3: 4}, '{"1": 2, "3": 4}'),
  93. (
  94. {1: Var.create_safe("hi"), 3: Var.create_safe("bye", is_local=False)},
  95. '{"1": "hi", "3": bye}',
  96. ),
  97. (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"),
  98. (datetime.date(2021, 1, 1), "2021-01-01"),
  99. (datetime.time(1, 1, 1, 1), "01:01:01.000001"),
  100. (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
  101. ],
  102. )
  103. def test_serialize(value: Any, expected: str):
  104. """Test that serialize returns the correct value.
  105. Args:
  106. value: The value to serialize.
  107. expected: The expected result.
  108. """
  109. assert serializers.serialize(value) == expected