Forráskód Böngészése

feat: add _ScenarioManager._get_by_creation_time() to filter scenarios by creation time

trgiangdo 10 hónapja
szülő
commit
eedab9b934

+ 24 - 2
taipy/core/scenario/_scenario_manager.py

@@ -9,7 +9,7 @@
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 # specific language governing permissions and limitations under the License.
 # specific language governing permissions and limitations under the License.
 
 
-import datetime
+from datetime import datetime
 from functools import partial
 from functools import partial
 from typing import Any, Callable, Dict, List, Literal, Optional, Union
 from typing import Any, Callable, Dict, List, Literal, Optional, Union
 
 
@@ -122,7 +122,7 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
     def _create(
     def _create(
         cls,
         cls,
         config: ScenarioConfig,
         config: ScenarioConfig,
-        creation_date: Optional[datetime.datetime] = None,
+        creation_date: Optional[datetime] = None,
         name: Optional[str] = None,
         name: Optional[str] = None,
     ) -> Scenario:
     ) -> Scenario:
         _task_manager = _TaskManagerFactory._build_manager()
         _task_manager = _TaskManagerFactory._build_manager()
@@ -468,3 +468,25 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
         for fil in filters:
         for fil in filters:
             fil.update({"config_id": config_id})
             fil.update({"config_id": config_id})
         return cls._repository._load_all(filters)
         return cls._repository._load_all(filters)
+
+    @classmethod
+    def _get_by_creation_time(
+        cls, start_time: datetime, end_time: datetime, version_number: Optional[str] = None
+    ) -> List[Scenario]:
+        """
+        Get all scenarios by a given creation time period.
+        The time period is inclusive.
+
+        Parameters:
+            start_time (datetime): Start time of the period.
+            end_time (datetime): End time of the period.
+
+        Returns:
+            List[Scenario]: List of scenarios created in the given time period.
+        """
+        filters = cls._build_filters_with_version(version_number)
+        if not filters:
+            filters = [{}]
+
+        scenarios = cls._repository._load_all(filters)
+        return [scenario for scenario in scenarios if start_time <= scenario.creation_date <= end_time]

+ 25 - 0
tests/core/scenario/test_scenario_manager.py

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
 from typing import Callable, Iterable, Optional
 from typing import Callable, Iterable, Optional
 from unittest.mock import ANY, patch
 from unittest.mock import ANY, patch
 
 
+import freezegun
 import pytest
 import pytest
 
 
 from taipy.config.common.frequency import Frequency
 from taipy.config.common.frequency import Frequency
@@ -1481,3 +1482,27 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment():
 
 
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
+
+
+def test_get_scenarios_by_creation_datetime():
+    scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
+
+    with freezegun.freeze_time("2024-01-01"):
+        s_1_1 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-01-03"):
+        s_1_2 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-02-01"):
+        s_1_3 = _ScenarioManager._create(scenario_config_1)
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios
+
+    # The time period is inclusive
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    assert len(filtered_scenarios) == 2
+    assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    assert len(filtered_scenarios) == 3
+    assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])

+ 25 - 0
tests/core/scenario/test_scenario_manager_with_sql_repo.py

@@ -11,6 +11,7 @@
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
+import freezegun
 import pytest
 import pytest
 
 
 from taipy.config.common.frequency import Frequency
 from taipy.config.common.frequency import Frequency
@@ -435,3 +436,27 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment(init_sql_re
 
 
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_1.id)) == 3
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
+
+
+def test_get_scenarios_by_creation_datetime(init_sql_repo):
+    scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
+
+    with freezegun.freeze_time("2024-01-01"):
+        s_1_1 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-01-03"):
+        s_1_2 = _ScenarioManager._create(scenario_config_1)
+    with freezegun.freeze_time("2024-02-01"):
+        s_1_3 = _ScenarioManager._create(scenario_config_1)
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios
+
+    # The time period is inclusive
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    assert len(filtered_scenarios) == 2
+    assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    assert len(filtered_scenarios) == 3
+    assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])