浏览代码

feature: allow serialize Enum object (#2457)

Đỗ Trường Giang 2 月之前
父节点
当前提交
93d2706f2d
共有 2 个文件被更改,包括 23 次插入6 次删除
  1. 14 6
      taipy/common/config/_serializer/_base_serializer.py
  2. 9 0
      taipy/common/config/common/_template_handler.py

+ 14 - 6
taipy/common/config/_serializer/_base_serializer.py

@@ -14,6 +14,7 @@ import re
 import types
 from abc import abstractmethod
 from datetime import datetime, timedelta
+from enum import Enum
 from typing import Any, Dict, Optional
 
 from .._config import _Config
@@ -38,6 +39,7 @@ class _BaseSerializer(object):
         "timedelta",
         "function",
         "class",
+        "enum",
         "SECTION",
     ]
     _section_class = {_GLOBAL_NODE_NAME: GlobalAppConfig}
@@ -70,14 +72,14 @@ class _BaseSerializer(object):
             config_as_dict[u_sect_name] = u_sect._to_dict()
         for sect_name, sections in configuration._sections.items():
             config_as_dict[sect_name] = cls.__to_dict(sections)
-        return cls.__stringify(config_as_dict)
+        return cls._stringify(config_as_dict)
 
     @classmethod
     def __to_dict(cls, sections: Dict[str, Any]):
         return {section_id: section._to_dict() for section_id, section in sections.items()}
 
     @classmethod
-    def __stringify(cls, as_dict):
+    def _stringify(cls, as_dict):
         if as_dict is None:
             return None
         if hasattr(as_dict, "_stringify") and callable(as_dict._stringify):
@@ -96,14 +98,16 @@ class _BaseSerializer(object):
             return f"{as_dict.__module__}.{as_dict.__name__}:function"
         if inspect.isclass(as_dict):
             return f"{as_dict.__module__}.{as_dict.__qualname__}:class"
+        if isinstance(as_dict, Enum):
+            return f"{as_dict.__module__}.{as_dict.__class__.__qualname__}.{as_dict.name}:enum"
         if isinstance(as_dict, dict):
-            return {str(key): cls.__stringify(val) for key, val in as_dict.items()}
+            return {str(key): cls._stringify(val) for key, val in as_dict.items()}
         if isinstance(as_dict, list):
-            return [cls.__stringify(val) for val in as_dict]
+            return [cls._stringify(val) for val in as_dict]
         if isinstance(as_dict, tuple):
-            return [cls.__stringify(val) for val in as_dict]
+            return [cls._stringify(val) for val in as_dict]
         if isinstance(as_dict, set):
-            return [cls.__stringify(val) for val in as_dict]
+            return [cls._stringify(val) for val in as_dict]
         return as_dict
 
     @classmethod
@@ -155,12 +159,16 @@ class _BaseSerializer(object):
                         return _TemplateHandler._to_function(actual_val)
                     elif dynamic_type == "class":
                         return _TemplateHandler._to_class(actual_val)
+                    elif dynamic_type == "enum":
+                        return _TemplateHandler._to_enum(actual_val)
                     elif dynamic_type == "str":
                         return actual_val
                     else:
                         error_msg = f"Error loading toml configuration at {val}. {dynamic_type} type is not supported."
                         raise LoadingError(error_msg)
             if isinstance(val, dict):
+                if len(val) == 1 and list(val.keys())[0] in cls._registered_types.keys():
+                    return cls._registered_types[list(val.keys())[0]]._pythonify(list(val.values())[0])
                 return {str(k): cls._pythonify(v) for k, v in val.items()}
             if isinstance(val, list):
                 return [cls._pythonify(v) for v in val]

+ 9 - 0
taipy/common/config/common/_template_handler.py

@@ -128,3 +128,12 @@ class _TemplateHandler:
             return locate(val)
         except Exception:
             raise InconsistentEnvVariableError(f"{val} is not a valid class.") from None
+
+    @staticmethod
+    def _to_enum(val: str):
+        enum_class, enum_value = val.rsplit(".", 1)
+        try:
+            enum = locate(enum_class)
+            return enum[enum_value]  # type: ignore[index]
+        except Exception:
+            raise InconsistentEnvVariableError(f"{val} is not a valid enum.") from None