Browse Source

add ndigits to round dunder method (#5019)

* add ndigits to round dunder method

* add integration test
Khaleel Al-Adhami 1 month ago
parent
commit
e182afebc0
3 changed files with 46 additions and 13 deletions
  1. 8 3
      reflex/vars/base.py
  2. 22 5
      reflex/vars/number.py
  3. 16 5
      tests/integration/test_var_operations.py

+ 8 - 3
reflex/vars/base.py

@@ -3317,10 +3317,15 @@ class Field(Generic[FIELD_TYPE]):
 
     @overload
     def __get__(
-        self: Field[int]
-        | Field[float]
+        self: Field[int] | Field[int | None],
+        instance: None,
+        owner: Any,
+    ) -> NumberVar[int]: ...
+
+    @overload
+    def __get__(
+        self: Field[float]
         | Field[int | float]
-        | Field[int | None]
         | Field[float | None]
         | Field[int | float | None],
         instance: None,

+ 22 - 5
reflex/vars/number.py

@@ -16,6 +16,8 @@ from typing import (
     overload,
 )
 
+from typing_extensions import TypeVar as TypeVarExt
+
 from reflex.constants.base import Dirs
 from reflex.utils.exceptions import (
     PrimitiveUnserializableToJSONError,
@@ -35,7 +37,9 @@ from .base import (
     var_operation_return,
 )
 
-NUMBER_T = TypeVar("NUMBER_T", int, float, bool)
+NUMBER_T = TypeVarExt(
+    "NUMBER_T", bound=(int | float), default=(int | float), covariant=True
+)
 
 if TYPE_CHECKING:
     from .sequence import ArrayVar
@@ -313,13 +317,19 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         return self
 
-    def __round__(self):
+    def __round__(self, ndigits: int | NumberVar = 0) -> NumberVar:
         """Round the number.
 
+        Args:
+            ndigits: The number of digits to round.
+
         Returns:
             The number round operation.
         """
-        return number_round_operation(self)
+        if not isinstance(ndigits, NUMBER_TYPES):
+            raise_unsupported_operand_types("round", (type(self), type(ndigits)))
+
+        return number_round_operation(self, +ndigits)
 
     def __ceil__(self):
         """Ceil the number.
@@ -653,16 +663,23 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
 
 
 @var_operation
-def number_round_operation(value: NumberVar):
+def number_round_operation(value: NumberVar, ndigits: NumberVar | int):
     """Round the number.
 
     Args:
         value: The number.
+        ndigits: The number of digits.
 
     Returns:
         The number round operation.
     """
-    return var_operation_return(js_expression=f"Math.round({value})", var_type=int)
+    if (isinstance(ndigits, LiteralNumberVar) and ndigits._var_value == 0) or (
+        isinstance(ndigits, int) and ndigits == 0
+    ):
+        return var_operation_return(js_expression=f"Math.round({value})", var_type=int)
+    return var_operation_return(
+        js_expression=f"(+{value}.toFixed({ndigits}))", var_type=float
+    )
 
 
 @var_operation

+ 16 - 5
tests/integration/test_var_operations.py

@@ -31,6 +31,7 @@ def VarOperations():
         int_var3: rx.Field[int] = rx.field(7)
         float_var1: rx.Field[float] = rx.field(10.5)
         float_var2: rx.Field[float] = rx.field(5.5)
+        long_float: rx.Field[float] = rx.field(13212312312.1231231)
         list1: rx.Field[list] = rx.field([1, 2])
         list2: rx.Field[list] = rx.field([3, 4])
         list3: rx.Field[list] = rx.field(["first", "second", "third"])
@@ -718,25 +719,33 @@ def VarOperations():
             ),
             # ObjectVar
             rx.box(
-                rx.text(VarOperationState.obj.name),
+                rx.text.span(VarOperationState.obj.name),
                 id="obj_name",
             ),
             rx.box(
-                rx.text(VarOperationState.obj.optional_none),
+                rx.text.span(VarOperationState.obj.optional_none),
                 id="obj_optional_none",
             ),
             rx.box(
-                rx.text(VarOperationState.obj.optional_str),
+                rx.text.span(VarOperationState.obj.optional_str),
                 id="obj_optional_str",
             ),
             rx.box(
-                rx.text(VarOperationState.obj.get("optional_none")),
+                rx.text.span(VarOperationState.obj.get("optional_none")),
                 id="obj_optional_none_get_none",
             ),
             rx.box(
-                rx.text(VarOperationState.obj.get("optional_none", "foo")),
+                rx.text.span(VarOperationState.obj.get("optional_none", "foo")),
                 id="obj_optional_none_get_foo",
             ),
+            rx.box(
+                rx.text.span(round(VarOperationState.long_float)),
+                id="float_round",
+            ),
+            rx.box(
+                rx.text.span(round(VarOperationState.long_float, 2)),
+                id="float_round_2",
+            ),
         )
 
 
@@ -965,6 +974,8 @@ def test_var_operations(driver, var_operations: AppHarness):
         ("obj_optional_str", "hello"),
         ("obj_optional_none_get_none", ""),
         ("obj_optional_none_get_foo", "foo"),
+        ("float_round", "13212312312"),
+        ("float_round_2", "13212312312.12"),
     ]
 
     for tag, expected in tests: