Forráskód Böngészése

add special handling for infinity and nan (#3943)

* add special handling for infinity and nan

* use custom exception

* add test for inf and nan
Khaleel Al-Adhami 8 hónapja
szülő
commit
6ae66987b6
3 módosított fájl, 38 hozzáadás és 2 törlés
  1. 4 0
      reflex/utils/exceptions.py
  2. 17 2
      reflex/vars/number.py
  3. 17 0
      tests/test_var.py

+ 4 - 0
reflex/utils/exceptions.py

@@ -107,3 +107,7 @@ class EventHandlerShadowsBuiltInStateMethod(ReflexError, NameError):
 
 
 class GeneratedCodeHasNoFunctionDefs(ReflexError):
 class GeneratedCodeHasNoFunctionDefs(ReflexError):
     """Raised when refactored code generated with flexgen has no functions defined."""
     """Raised when refactored code generated with flexgen has no functions defined."""
+
+
+class PrimitiveUnserializableToJSON(ReflexError, ValueError):
+    """Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""

+ 17 - 2
reflex/vars/number.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 
 
 import dataclasses
 import dataclasses
 import json
 import json
+import math
 import sys
 import sys
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
@@ -18,7 +19,7 @@ from typing import (
 )
 )
 
 
 from reflex.constants.base import Dirs
 from reflex.constants.base import Dirs
-from reflex.utils.exceptions import VarTypeError
+from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
 
 
 from .base import (
 from .base import (
@@ -1040,7 +1041,14 @@ class LiteralNumberVar(LiteralVar, NumberVar):
 
 
         Returns:
         Returns:
             The JSON representation of the var.
             The JSON representation of the var.
+
+        Raises:
+            PrimitiveUnserializableToJSON: If the var is unserializable to JSON.
         """
         """
+        if math.isinf(self._var_value) or math.isnan(self._var_value):
+            raise PrimitiveUnserializableToJSON(
+                f"No valid JSON representation for {self}"
+            )
         return json.dumps(self._var_value)
         return json.dumps(self._var_value)
 
 
     def __hash__(self) -> int:
     def __hash__(self) -> int:
@@ -1062,8 +1070,15 @@ class LiteralNumberVar(LiteralVar, NumberVar):
         Returns:
         Returns:
             The number var.
             The number var.
         """
         """
+        if math.isinf(value):
+            js_expr = "Infinity" if value > 0 else "-Infinity"
+        elif math.isnan(value):
+            js_expr = "NaN"
+        else:
+            js_expr = str(value)
+
         return cls(
         return cls(
-            _js_expr=str(value),
+            _js_expr=js_expr,
             _var_type=type(value),
             _var_type=type(value),
             _var_data=_var_data,
             _var_data=_var_data,
             _var_value=value,
             _var_value=value,

+ 17 - 0
tests/test_var.py

@@ -9,6 +9,7 @@ from pandas import DataFrame
 from reflex.base import Base
 from reflex.base import Base
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
 from reflex.state import BaseState
 from reflex.state import BaseState
+from reflex.utils.exceptions import PrimitiveUnserializableToJSON
 from reflex.utils.imports import ImportVar
 from reflex.utils.imports import ImportVar
 from reflex.vars import VarData
 from reflex.vars import VarData
 from reflex.vars.base import (
 from reflex.vars.base import (
@@ -974,6 +975,22 @@ def test_index_operation():
     assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
     assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
 
 
 
 
+@pytest.mark.parametrize(
+    "var, expected_js",
+    [
+        (Var.create(float("inf")), "Infinity"),
+        (Var.create(-float("inf")), "-Infinity"),
+        (Var.create(float("nan")), "NaN"),
+    ],
+)
+def test_inf_and_nan(var, expected_js):
+    assert str(var) == expected_js
+    assert isinstance(var, NumberVar)
+    assert isinstance(var, LiteralVar)
+    with pytest.raises(PrimitiveUnserializableToJSON):
+        var.json()
+
+
 def test_array_operations():
 def test_array_operations():
     array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
     array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])