소스 검색

Write uploaded file bytes to a SpooledTemporaryFile

Masen Furer 3 달 전
부모
커밋
7f0efa13e9
1개의 변경된 파일57개의 추가작업 그리고 14개의 파일을 삭제
  1. 57 14
      reflex/app.py

+ 57 - 14
reflex/app.py

@@ -9,11 +9,11 @@ import copy
 import dataclasses
 import dataclasses
 import functools
 import functools
 import inspect
 import inspect
-import io
 import json
 import json
 import multiprocessing
 import multiprocessing
 import platform
 import platform
 import sys
 import sys
+import tempfile
 import traceback
 import traceback
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
@@ -29,6 +29,7 @@ from typing import (
     List,
     List,
     MutableMapping,
     MutableMapping,
     Optional,
     Optional,
+    Sequence,
     Set,
     Set,
     Type,
     Type,
     Union,
     Union,
@@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
     return JSONResponse(content=health_status, status_code=status_code)
     return JSONResponse(content=health_status, status_code=status_code)
 
 
 
 
+def _handle_temporary_upload_file(
+    upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
+) -> tempfile.SpooledTemporaryFile:
+    temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
+    temp_file.write(upload_file.file.read())
+    temp_file.seek(0)
+    return temp_file
+
+
+async def temporary_upload_tree(
+    token: str, files: List[UploadFile]
+) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
+    """Write the uploaded files to a temporary directory structure.
+
+    Args:
+        token: The token to use for the temporary directory.
+        files: The files to write to the temporary directory.
+
+    Yields:
+        A list of the temporary files.
+    """
+    upload_dir = get_upload_dir()
+    upload_dir.mkdir(parents=True, exist_ok=True)
+    temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
+    temp_files = []
+    loop = asyncio.get_running_loop()
+    temp_files = [
+        await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
+        for f in files
+    ]
+    try:
+        yield temp_files
+    finally:
+
+        def _cleanup():
+            for temp_file in temp_files:
+                temp_file.close()
+            temp_root.cleanup()
+
+        await loop.run_in_executor(None, _cleanup)
+
+
 def upload(app: App):
 def upload(app: App):
     """Upload a file.
     """Upload a file.
 
 
@@ -1563,24 +1606,21 @@ def upload(app: App):
         # AsyncExitStack was removed from the request scope and is now
         # AsyncExitStack was removed from the request scope and is now
         # part of the routing function which closes this before the
         # part of the routing function which closes this before the
         # event is handled.
         # event is handled.
-        file_copies = []
-        for file in files:
-            content_copy = io.BytesIO()
-            content_copy.write(await file.read())
-            content_copy.seek(0)
-            file_copies.append(
-                UploadFile(
-                    file=content_copy,
-                    filename=file.filename,
-                    size=file.size,
-                    headers=file.headers,
-                )
+        file_ctx = temporary_upload_tree(token, files)
+        temp_files = [
+            UploadFile(
+                file=tmp,  # pyright: ignore[reportArgumentType]
+                filename=file.filename,
+                size=file.size,
+                headers=file.headers,
             )
             )
+            for file, tmp in zip(files, await anext(file_ctx), strict=True)
+        ]
 
 
         event = Event(
         event = Event(
             token=token,
             token=token,
             name=handler,
             name=handler,
-            payload={handler_upload_param[0]: file_copies},
+            payload={handler_upload_param[0]: temp_files},
         )
         )
 
 
         async def _ndjson_updates():
         async def _ndjson_updates():
@@ -1595,6 +1635,9 @@ def upload(app: App):
                     # Postprocess the event.
                     # Postprocess the event.
                     update = await app._postprocess(state, event, update)
                     update = await app._postprocess(state, event, update)
                     yield update.json() + "\n"
                     yield update.json() + "\n"
+            # Clean up the temporary files.
+            async for _ in file_ctx:
+                pass
 
 
         # Stream updates to client
         # Stream updates to client
         return StreamingResponse(
         return StreamingResponse(