|
@@ -9,11 +9,11 @@ import copy
|
|
|
import dataclasses
|
|
|
import functools
|
|
|
import inspect
|
|
|
-import io
|
|
|
import json
|
|
|
import multiprocessing
|
|
|
import platform
|
|
|
import sys
|
|
|
+import tempfile
|
|
|
import traceback
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
@@ -29,6 +29,7 @@ from typing import (
|
|
|
List,
|
|
|
MutableMapping,
|
|
|
Optional,
|
|
|
+ Sequence,
|
|
|
Set,
|
|
|
Type,
|
|
|
Union,
|
|
@@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
|
|
|
return JSONResponse(content=health_status, status_code=status_code)
|
|
|
|
|
|
|
|
|
+def _handle_temporary_upload_file(
|
|
|
+ upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
|
|
|
+) -> tempfile.SpooledTemporaryFile:
|
|
|
+ temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
|
|
|
+ temp_file.write(upload_file.file.read())
|
|
|
+ temp_file.seek(0)
|
|
|
+ return temp_file
|
|
|
+
|
|
|
+
|
|
|
+async def temporary_upload_tree(
|
|
|
+ token: str, files: List[UploadFile]
|
|
|
+) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
|
|
|
+ """Write the uploaded files to a temporary directory structure.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ token: The token to use for the temporary directory.
|
|
|
+ files: The files to write to the temporary directory.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ A list of the temporary files.
|
|
|
+ """
|
|
|
+ upload_dir = get_upload_dir()
|
|
|
+ upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
|
|
|
+ temp_files = []
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ temp_files = [
|
|
|
+ await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
|
|
|
+ for f in files
|
|
|
+ ]
|
|
|
+ try:
|
|
|
+ yield temp_files
|
|
|
+ finally:
|
|
|
+
|
|
|
+ def _cleanup():
|
|
|
+ for temp_file in temp_files:
|
|
|
+ temp_file.close()
|
|
|
+ temp_root.cleanup()
|
|
|
+
|
|
|
+ await loop.run_in_executor(None, _cleanup)
|
|
|
+
|
|
|
+
|
|
|
def upload(app: App):
|
|
|
"""Upload a file.
|
|
|
|
|
@@ -1563,24 +1606,21 @@ def upload(app: App):
|
|
|
# AsyncExitStack was removed from the request scope and is now
|
|
|
# part of the routing function which closes this before the
|
|
|
# event is handled.
|
|
|
- file_copies = []
|
|
|
- for file in files:
|
|
|
- content_copy = io.BytesIO()
|
|
|
- content_copy.write(await file.read())
|
|
|
- content_copy.seek(0)
|
|
|
- file_copies.append(
|
|
|
- UploadFile(
|
|
|
- file=content_copy,
|
|
|
- filename=file.filename,
|
|
|
- size=file.size,
|
|
|
- headers=file.headers,
|
|
|
- )
|
|
|
+ file_ctx = temporary_upload_tree(token, files)
|
|
|
+ temp_files = [
|
|
|
+ UploadFile(
|
|
|
+ file=tmp, # pyright: ignore[reportArgumentType]
|
|
|
+ filename=file.filename,
|
|
|
+ size=file.size,
|
|
|
+ headers=file.headers,
|
|
|
)
|
|
|
+ for file, tmp in zip(files, await anext(file_ctx), strict=True)
|
|
|
+ ]
|
|
|
|
|
|
event = Event(
|
|
|
token=token,
|
|
|
name=handler,
|
|
|
- payload={handler_upload_param[0]: file_copies},
|
|
|
+ payload={handler_upload_param[0]: temp_files},
|
|
|
)
|
|
|
|
|
|
async def _ndjson_updates():
|
|
@@ -1595,6 +1635,9 @@ def upload(app: App):
|
|
|
# Postprocess the event.
|
|
|
update = await app._postprocess(state, event, update)
|
|
|
yield update.json() + "\n"
|
|
|
+ # Clean up the temporary files.
|
|
|
+ async for _ in file_ctx:
|
|
|
+ pass
|
|
|
|
|
|
# Stream updates to client
|
|
|
return StreamingResponse(
|