Bladeren bron

fix upload on drop (#5160)

Khaleel Al-Adhami 3 weken geleden
bovenliggende
commit
f4364afd11
3 gewijzigde bestanden met toevoegingen van 98 en 18 verwijderingen
  1. 8 0
      reflex/components/core/upload.py
  2. 81 9
      tests/integration/test_upload.py
  3. 9 9
      tests/units/states/upload.py

+ 8 - 0
reflex/components/core/upload.py

@@ -273,6 +273,14 @@ class Upload(MemoizationLeaf):
             elif isinstance(on_drop, Callable):
                 # Call the lambda to get the event chain.
                 on_drop = call_event_fn(on_drop, _on_drop_spec)
+            if isinstance(on_drop, EventSpec):
+                # Update the provided args for direct use with on_drop.
+                on_drop = on_drop.with_args(
+                    args=tuple(
+                        cls._update_arg_tuple_for_on_drop(arg_value)
+                        for arg_value in on_drop.args
+                    ),
+                )
             upload_props["on_drop"] = on_drop
 
         input_props_unique_name = get_unique_variable_name()

+ 81 - 9
tests/integration/test_upload.py

@@ -26,27 +26,32 @@ def UploadFile():
     class UploadState(rx.State):
         _file_data: dict[str, str] = {}
         event_order: rx.Field[list[str]] = rx.field([])
-        progress_dicts: list[dict] = []
-        disabled: bool = False
-        large_data: str = ""
+        progress_dicts: rx.Field[list[dict]] = rx.field([])
+        disabled: rx.Field[bool] = rx.field(False)
+        large_data: rx.Field[str] = rx.field("")
+        quaternary_names: rx.Field[list[str]] = rx.field([])
 
+        @rx.event
         async def handle_upload(self, files: list[rx.UploadFile]):
             for file in files:
                 upload_data = await file.read()
-                self._file_data[file.filename or ""] = upload_data.decode("utf-8")
+                self._file_data[file.name or ""] = upload_data.decode("utf-8")
 
+        @rx.event
         async def handle_upload_secondary(self, files: list[rx.UploadFile]):
             for file in files:
                 upload_data = await file.read()
-                self._file_data[file.filename or ""] = upload_data.decode("utf-8")
+                self._file_data[file.name or ""] = upload_data.decode("utf-8")
                 self.large_data = LARGE_DATA
                 yield UploadState.chain_event
 
+        @rx.event
         def upload_progress(self, progress):
             assert progress
             self.event_order.append("upload_progress")
             self.progress_dicts.append(progress)
 
+        @rx.event
         def chain_event(self):
             assert self.large_data == LARGE_DATA
             self.large_data = ""
@@ -55,10 +60,14 @@ def UploadFile():
         @rx.event
         async def handle_upload_tertiary(self, files: list[rx.UploadFile]):
             for file in files:
-                (rx.get_upload_dir() / (file.filename or "INVALID")).write_bytes(
+                (rx.get_upload_dir() / (file.name or "INVALID")).write_bytes(
                     await file.read()
                 )
 
+        @rx.event
+        async def handle_upload_quaternary(self, files: list[rx.UploadFile]):
+            self.quaternary_names = [file.name for file in files if file.name]
+
         @rx.event
         def do_download(self):
             return rx.download(rx.get_upload_url("test.txt"))
@@ -80,7 +89,7 @@ def UploadFile():
             ),
             rx.button(
                 "Upload",
-                on_click=lambda: UploadState.handle_upload(rx.upload_files()),  # pyright: ignore [reportCallIssue]
+                on_click=lambda: UploadState.handle_upload(rx.upload_files()),  # pyright: ignore [reportArgumentType]
                 id="upload_button",
             ),
             rx.box(
@@ -105,8 +114,8 @@ def UploadFile():
             ),
             rx.button(
                 "Upload",
-                on_click=UploadState.handle_upload_secondary(  # pyright: ignore [reportCallIssue]
-                    rx.upload_files(
+                on_click=UploadState.handle_upload_secondary(
+                    rx.upload_files(  # pyright: ignore [reportArgumentType]
                         upload_id="secondary",
                         on_upload_progress=UploadState.upload_progress,
                     ),
@@ -163,6 +172,22 @@ def UploadFile():
                 on_click=UploadState.do_download,
                 id="download-backend",
             ),
+            rx.upload.root(
+                rx.vstack(
+                    rx.button("Select File"),
+                    rx.text("Drag and drop files here or click to select files"),
+                ),
+                on_drop=UploadState.handle_upload_quaternary(
+                    rx.upload_files(  # pyright: ignore [reportArgumentType]
+                        upload_id="quaternary",
+                    ),
+                ),
+                id="quaternary",
+            ),
+            rx.text(
+                UploadState.quaternary_names.to_string(),
+                id="quaternary_files",
+            ),
             rx.text(UploadState.event_order.to_string(), id="event-order"),
         )
 
@@ -501,3 +526,50 @@ async def test_upload_download_file(
         download_backend.click()
     assert urlsplit(driver.current_url).path == f"/{Endpoint.UPLOAD.value}/test.txt"
     assert driver.find_element(by=By.TAG_NAME, value="body").text == exp_contents
+
+
+@pytest.mark.asyncio
+async def test_on_drop(
+    tmp_path,
+    upload_file: AppHarness,
+    driver: WebDriver,
+):
+    """Test the on_drop event handler.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+        upload_file: harness for UploadFile app.
+        driver: WebDriver instance.
+    """
+    assert upload_file.app_instance is not None
+    token = poll_for_token(driver, upload_file)
+    full_state_name = upload_file.get_full_state_name(["_upload_state"])
+    state_name = upload_file.get_state_name("_upload_state")
+    substate_token = f"{token}_{full_state_name}"
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
+        3
+    ]  # quaternary upload
+    assert upload_box
+
+    exp_name = "drop_test.txt"
+    exp_contents = "dropped file contents!"
+    target_file = tmp_path / exp_name
+    target_file.write_text(exp_contents)
+
+    # Simulate file drop by directly setting the file input
+    upload_box.send_keys(str(target_file))
+
+    # Wait for the on_drop event to be processed
+    await asyncio.sleep(0.5)
+
+    async def exp_name_in_quaternary():
+        state = await upload_file.get_state(substate_token)
+        return exp_name in state.substates[state_name].quaternary_names
+
+    # Poll until the file names appear in the display
+    await AppHarness._poll_for_async(exp_name_in_quaternary)
+
+    # Verify through state that the file names were captured correctly
+    state = await upload_file.get_state(substate_token)
+    assert exp_name in state.substates[state_name].quaternary_names

+ 9 - 9
tests/units/states/upload.py

@@ -59,14 +59,14 @@ class FileUploadState(State):
         """
         for file in files:
             upload_data = await file.read()
-            assert file.filename is not None
-            outfile = self._tmp_path / file.filename
+            assert file.name is not None
+            outfile = self._tmp_path / file.name
 
             # Save the file.
             outfile.write_bytes(upload_data)
 
             # Update the img var.
-            self.img_list.append(file.filename)
+            self.img_list.append(file.name)
 
     @rx.event(background=True)
     async def bg_upload(self, files: list[rx.UploadFile]):
@@ -106,14 +106,14 @@ class ChildFileUploadState(FileStateBase1):
         """
         for file in files:
             upload_data = await file.read()
-            assert file.filename is not None
-            outfile = self._tmp_path / file.filename
+            assert file.name is not None
+            outfile = self._tmp_path / file.name
 
             # Save the file.
             outfile.write_bytes(upload_data)
 
             # Update the img var.
-            self.img_list.append(file.filename)
+            self.img_list.append(file.name)
 
     @rx.event(background=True)
     async def bg_upload(self, files: list[rx.UploadFile]):
@@ -153,14 +153,14 @@ class GrandChildFileUploadState(FileStateBase2):
         """
         for file in files:
             upload_data = await file.read()
-            assert file.filename is not None
-            outfile = self._tmp_path / file.filename
+            assert file.name is not None
+            outfile = self._tmp_path / file.name
 
             # Save the file.
             outfile.write_bytes(upload_data)
 
             # Update the img var.
-            self.img_list.append(file.filename)
+            self.img_list.append(file.name)
 
     @rx.event(background=True)
     async def bg_upload(self, files: list[rx.UploadFile]):