Ver código fonte

Add decimal.Decimal support to serializers and NumberVar (#5226)

* Add serializer for decimal.Decimal type that converts to float

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Add tests for decimal.Decimal serializer and NumberVar support

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Update NumberVar and related components to support decimal.Decimal

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Simplify test_all_number_operations to fix type compatibility with decimal.Decimal

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Fix decimal serialization to properly quote string values

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Fix decimal serialization functions

Co-Authored-By: Alek Petuskey <alek@pynecone.io>

* Revert "Simplify test_all_number_operations to fix type compatibility with decimal.Decimal"

This reverts commit 758d55f00053c0401cc37bbd70734ffe252502ca.

* revert bad test change

* add overload for Decimal in Var.create

move test_decimal_var to test_var and tweak the expectations

override return type for NumberVar.__neg__

* revert changes in `float_input_event`

needed to add another `.to` overload for proper type checking

* update test_serializers expectation

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Alek Petuskey <alek@pynecone.io>
Co-authored-by: Masen Furer <m_github@0x26.net>
devin-ai-integration[bot] 2 semanas atrás
pai
commit
6eec8e36ae

+ 14 - 0
reflex/utils/serializers.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import contextlib
 import dataclasses
+import decimal
 import functools
 import inspect
 import json
@@ -386,6 +387,19 @@ def serialize_uuid(uuid: UUID) -> str:
     return str(uuid)
 
 
+@serializer(to=float)
+def serialize_decimal(value: decimal.Decimal) -> float:
+    """Serialize a Decimal to a float.
+
+    Args:
+        value: The Decimal to serialize.
+
+    Returns:
+        The serialized Decimal as a float.
+    """
+    return float(value)
+
+
 @serializer(to=str)
 def serialize_color(color: Color) -> str:
     """Serialize a color.

+ 13 - 1
reflex/vars/base.py

@@ -15,6 +15,7 @@ import string
 import uuid
 import warnings
 from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
+from decimal import Decimal
 from types import CodeType, FunctionType
 from typing import (  # noqa: UP035
     TYPE_CHECKING,
@@ -630,6 +631,14 @@ class Var(Generic[VAR_TYPE]):
         _var_data: VarData | None = None,
     ) -> LiteralNumberVar[float]: ...
 
+    @overload
+    @classmethod
+    def create(
+        cls,
+        value: Decimal,
+        _var_data: VarData | None = None,
+    ) -> LiteralNumberVar[Decimal]: ...
+
     @overload
     @classmethod
     def create(  # pyright: ignore [reportOverlappingOverload]
@@ -743,7 +752,10 @@ class Var(Generic[VAR_TYPE]):
     def to(self, output: type[int]) -> NumberVar[int]: ...
 
     @overload
-    def to(self, output: type[int] | type[float]) -> NumberVar: ...
+    def to(self, output: type[float]) -> NumberVar[float]: ...
+
+    @overload
+    def to(self, output: type[Decimal]) -> NumberVar[Decimal]: ...
 
     @overload
     def to(

+ 16 - 8
reflex/vars/number.py

@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 import dataclasses
+import decimal
 import json
 import math
 from collections.abc import Callable
@@ -30,7 +31,10 @@ from .base import (
 )
 
 NUMBER_T = TypeVarExt(
-    "NUMBER_T", bound=(int | float), default=(int | float), covariant=True
+    "NUMBER_T",
+    bound=(int | float | decimal.Decimal),
+    default=(int | float | decimal.Decimal),
+    covariant=True,
 )
 
 if TYPE_CHECKING:
@@ -54,7 +58,7 @@ def raise_unsupported_operand_types(
     )
 
 
-class NumberVar(Var[NUMBER_T], python_types=(int, float)):
+class NumberVar(Var[NUMBER_T], python_types=(int, float, decimal.Decimal)):
     """Base class for immutable number vars."""
 
     def __add__(self, other: number_types) -> NumberVar:
@@ -285,13 +289,13 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
 
         return number_exponent_operation(+other, self)
 
-    def __neg__(self):
+    def __neg__(self) -> NumberVar:
         """Negate the number.
 
         Returns:
             The number negation operation.
         """
-        return number_negate_operation(self)
+        return number_negate_operation(self)  # pyright: ignore [reportReturnType]
 
     def __invert__(self):
         """Boolean NOT the number.
@@ -943,7 +947,7 @@ def boolean_not_operation(value: BooleanVar):
 class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
     """Base class for immutable literal number vars."""
 
-    _var_value: float | int = dataclasses.field(default=0)
+    _var_value: float | int | decimal.Decimal = dataclasses.field(default=0)
 
     def json(self) -> str:
         """Get the JSON representation of the var.
@@ -954,6 +958,8 @@ class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
         Raises:
             PrimitiveUnserializableToJSONError: If the var is unserializable to JSON.
         """
+        if isinstance(self._var_value, decimal.Decimal):
+            return json.dumps(float(self._var_value))
         if math.isinf(self._var_value) or math.isnan(self._var_value):
             raise PrimitiveUnserializableToJSONError(
                 f"No valid JSON representation for {self}"
@@ -969,7 +975,9 @@ class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
         return hash((type(self).__name__, self._var_value))
 
     @classmethod
-    def create(cls, value: float | int, _var_data: VarData | None = None):
+    def create(
+        cls, value: float | int | decimal.Decimal, _var_data: VarData | None = None
+    ):
         """Create the number var.
 
         Args:
@@ -1039,7 +1047,7 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
         )
 
 
-number_types = NumberVar | int | float
+number_types = NumberVar | int | float | decimal.Decimal
 boolean_types = BooleanVar | bool
 
 
@@ -1112,4 +1120,4 @@ def ternary_operation(
     return value
 
 
-NUMBER_TYPES = (int, float, NumberVar)
+NUMBER_TYPES = (int, float, decimal.Decimal, NumberVar)

+ 2 - 1
reflex/vars/sequence.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 import collections.abc
 import dataclasses
+import decimal
 import inspect
 import json
 import re
@@ -1558,7 +1559,7 @@ def is_tuple_type(t: GenericType) -> bool:
 
 
 def _determine_value_of_array_index(
-    var_type: GenericType, index: int | float | None = None
+    var_type: GenericType, index: int | float | decimal.Decimal | None = None
 ):
     """Determine the value of an array index.
 

+ 41 - 0
tests/units/test_var.py

@@ -1,3 +1,4 @@
+import decimal
 import json
 import math
 import typing
@@ -1920,3 +1921,43 @@ def test_str_var_in_components(mocker):
     rx.vstack(
         str(StateWithVar.field),
     )
+
+
+def test_decimal_number_operations():
+    """Test that decimal.Decimal values work with NumberVar operations."""
+    dec_num = Var.create(decimal.Decimal("123.456"))
+    assert isinstance(dec_num._var_value, decimal.Decimal)
+    assert str(dec_num) == "123.456"
+
+    result = dec_num + 10
+    assert str(result) == "(123.456 + 10)"
+
+    result = dec_num * 2
+    assert str(result) == "(123.456 * 2)"
+
+    result = dec_num / 2
+    assert str(result) == "(123.456 / 2)"
+
+    result = dec_num > 100
+    assert str(result) == "(123.456 > 100)"
+
+    result = dec_num < 200
+    assert str(result) == "(123.456 < 200)"
+
+    assert dec_num.json() == "123.456"
+
+
+def test_decimal_var_type_compatibility():
+    """Test that decimal.Decimal values are compatible with NumberVar type system."""
+    dec_num = Var.create(decimal.Decimal("123.456"))
+    int_num = Var.create(42)
+    float_num = Var.create(3.14)
+
+    result = dec_num + int_num
+    assert str(result) == "(123.456 + 42)"
+
+    result = dec_num * float_num
+    assert str(result) == "(123.456 * 3.14)"
+
+    result = (dec_num + int_num) / float_num
+    assert str(result) == "((123.456 + 42) / 3.14)"

+ 6 - 0
tests/units/utils/test_serializers.py

@@ -1,4 +1,5 @@
 import datetime
+import decimal
 import json
 from enum import Enum
 from pathlib import Path
@@ -188,6 +189,9 @@ class BaseSubclass(Base):
         (Color(color="slate", shade=1), "var(--slate-1)"),
         (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),
         (Color(color="accent", shade=1, alpha=True), "var(--accent-a1)"),
+        (decimal.Decimal("123.456"), 123.456),
+        (decimal.Decimal("-0.5"), -0.5),
+        (decimal.Decimal("0"), 0.0),
     ],
 )
 def test_serialize(value: Any, expected: str):
@@ -226,6 +230,8 @@ def test_serialize(value: Any, expected: str):
         (Color(color="slate", shade=1), '"var(--slate-1)"', True),
         (BaseSubclass, '"BaseSubclass"', True),
         (Path(), '"."', True),
+        (decimal.Decimal("123.456"), "123.456", True),
+        (decimal.Decimal("-0.5"), "-0.5", True),
     ],
 )
 def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):