Prechádzať zdrojové kódy

feat: integreate _ScenarioManager._filter_by_creation_time() to taipy.get_scenarios() and taipy.get_primary_scenarios()

trgiangdo 10 mesiacov pred
rodič
commit
c3fe881b5d

+ 29 - 22
taipy/core/scenario/_scenario_manager.py

@@ -306,6 +306,35 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
             scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
             scenarios.sort(key=lambda x: (x.name, x.id), reverse=descending)
         return scenarios
         return scenarios
 
 
+    @classmethod
+    def _filter_by_creation_time(
+        cls,
+        scenarios: List[Scenario],
+        created_start_time: Optional[datetime] = None,
+        created_end_time: Optional[datetime] = None,
+    ) -> List[Scenario]:
+        """
+        Filter a list of scenarios by a given creation time period.
+        The time period is inclusive.
+
+        Parameters:
+            created_start_time (Optional[datetime]): Start time of the period.
+            created_end_time (Optional[datetime]): End time of the period.
+
+        Returns:
+            List[Scenario]: List of scenarios created in the given time period.
+        """
+        if not created_start_time and not created_end_time:
+            return scenarios
+
+        if not created_start_time:
+            return [scenario for scenario in scenarios if scenario.creation_date <= created_end_time]
+
+        if not created_end_time:
+            return [scenario for scenario in scenarios if created_start_time <= scenario.creation_date]
+
+        return [scenario for scenario in scenarios if created_start_time <= scenario.creation_date <= created_end_time]
+
     @classmethod
     @classmethod
     def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
     def _is_promotable_to_primary(cls, scenario: Union[Scenario, ScenarioId]) -> bool:
         if isinstance(scenario, str):
         if isinstance(scenario, str):
@@ -468,25 +497,3 @@ class _ScenarioManager(_Manager[Scenario], _VersionMixin):
         for fil in filters:
         for fil in filters:
             fil.update({"config_id": config_id})
             fil.update({"config_id": config_id})
         return cls._repository._load_all(filters)
         return cls._repository._load_all(filters)
-
-    @classmethod
-    def _get_by_creation_time(
-        cls, start_time: datetime, end_time: datetime, version_number: Optional[str] = None
-    ) -> List[Scenario]:
-        """
-        Get all scenarios by a given creation time period.
-        The time period is inclusive.
-
-        Parameters:
-            start_time (datetime): Start time of the period.
-            end_time (datetime): End time of the period.
-
-        Returns:
-            List[Scenario]: List of scenarios created in the given time period.
-        """
-        filters = cls._build_filters_with_version(version_number)
-        if not filters:
-            filters = [{}]
-
-        scenarios = cls._repository._load_all(filters)
-        return [scenario for scenario in scenarios if start_time <= scenario.creation_date <= end_time]

+ 13 - 0
taipy/core/taipy.py

@@ -510,6 +510,8 @@ def get_scenarios(
     tag: Optional[str] = None,
     tag: Optional[str] = None,
     is_sorted: bool = False,
     is_sorted: bool = False,
     descending: bool = False,
     descending: bool = False,
+    created_start_time: Optional[datetime] = None,
+    created_end_time: Optional[datetime] = None,
     sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
     sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
 ) -> List[Scenario]:
 ) -> List[Scenario]:
     """Retrieve a list of existing scenarios filtered by cycle or tag.
     """Retrieve a list of existing scenarios filtered by cycle or tag.
@@ -526,6 +528,8 @@ def get_scenarios(
             The default value is False.
             The default value is False.
         descending (bool): If True, sort the output list of scenarios in descending order.
         descending (bool): If True, sort the output list of scenarios in descending order.
             The default value is False.
             The default value is False.
+        created_start_time (Optional[datetime]): The optional inclusive start date to filter scenarios by creation date.
+        created_end_time (Optional[datetime]): The optional inclusive end date to filter scenarios by creation date.
         sort_key (Literal["name", "id", "creation_date", "tags"]): The optional sort_key to
         sort_key (Literal["name", "id", "creation_date", "tags"]): The optional sort_key to
             decide upon what key scenarios are sorted. The sorting is in increasing order for
             decide upon what key scenarios are sorted. The sorting is in increasing order for
             dates, in alphabetical order for name and id, and in lexicographical order for tags.
             dates, in alphabetical order for name and id, and in lexicographical order for tags.
@@ -548,6 +552,8 @@ def get_scenarios(
     else:
     else:
         scenarios = []
         scenarios = []
 
 
+    if created_start_time or created_end_time:
+        scenarios = scenario_manager._filter_by_creation_time(scenarios, created_start_time, created_end_time)
     if is_sorted:
     if is_sorted:
         scenario_manager._sort_scenarios(scenarios, descending, sort_key)
         scenario_manager._sort_scenarios(scenarios, descending, sort_key)
     return scenarios
     return scenarios
@@ -569,6 +575,8 @@ def get_primary(cycle: Cycle) -> Optional[Scenario]:
 def get_primary_scenarios(
 def get_primary_scenarios(
     is_sorted: bool = False,
     is_sorted: bool = False,
     descending: bool = False,
     descending: bool = False,
+    created_start_time: Optional[datetime] = None,
+    created_end_time: Optional[datetime] = None,
     sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
     sort_key: Literal["name", "id", "config_id", "creation_date", "tags"] = "name",
 ) -> List[Scenario]:
 ) -> List[Scenario]:
     """Retrieve a list of all primary scenarios.
     """Retrieve a list of all primary scenarios.
@@ -578,6 +586,8 @@ def get_primary_scenarios(
             The default value is False.
             The default value is False.
         descending (bool): If True, sort the output list of scenarios in descending order.
         descending (bool): If True, sort the output list of scenarios in descending order.
             The default value is False.
             The default value is False.
+        created_start_time (Optional[datetime]): The optional inclusive start date to filter scenarios by creation date.
+        created_end_time (Optional[datetime]): The optional inclusive end date to filter scenarios by creation date.
         sort_key (Literal["name", "id", "creation_date", "tags"]): The optional sort_key to
         sort_key (Literal["name", "id", "creation_date", "tags"]): The optional sort_key to
             decide upon what key scenarios are sorted. The sorting is in increasing order for
             decide upon what key scenarios are sorted. The sorting is in increasing order for
             dates, in alphabetical order for name and id, and in lexicographical order for tags.
             dates, in alphabetical order for name and id, and in lexicographical order for tags.
@@ -589,6 +599,9 @@ def get_primary_scenarios(
     """
     """
     scenario_manager = _ScenarioManagerFactory._build_manager()
     scenario_manager = _ScenarioManagerFactory._build_manager()
     scenarios = scenario_manager._get_primary_scenarios()
     scenarios = scenario_manager._get_primary_scenarios()
+
+    if created_start_time or created_end_time:
+        scenarios = scenario_manager._filter_by_creation_time(scenarios, created_start_time, created_end_time)
     if is_sorted:
     if is_sorted:
         scenario_manager._sort_scenarios(scenarios, descending, sort_key)
         scenario_manager._sort_scenarios(scenarios, descending, sort_key)
     return scenarios
     return scenarios

+ 32 - 4
tests/core/scenario/test_scenario_manager.py

@@ -1484,7 +1484,7 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment():
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
 
 
 
 
-def test_get_scenarios_by_creation_datetime():
+def test_filter_scenarios_by_creation_datetime():
     scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
     scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
 
 
     with freezegun.freeze_time("2024-01-01"):
     with freezegun.freeze_time("2024-01-01"):
@@ -1494,15 +1494,43 @@ def test_get_scenarios_by_creation_datetime():
     with freezegun.freeze_time("2024-02-01"):
     with freezegun.freeze_time("2024-02-01"):
         s_1_3 = _ScenarioManager._create(scenario_config_1)
         s_1_3 = _ScenarioManager._create(scenario_config_1)
 
 
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    all_scenarios = _ScenarioManager._get_all()
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 1, 1),
+        created_end_time=datetime(2024, 1, 2),
+    )
     assert len(filtered_scenarios) == 1
     assert len(filtered_scenarios) == 1
     assert [s_1_1] == filtered_scenarios
     assert [s_1_1] == filtered_scenarios
 
 
     # The time period is inclusive
     # The time period is inclusive
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 1, 1),
+        created_end_time=datetime(2024, 1, 3),
+    )
     assert len(filtered_scenarios) == 2
     assert len(filtered_scenarios) == 2
     assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
     assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
 
 
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2023, 1, 1),
+        created_end_time=datetime(2025, 1, 1),
+    )
     assert len(filtered_scenarios) == 3
     assert len(filtered_scenarios) == 3
     assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])
     assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 2, 1),
+    )
+    assert len(filtered_scenarios) == 1
+    assert [s_1_3] == filtered_scenarios
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_end_time=datetime(2024, 1, 2),
+    )
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios

+ 32 - 4
tests/core/scenario/test_scenario_manager_with_sql_repo.py

@@ -438,7 +438,7 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment(init_sql_re
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
     assert len(_ScenarioManager._get_by_config_id(scenario_config_2.id)) == 2
 
 
 
 
-def test_get_scenarios_by_creation_datetime(init_sql_repo):
+def test_filter_scenarios_by_creation_datetime(init_sql_repo):
     scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
     scenario_config_1 = Config.configure_scenario("s1", sequence_configs=[])
 
 
     with freezegun.freeze_time("2024-01-01"):
     with freezegun.freeze_time("2024-01-01"):
@@ -448,15 +448,43 @@ def test_get_scenarios_by_creation_datetime(init_sql_repo):
     with freezegun.freeze_time("2024-02-01"):
     with freezegun.freeze_time("2024-02-01"):
         s_1_3 = _ScenarioManager._create(scenario_config_1)
         s_1_3 = _ScenarioManager._create(scenario_config_1)
 
 
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 2))
+    all_scenarios = _ScenarioManager._get_all()
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 1, 1),
+        created_end_time=datetime(2024, 1, 2),
+    )
     assert len(filtered_scenarios) == 1
     assert len(filtered_scenarios) == 1
     assert [s_1_1] == filtered_scenarios
     assert [s_1_1] == filtered_scenarios
 
 
     # The time period is inclusive
     # The time period is inclusive
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2024, 1, 1), datetime(2024, 1, 3))
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 1, 1),
+        created_end_time=datetime(2024, 1, 3),
+    )
     assert len(filtered_scenarios) == 2
     assert len(filtered_scenarios) == 2
     assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
     assert sorted([s_1_1.id, s_1_2.id]) == sorted([scenario.id for scenario in filtered_scenarios])
 
 
-    filtered_scenarios = _ScenarioManager._get_by_creation_time(datetime(2023, 1, 1), datetime(2025, 1, 1))
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2023, 1, 1),
+        created_end_time=datetime(2025, 1, 1),
+    )
     assert len(filtered_scenarios) == 3
     assert len(filtered_scenarios) == 3
     assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])
     assert sorted([s_1_1.id, s_1_2.id, s_1_3.id]) == sorted([scenario.id for scenario in filtered_scenarios])
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_start_time=datetime(2024, 2, 1),
+    )
+    assert len(filtered_scenarios) == 1
+    assert [s_1_3] == filtered_scenarios
+
+    filtered_scenarios = _ScenarioManager._filter_by_creation_time(
+        scenarios=all_scenarios,
+        created_end_time=datetime(2024, 1, 2),
+    )
+    assert len(filtered_scenarios) == 1
+    assert [s_1_1] == filtered_scenarios

+ 6 - 0
tests/core/test_taipy.py

@@ -431,6 +431,9 @@ class TestTaipy:
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._get_all_by_tag") as mck:
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._get_all_by_tag") as mck:
             tp.get_scenarios(tag="tag")
             tp.get_scenarios(tag="tag")
             mck.assert_called_once_with("tag")
             mck.assert_called_once_with("tag")
+        with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._filter_by_creation_time") as mck:
+            tp.get_scenarios(created_start_time=datetime.datetime(2021, 1, 1))
+            mck.assert_called_once_with([], datetime.datetime(2021, 1, 1), None)
 
 
     def test_get_scenarios_sorted(self):
     def test_get_scenarios_sorted(self):
         scenario_1_cfg = Config.configure_scenario(id="scenario_1")
         scenario_1_cfg = Config.configure_scenario(id="scenario_1")
@@ -500,6 +503,9 @@ class TestTaipy:
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._get_primary_scenarios") as mck:
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._get_primary_scenarios") as mck:
             tp.get_primary_scenarios()
             tp.get_primary_scenarios()
             mck.assert_called_once_with()
             mck.assert_called_once_with()
+        with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._filter_by_creation_time") as mck:
+            tp.get_scenarios(created_end_time=datetime.datetime(2021, 1, 1))
+            mck.assert_called_once_with([], None, datetime.datetime(2021, 1, 1))
 
 
     def test_set_primary(self, scenario):
     def test_set_primary(self, scenario):
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._set_primary") as mck:
         with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._set_primary") as mck: