소스 검색

auto enable /_upload endpoint only if Upload component is used (#2430)

benedikt-bartscher 1 년 전
부모
커밋
0b1b8ee639
4개의 변경된 파일43개의 추가작업 그리고 12개의 파일을 삭제
  1. 5 2
      reflex/app.py
  2. 19 2
      reflex/components/core/upload.py
  3. 6 7
      reflex/components/core/upload.pyi
  4. 13 1
      scripts/pyi_generator.py

+ 5 - 2
reflex/app.py

@@ -43,6 +43,7 @@ from reflex.components.core.client_side_routing import (
     Default404Page,
     Default404Page,
     wait_for_client_redirect,
     wait_for_client_redirect,
 )
 )
+from reflex.components.core.upload import UploadFilesProvider
 from reflex.config import get_config
 from reflex.config import get_config
 from reflex.event import Event, EventHandler, EventSpec
 from reflex.event import Event, EventHandler, EventSpec
 from reflex.middleware import HydrateMiddleware, Middleware
 from reflex.middleware import HydrateMiddleware, Middleware
@@ -180,7 +181,6 @@ class App(Base):
         # Set up the API.
         # Set up the API.
         self.api = FastAPI()
         self.api = FastAPI()
         self.add_cors()
         self.add_cors()
-        self.add_default_endpoints()
 
 
         if self.state:
         if self.state:
             # Set up the state manager.
             # Set up the state manager.
@@ -242,7 +242,8 @@ class App(Base):
         self.api.get(str(constants.Endpoint.PING))(ping)
         self.api.get(str(constants.Endpoint.PING))(ping)
 
 
         # To upload files.
         # To upload files.
-        self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
+        if UploadFilesProvider.is_used:
+            self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
 
 
     def add_cors(self):
     def add_cors(self):
         """Add CORS middleware to the app."""
         """Add CORS middleware to the app."""
@@ -800,6 +801,8 @@ class App(Base):
             for future in concurrent.futures.as_completed(write_page_futures):
             for future in concurrent.futures.as_completed(write_page_futures):
                 future.result()
                 future.result()
 
 
+        self.add_default_endpoints()
+
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
     async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
         """Modify the state out of band.
         """Modify the state out of band.

+ 19 - 2
reflex/components/core/upload.py

@@ -1,7 +1,7 @@
 """A file upload component."""
 """A file upload component."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, ClassVar, Dict, List, Optional, Union
 
 
 from reflex import constants
 from reflex import constants
 from reflex.components.chakra.forms.input import Input
 from reflex.components.chakra.forms.input import Input
@@ -98,6 +98,23 @@ class UploadFilesProvider(Component):
     library = f"/{Dirs.CONTEXTS_PATH}"
     library = f"/{Dirs.CONTEXTS_PATH}"
     tag = "UploadFilesProvider"
     tag = "UploadFilesProvider"
 
 
+    is_used: ClassVar[bool] = False
+
+    @classmethod
+    def create(cls, *children, **props) -> Component:
+        """Create an UploadFilesProvider component.
+
+        Args:
+            *children: The children of the component.
+            **props: The properties of the component.
+
+        Returns:
+            The UploadFilesProvider component.
+        """
+        cls.is_used = True
+
+        return super().create(*children, **props)
+
 
 
 class Upload(Component):
 class Upload(Component):
     """A file upload component."""
     """A file upload component."""
@@ -192,5 +209,5 @@ class Upload(Component):
     @staticmethod
     @staticmethod
     def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
     def _get_app_wrap_components() -> dict[tuple[int, str], Component]:
         return {
         return {
-            (5, "UploadFilesProvider"): UploadFilesProvider(),
+            (5, "UploadFilesProvider"): UploadFilesProvider.create(),
         }
         }

+ 6 - 7
reflex/components/core/upload.pyi

@@ -7,7 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, ClassVar, Dict, List, Optional, Union
 from reflex import constants
 from reflex import constants
 from reflex.components.chakra.forms.input import Input
 from reflex.components.chakra.forms.input import Input
 from reflex.components.chakra.layout.box import Box
 from reflex.components.chakra.layout.box import Box
@@ -29,6 +29,8 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
 def cancel_upload(upload_id: str) -> EventSpec: ...
 def cancel_upload(upload_id: str) -> EventSpec: ...
 
 
 class UploadFilesProvider(Component):
 class UploadFilesProvider(Component):
+    is_used: ClassVar[bool] = False
+
     @overload
     @overload
     @classmethod
     @classmethod
     def create(  # type: ignore
     def create(  # type: ignore
@@ -87,7 +89,7 @@ class UploadFilesProvider(Component):
         ] = None,
         ] = None,
         **props
         **props
     ) -> "UploadFilesProvider":
     ) -> "UploadFilesProvider":
-        """Create the component.
+        """Create an UploadFilesProvider component.
 
 
         Args:
         Args:
             *children: The children of the component.
             *children: The children of the component.
@@ -97,13 +99,10 @@ class UploadFilesProvider(Component):
             class_name: The class name for the component.
             class_name: The class name for the component.
             autofocus: Whether the component should take the focus once the page is loaded
             autofocus: Whether the component should take the focus once the page is loaded
             custom_attrs: custom attribute
             custom_attrs: custom attribute
-            **props: The props of the component.
+            **props: The properties of the component.
 
 
         Returns:
         Returns:
-            The component.
-
-        Raises:
-            TypeError: If an invalid child is passed.
+            The UploadFilesProvider component.
         """
         """
         ...
         ...
 
 

+ 13 - 1
scripts/pyi_generator.py

@@ -245,7 +245,12 @@ def _extract_class_props_as_ast_nodes(
         # Import from the target class to ensure type hints are resolvable.
         # Import from the target class to ensure type hints are resolvable.
         exec(f"from {target_class.__module__} import *", type_hint_globals)
         exec(f"from {target_class.__module__} import *", type_hint_globals)
         for name, value in target_class.__annotations__.items():
         for name, value in target_class.__annotations__.items():
-            if name in spec.kwonlyargs or name in EXCLUDED_PROPS or name in all_props:
+            if (
+                name in spec.kwonlyargs
+                or name in EXCLUDED_PROPS
+                or name in all_props
+                or (isinstance(value, str) and "ClassVar" in value)
+            ):
                 continue
                 continue
             all_props.append(name)
             all_props.append(name)
 
 
@@ -559,6 +564,13 @@ class StubGenerator(ast.NodeTransformer):
         Returns:
         Returns:
             The modified AnnAssign node (or None).
             The modified AnnAssign node (or None).
         """
         """
+        # skip ClassVars
+        if (
+            isinstance(node.annotation, ast.Subscript)
+            and isinstance(node.annotation.value, ast.Name)
+            and node.annotation.value.id == "ClassVar"
+        ):
+            return node
         if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
         if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
             return None
             return None
         if self.current_class in self.classes:
         if self.current_class in self.classes: