瀏覽代碼

feat: add _ScenarioManager._get_by_creation_time() to filter scenarios by creation time

trgiangdo 10 月之前
父節點
當前提交
eedab9b934

+ 24 - 2
taipy/core/scenario/_scenario_manager.py

@@ -9,7 +9,7 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 
-import datetime
+from datetime import datetime
 from functools import partial
 from typing import Any, Callable, Dict, List, Literal, Optional, Union
 
@@ -122,7 +122,7 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
     def _create(
         cls,
         config: ScenarioConfig,
-        creation_date: Optional[datetime.datetime] = None,
+        creation_date: Optional[datetime] = None,
         name: Optional[str] = None,
     ) -> Scenario:
         _task_manager = _TaskManagerFactory._build_manager()
@@ -468,3 +468,25 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
         for fil in filters:
             fil.update({"config_id": config_id})
         return cls._repository._load_all(filters)
+
+    @classmethod
+    def _get_by_creation_time(
+        cls, start_time: datetime, end_time: datetime, version_number: Optional[str] = None
+    ) -> List[Scenario]:
+        """
+        Get all scenarios by a given creation time period.
+        The time period is inclusive.
+
+        Parameters:
+            start_time (datetime): Start time of the period.
+            end_time (datetime): End time of the period.
+
+        Returns:
+            List[Scenario]: List of scenarios created in the given time period.
+        """
+        filters = cls._build_filters_with_version(version_number)
+        if not filters:
+            filters = [{}]
+
+        scenarios = cls._repository._load_all(filters)
+        return [scenario for scenario in scenarios if start_time <= scenario.creation_date <= end_time]

+ 25 - 0
tests/core/scenario/test_scenario_manager.py

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
 from typing import Callable, Iterable, Optional
 from unittest.mock import ANY, patch
 
+import freezegun
 import pytest
 
 from taipy.config.common.frequency import Frequency
@@ -1481,3 +1482,27 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment():
 
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
+
+
+def test_get_scenarios_by_creation_datetime():
+    scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
+
+    with freezegun.freeze_time("2024-01-01"):
+        s_1_1 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-01-03"):
+        s_1_2 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-02-01"):
+        s_1_3 = _ScenarioManager._create(scenario_config_1)
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios
+
+    # The time period is inclusive
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    assert len(filtered_scenarios) == 2
+    assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    assert len(filtered_scenarios) == 3
+    assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])

+ 25 - 0
tests/core/scenario/test_scenario_manager_with_sql_repo.py

@@ -11,6 +11,7 @@
 
 from datetime import datetime, timedelta
 
+import freezegun
 import pytest
 
 from taipy.config.common.frequency import Frequency
@@ -435,3 +436,27 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment(init_sql_re
 
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
+
+
+def test_get_scenarios_by_creation_datetime(init_sql_repo):
+    scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
+
+    with freezegun.freeze_time("2024-01-01"):
+        s_1_1 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-01-03"):
+        s_1_2 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-02-01"):
+        s_1_3 = _ScenarioManager._create(scenario_config_1)
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios
+
+    # The time period is inclusive
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    assert len(filtered_scenarios) == 2
+    assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    assert len(filtered_scenarios) == 3
+    assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])