Browse Source

Make better/less use of dict.keys() calls (#3455)

Alexander Morgan 11 months ago
parent
commit
ad3134413b

+ 3 - 2
reflex/.templates/apps/demo/code/webui/state.py

@@ -66,7 +66,8 @@ class State(State):
         del self.chats[self.current_chat]
         if len(self.chats) == 0:
             self.chats = DEFAULT_CHATS
-        self.current_chat = list(self.chats.keys())[0]
+        # set self.current_chat to the first chat.
+        self.current_chat = next(iter(self.chats))
         self.toggle_drawer()
 
     def set_chat(self, chat_name: str):
@@ -85,7 +86,7 @@ class State(State):
         Returns:
             The list of chat names.
         """
-        return list(self.chats.keys())
+        return [*self.chats]
 
     async def process_question(self, form_data: dict[str, str]):
         """Get the response from the API.

+ 2 - 5
reflex/app.py

@@ -707,11 +707,8 @@ class App(LifespanMixin, Base):
         page_imports = {
             i
             for i, tags in imports.items()
-            if i
-            not in [
-                *constants.PackageJson.DEPENDENCIES.keys(),
-                *constants.PackageJson.DEV_DEPENDENCIES.keys(),
-            ]
+            if i not in constants.PackageJson.DEPENDENCIES
+            and i not in constants.PackageJson.DEV_DEPENDENCIES
             and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
             and i != ""
             and any(tag.install for tag in tags)

+ 7 - 4
reflex/components/component.py

@@ -360,7 +360,6 @@ class Component(BaseComponent, ABC):
         # Get the component fields, triggers, and props.
         fields = self.get_fields()
         component_specific_triggers = self.get_event_triggers()
-        triggers = component_specific_triggers.keys()
         props = self.get_props()
 
         # Add any events triggers.
@@ -370,13 +369,17 @@ class Component(BaseComponent, ABC):
 
         # Iterate through the kwargs and set the props.
         for key, value in kwargs.items():
-            if key.startswith("on_") and key not in triggers and key not in props:
+            if (
+                key.startswith("on_")
+                and key not in component_specific_triggers
+                and key not in props
+            ):
                 raise ValueError(
                     f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}"
                     f" is a third party component make sure to add `{key}` to the component's event triggers. "
                     f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info."
                 )
-            if key in triggers:
+            if key in component_specific_triggers:
                 # Event triggers are bound to event chains.
                 field_type = EventChain
             elif key in props:
@@ -436,7 +439,7 @@ class Component(BaseComponent, ABC):
                     )
 
             # Check if the key is an event trigger.
-            if key in triggers:
+            if key in component_specific_triggers:
                 # Temporarily disable full control for event triggers.
                 kwargs["event_triggers"][key] = self._create_event_chain(
                     value=value, args_spec=component_specific_triggers[key]

+ 1 - 1
reflex/utils/pyi_generator.py

@@ -424,7 +424,7 @@ def _generate_component_create_functiondef(
             ),
             ast.Constant(value=None),
         )
-        for trigger in sorted(clz().get_event_triggers().keys())
+        for trigger in sorted(clz().get_event_triggers())
     )
     logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
     create_args = ast.arguments(

+ 1 - 1
reflex/utils/types.py

@@ -509,7 +509,7 @@ def validate_parameter_literals(func):
         annotations = {param[0]: param[1].annotation for param in func_params}
 
         # validate args
-        for param, arg in zip(annotations.keys(), args):
+        for param, arg in zip(annotations, args):
             if annotations[param] is inspect.Parameter.empty:
                 continue
             validate_literal(param, arg, annotations[param], func.__name__)

+ 5 - 5
tests/components/core/test_banner.py

@@ -10,19 +10,19 @@ from reflex.components.radix.themes.typography.text import Text
 def test_websocket_target_url():
     url = WebsocketTargetURL.create()
     _imports = url._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == ["/utils/state", "/env.json"]
+    assert tuple(_imports) == ("/utils/state", "/env.json")
 
 
 def test_connection_banner():
     banner = ConnectionBanner.create()
     _imports = banner._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == [
+    assert tuple(_imports) == (
         "react",
         "/utils/context",
         "/utils/state",
         "@radix-ui/themes@^3.0.0",
         "/env.json",
-    ]
+    )
 
     msg = "Connection error"
     custom_banner = ConnectionBanner.create(Text.create(msg))
@@ -32,13 +32,13 @@ def test_connection_banner():
 def test_connection_modal():
     modal = ConnectionModal.create()
     _imports = modal._get_all_imports(collapse=True)
-    assert list(_imports.keys()) == [
+    assert tuple(_imports) == (
         "react",
         "/utils/context",
         "/utils/state",
         "@radix-ui/themes@^3.0.0",
         "/env.json",
-    ]
+    )
 
     msg = "Connection error"
     custom_modal = ConnectionModal.create(Text.create(msg))

+ 3 - 4
tests/components/core/test_debounce.py

@@ -98,11 +98,10 @@ def test_event_triggers():
             on_change=S.on_change,
         )
     )
-    default_event_triggers = list(rx.Component().get_event_triggers().keys())
-    assert list(debounced_input.get_event_triggers().keys()) == [
-        *default_event_triggers,
+    assert tuple(debounced_input.get_event_triggers()) == (
+        *rx.Component().get_event_triggers(),  # default event triggers
         "on_change",
-    ]
+    )
 
 
 def test_render_child_props_recursive():

+ 1 - 1
tests/components/datadisplay/test_datatable.py

@@ -114,4 +114,4 @@ def test_serialize_dataframe():
     value = serialize(df)
     assert value == serialize_dataframe(df)
     assert isinstance(value, dict)
-    assert list(value.keys()) == ["columns", "data"]
+    assert tuple(value) == ("columns", "data")

+ 1 - 1
tests/components/test_component.py

@@ -566,7 +566,7 @@ def test_get_event_triggers(component1, component2):
         EventTriggers.ON_MOUNT,
         EventTriggers.ON_UNMOUNT,
     }
-    assert set(component1().get_event_triggers().keys()) == default_triggers
+    assert component1().get_event_triggers().keys() == default_triggers
     assert (
         component2().get_event_triggers().keys()
         == {"on_open", "on_close"} | default_triggers

+ 5 - 5
tests/test_app.py

@@ -235,9 +235,9 @@ def test_add_page_default_route(app: App, index_page, about_page):
     """
     assert app.pages == {}
     app.add_page(index_page)
-    assert set(app.pages.keys()) == {"index"}
+    assert app.pages.keys() == {"index"}
     app.add_page(about_page)
-    assert set(app.pages.keys()) == {"index", "about"}
+    assert app.pages.keys() == {"index", "about"}
 
 
 def test_add_page_set_route(app: App, index_page, windows_platform: bool):
@@ -251,7 +251,7 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
     route = "test" if windows_platform else "/test"
     assert app.pages == {}
     app.add_page(index_page, route=route)
-    assert set(app.pages.keys()) == {"test"}
+    assert app.pages.keys() == {"test"}
 
 
 def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
@@ -268,7 +268,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
         route.lstrip("/").replace("/", "\\")
     assert app.pages == {}
     app.add_page(index_page, route=route)
-    assert set(app.pages.keys()) == {"test/[dynamic]"}
+    assert app.pages.keys() == {"test/[dynamic]"}
     assert "dynamic" in app.state.computed_vars
     assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
         constants.ROUTER
@@ -287,7 +287,7 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
     route = "test\\nested" if windows_platform else "/test/nested"
     assert app.pages == {}
     app.add_page(index_page, route=route)
-    assert set(app.pages.keys()) == {route.strip(os.path.sep)}
+    assert app.pages.keys() == {route.strip(os.path.sep)}
 
 
 def test_add_page_invalid_api_route(app: App, index_page):

+ 4 - 4
tests/test_state.py

@@ -287,7 +287,7 @@ def test_class_vars(test_state):
         test_state: A state.
     """
     cls = type(test_state)
-    assert set(cls.vars.keys()) == {
+    assert cls.vars.keys() == {
         "router",
         "num1",
         "num2",
@@ -310,7 +310,7 @@ def test_event_handlers(test_state):
     Args:
         test_state: A state.
     """
-    expected = {
+    expected_keys = (
         "do_something",
         "set_array",
         "set_complex",
@@ -320,10 +320,10 @@ def test_event_handlers(test_state):
         "set_num1",
         "set_num2",
         "set_obj",
-    }
+    )
 
     cls = type(test_state)
-    assert set(cls.event_handlers.keys()).intersection(expected) == expected
+    assert all(key in cls.event_handlers for key in expected_keys)
 
 
 def test_default_value(test_state):

+ 1 - 1
tests/utils/test_imports.py

@@ -72,7 +72,7 @@ def test_merge_imports(input_1, input_2, output):
 
     """
     res = merge_imports(input_1, input_2)
-    assert set(res.keys()) == set(output.keys())
+    assert res.keys() == output.keys()
 
     for key in output:
         assert set(res[key]) == set(output[key])