pandas_data_accessor.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # Copyright 2021-2025 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import typing as t
  13. from datetime import datetime
  14. from importlib import util
  15. from tempfile import mkstemp
  16. import numpy as np
  17. import pandas as pd
  18. from pandas.api.types import is_numeric_dtype
  19. from .._warnings import _warn
  20. from ..gui import Gui
  21. from ..types import PropertyType
  22. from ..utils import _RE_PD_TYPE, _get_date_col_str_name
  23. from .comparison import _compare_function
  24. from .data_accessor import _DataAccessor
  25. from .data_format import _DataFormat
  26. _has_arrow_module = False
  27. if util.find_spec("pyarrow"):
  28. _has_arrow_module = True
  29. import pyarrow as pa
  30. _ORIENT_TYPE = t.Literal["records", "list"]
  31. class _PandasDataAccessor(_DataAccessor):
  32. __types = (pd.DataFrame, pd.Series)
  33. __INDEX_COL = "_tp_index"
  34. __AGGREGATE_FUNCTIONS: t.List[str] = ["count", "sum", "mean", "median", "min", "max", "std", "first", "last"]
  35. @staticmethod
  36. def get_supported_classes() -> t.List[t.Type]:
  37. return list(_PandasDataAccessor.__types)
  38. def to_pandas(self, value: t.Union[pd.DataFrame, pd.Series]) -> t.Union[t.List[pd.DataFrame], pd.DataFrame]:
  39. return self._to_dataframe(value)
  40. def _to_dataframe(self, value: t.Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
  41. if isinstance(value, pd.Series):
  42. return pd.DataFrame(value)
  43. return t.cast(pd.DataFrame, value)
  44. def _from_pandas(self, value: pd.DataFrame, data_type: t.Type) -> t.Any:
  45. if data_type is pd.Series:
  46. return value.iloc[:, 0]
  47. return value
  48. @staticmethod
  49. def __user_function(
  50. row: pd.Series, gui: Gui, column_name: t.Optional[str], user_function: t.Callable, function_name: str
  51. ) -> str: # pragma: no cover
  52. args = []
  53. if column_name:
  54. args.append(row[column_name])
  55. args.extend((row.name, row)) # type: ignore[arg-type]
  56. if column_name:
  57. args.append(column_name) # type: ignore[arg-type]
  58. try:
  59. return str(gui._call_function_with_state(user_function, args))
  60. except Exception as e:
  61. _warn(f"Exception raised when calling user function {function_name}()", e)
  62. return ""
  63. def __get_column_names(self, df: pd.DataFrame, *cols: str):
  64. col_names = [t for t in df.columns if str(t) in cols]
  65. return (col_names[0] if len(cols) == 1 else col_names) if col_names else None
  66. def get_dataframe_with_cols(self, df: pd.DataFrame, cols: t.List[str]) -> pd.DataFrame:
  67. return df.loc[:, df.dtypes[df.columns.astype(str).isin(cols)].index] # type: ignore[index]
  68. def __build_transferred_cols( # noqa: C901
  69. self,
  70. payload_cols: t.Any,
  71. dataframe: pd.DataFrame,
  72. styles: t.Optional[t.Dict[str, str]] = None,
  73. tooltips: t.Optional[t.Dict[str, str]] = None,
  74. is_copied: t.Optional[bool] = False,
  75. new_indexes: t.Optional[np.ndarray] = None,
  76. handle_nan: t.Optional[bool] = False,
  77. formats: t.Optional[t.Dict[str, str]] = None,
  78. ) -> pd.DataFrame:
  79. dataframe = dataframe.iloc[new_indexes] if new_indexes is not None else dataframe
  80. if isinstance(payload_cols, list) and len(payload_cols):
  81. cols_description = {k: v for k, v in self.get_cols_description("", dataframe).items() if k in payload_cols}
  82. else:
  83. cols_description = self.get_cols_description("", dataframe)
  84. cols = list(cols_description.keys())
  85. new_cols = {}
  86. if styles:
  87. for k, v in styles.items():
  88. col_applied = ""
  89. new_data = None
  90. func = self._gui._get_user_function(v)
  91. if callable(func):
  92. col_applied, new_data = self.__apply_user_function(
  93. func, k if k in cols else None, v, dataframe, "tps__"
  94. )
  95. new_cols[col_applied or v] = new_data if col_applied else v
  96. if tooltips:
  97. for k, v in tooltips.items():
  98. func = self._gui._get_user_function(v)
  99. if callable(func):
  100. col_applied, new_data = self.__apply_user_function(
  101. func, k if k in cols else None, v, dataframe, "tpt__"
  102. )
  103. if col_applied:
  104. new_cols[col_applied] = new_data
  105. if formats:
  106. for k, v in formats.items():
  107. func = self._gui._get_user_function(v)
  108. if callable(func):
  109. col_applied, new_data = self.__apply_user_function(
  110. func, k if k in cols else None, v, dataframe, "tpf__"
  111. )
  112. if col_applied:
  113. new_cols[col_applied] = new_data
  114. # deal with dates
  115. date_cols = [c for c, d in cols_description.items() if d.get("type", "").startswith("datetime")]
  116. if len(date_cols) != 0:
  117. if not is_copied:
  118. # copy the df so that we don't "mess" with the user's data
  119. dataframe = dataframe.copy()
  120. tz = Gui._get_timezone()
  121. for col in date_cols:
  122. col_name = self.__get_column_names(dataframe, col)
  123. new_col = _get_date_col_str_name(cols, col)
  124. re_type = _RE_PD_TYPE.match(cols_description[col].get("type", ""))
  125. groups = re_type.groups() if re_type else ()
  126. if len(groups) > 4 and groups[4]:
  127. new_cols[new_col] = (
  128. dataframe[col_name]
  129. .dt.tz_convert("UTC")
  130. .dt.strftime(_DataAccessor._WS_DATE_FORMAT)
  131. .astype(str)
  132. .replace("nan", "NaT" if handle_nan else None)
  133. )
  134. else:
  135. new_cols[new_col] = (
  136. dataframe[col_name]
  137. .dt.tz_localize(tz)
  138. .dt.tz_convert("UTC")
  139. .dt.strftime(_DataAccessor._WS_DATE_FORMAT)
  140. .astype(str)
  141. .replace("nan", "NaT" if handle_nan else None)
  142. )
  143. # remove the date columns from the list of columns
  144. cols = list(set(cols) - set(date_cols))
  145. if new_cols:
  146. dataframe = dataframe.assign(**new_cols)
  147. cols += list(new_cols.keys())
  148. return self.get_dataframe_with_cols(dataframe, cols)
  149. def __apply_user_function(
  150. self,
  151. user_function: t.Callable,
  152. column_name: t.Optional[str],
  153. function_name: str,
  154. data: pd.DataFrame,
  155. prefix: t.Optional[str],
  156. ):
  157. try:
  158. new_col_name = f"{prefix}{column_name}__{function_name}" if column_name else function_name
  159. return new_col_name, data.apply(
  160. _PandasDataAccessor.__user_function,
  161. axis=1,
  162. args=(
  163. self._gui,
  164. self.__get_column_names(data, column_name) if column_name else column_name,
  165. user_function,
  166. function_name,
  167. ),
  168. )
  169. except Exception as e:
  170. _warn(f"Exception raised when invoking user function {function_name}()", e)
  171. return "", data
  172. def _format_data(
  173. self,
  174. data: pd.DataFrame,
  175. data_format: _DataFormat,
  176. orient: _ORIENT_TYPE,
  177. start: t.Optional[int] = None,
  178. rowcount: t.Optional[int] = None,
  179. data_extraction: t.Optional[bool] = None,
  180. handle_nan: t.Optional[bool] = False,
  181. fullrowcount: t.Optional[int] = None,
  182. ) -> t.Dict[str, t.Any]:
  183. ret: t.Dict[str, t.Any] = {
  184. "format": str(data_format.value),
  185. }
  186. if rowcount is not None:
  187. ret["rowcount"] = rowcount
  188. if fullrowcount is not None and fullrowcount != rowcount:
  189. ret["fullrowcount"] = fullrowcount
  190. if start is not None:
  191. ret["start"] = start
  192. if data_extraction is not None:
  193. ret["dataExtraction"] = data_extraction # Extract data out of dictionary on front-end
  194. if data_format is _DataFormat.APACHE_ARROW:
  195. if not _has_arrow_module:
  196. raise RuntimeError("Cannot use Arrow as pyarrow package is not installed")
  197. # Convert from pandas to Arrow
  198. table = pa.Table.from_pandas(data) # type: ignore[reportPossiblyUnboundVariable]
  199. # Create sink buffer stream
  200. sink = pa.BufferOutputStream() # type: ignore[reportPossiblyUnboundVariable]
  201. # Create Stream writer
  202. writer = pa.ipc.new_stream(sink, table.schema) # type: ignore[reportPossiblyUnboundVariable]
  203. # Write data to table
  204. writer.write_table(table)
  205. writer.close()
  206. # End buffer stream
  207. buf = sink.getvalue()
  208. # Convert buffer to Python bytes and return
  209. ret["data"] = buf.to_pybytes()
  210. ret["orient"] = orient
  211. else:
  212. # Workaround for Python built in JSON encoder that does not yet support ignore_nan
  213. ret["data"] = self.get_json_ready_dict(data.replace([np.nan, pd.NA], [None, None]), orient)
  214. return ret
  215. def get_json_ready_dict(self, df: pd.DataFrame, orient: _ORIENT_TYPE) -> t.Dict[t.Hashable, t.Any]:
  216. return df.to_dict(orient=orient) # type: ignore[return-value]
  217. def get_cols_description(self, var_name: str, value: t.Any) -> t.Dict[str, t.Dict[str, str]]:
  218. if isinstance(value, list):
  219. ret_dict: t.Dict[str, t.Dict[str, str]] = {}
  220. for i, v in enumerate(value):
  221. res = self.get_cols_description("", v)
  222. if res:
  223. ret_dict.update({f"{i}/{k}": desc for k, desc in res.items()})
  224. return ret_dict
  225. df = self._to_dataframe(value)
  226. return {str(k): {"type": v} for k, v in df.dtypes.apply(lambda x: x.name.lower()).items()}
  227. def add_optional_columns(self, df: pd.DataFrame, columns: t.List[str]) -> t.Tuple[pd.DataFrame, t.List[str]]:
  228. return df, []
  229. def is_dataframe_supported(self, df: pd.DataFrame) -> bool:
  230. return not isinstance(df.columns, pd.MultiIndex)
  231. def __get_data( # noqa: C901
  232. self,
  233. var_name: str,
  234. df: pd.DataFrame,
  235. payload: t.Dict[str, t.Any],
  236. data_format: _DataFormat,
  237. col_prefix: t.Optional[str] = "",
  238. ) -> t.Dict[str, t.Any]:
  239. ret_payload = {"pagekey": payload.get("pagekey", "unknown page")}
  240. if not self.is_dataframe_supported(df):
  241. ret_payload["value"] = {}
  242. ret_payload["error"] = "MultiIndex columns are not supported."
  243. _warn("MultiIndex columns are not supported.")
  244. return ret_payload
  245. columns = payload.get("columns", [])
  246. if col_prefix:
  247. columns = [c[len(col_prefix) :] if c.startswith(col_prefix) else c for c in columns]
  248. paged = not payload.get("alldata", False)
  249. is_copied = False
  250. orig_df = df
  251. # add index if not chart
  252. if paged:
  253. if _PandasDataAccessor.__INDEX_COL not in df.columns:
  254. is_copied = True
  255. df = df.assign(**{_PandasDataAccessor.__INDEX_COL: df.index.to_numpy()})
  256. if columns and _PandasDataAccessor.__INDEX_COL not in columns:
  257. columns.append(_PandasDataAccessor.__INDEX_COL)
  258. # optional columns
  259. df, optional_columns = self.add_optional_columns(df, columns)
  260. is_copied = is_copied or bool(optional_columns)
  261. fullrowcount = len(df)
  262. # filtering
  263. filters = payload.get("filters")
  264. if isinstance(filters, list) and len(filters) > 0:
  265. query = ""
  266. vars = []
  267. cols_description = self.get_cols_description(var_name, df)
  268. for fd in filters:
  269. col = fd.get("col")
  270. val = fd.get("value")
  271. action = fd.get("action")
  272. match_case = fd.get("matchCase", False) is not False # Ensure it's a boolean
  273. right = None
  274. col_expr = f"`{col}`"
  275. if isinstance(val, str):
  276. if cols_description.get(col, {}).get("type", "").startswith("datetime"):
  277. val = datetime.fromisoformat(val[:-1])
  278. elif not match_case:
  279. if action != "contains":
  280. col_expr = f"{col_expr}.str.lower()"
  281. val = val.lower()
  282. vars.append(val)
  283. val_var = f"@vars[{len(vars) - 1}]"
  284. if action == "contains":
  285. right = f".str.contains({val_var}{'' if match_case else ', case=False'})"
  286. else:
  287. vars.append(val)
  288. val_var = f"@vars[{len(vars) - 1}]"
  289. if right is None:
  290. right = f" {action} {val_var}"
  291. if query:
  292. query += " and "
  293. query += f"{col_expr}{right}"
  294. # Apply filters using df.query()
  295. try:
  296. if query:
  297. df = df.query(query)
  298. is_copied = True
  299. except Exception as e:
  300. _warn(f"Dataframe filtering: invalid query '{query}' on {df.head()}", e)
  301. dict_ret: t.Optional[t.Dict[str, t.Any]]
  302. if paged:
  303. aggregates = payload.get("aggregates")
  304. applies = payload.get("applies")
  305. if isinstance(aggregates, list) and len(aggregates) and isinstance(applies, dict):
  306. applies_with_fn = {
  307. self.__get_column_names(df, k): v
  308. if v in _PandasDataAccessor.__AGGREGATE_FUNCTIONS
  309. else self._gui._get_user_function(v)
  310. for k, v in applies.items()
  311. }
  312. for col in df.columns:
  313. if col not in applies_with_fn:
  314. applies_with_fn[col] = "first"
  315. try:
  316. col_names = self.__get_column_names(df, *aggregates)
  317. if col_names:
  318. df = t.cast(pd.DataFrame, df).groupby(aggregates).agg(applies_with_fn)
  319. else:
  320. raise Exception()
  321. except Exception:
  322. _warn(f"Cannot aggregate {var_name} with groupby {aggregates} and aggregates {applies}.")
  323. inf = payload.get("infinite")
  324. if inf is not None:
  325. ret_payload["infinite"] = inf
  326. # real number of rows is needed to calculate the number of pages
  327. rowcount = len(df)
  328. # here we'll deal with start and end values from payload if present
  329. if isinstance(payload.get("start", 0), int):
  330. start = int(payload.get("start", 0))
  331. else:
  332. try:
  333. start = int(str(payload["start"]), base=10)
  334. except Exception:
  335. _warn(f'start should be an int value {payload["start"]}.')
  336. start = 0
  337. if isinstance(payload.get("end", -1), int):
  338. end = int(payload.get("end", -1))
  339. else:
  340. try:
  341. end = int(str(payload["end"]), base=10)
  342. except Exception:
  343. end = -1
  344. if start < 0 or start >= rowcount:
  345. start = 0
  346. if end < 0 or end >= rowcount:
  347. end = rowcount - 1
  348. if payload.get("reverse", False):
  349. diff = end - start
  350. end = rowcount - 1 - start
  351. if end < 0:
  352. end = rowcount - 1
  353. start = end - diff
  354. if start < 0:
  355. start = 0
  356. # deal with sort
  357. order_by = payload.get("orderby")
  358. if isinstance(order_by, str) and len(order_by):
  359. try:
  360. col_name = self.__get_column_names(df, order_by)
  361. if col_name:
  362. new_indexes = t.cast(pd.DataFrame, df)[col_name].values.argsort(axis=0)
  363. if payload.get("sort") == "desc":
  364. # reverse order
  365. new_indexes = new_indexes[::-1]
  366. new_indexes = new_indexes[slice(start, end + 1)]
  367. else:
  368. raise Exception()
  369. except Exception:
  370. _warn(f"Cannot sort {var_name} on columns {order_by}.")
  371. new_indexes = slice(start, end + 1) # type: ignore
  372. else:
  373. new_indexes = slice(start, end + 1) # type: ignore
  374. df = self.__build_transferred_cols(
  375. columns + optional_columns,
  376. t.cast(pd.DataFrame, df),
  377. styles=payload.get("styles"),
  378. tooltips=payload.get("tooltips"),
  379. is_copied=is_copied,
  380. new_indexes=t.cast(np.ndarray, new_indexes),
  381. handle_nan=payload.get("handlenan", False),
  382. formats=payload.get("formats"),
  383. )
  384. dict_ret = self._format_data(
  385. df,
  386. data_format,
  387. "records",
  388. start,
  389. rowcount,
  390. handle_nan=payload.get("handlenan", False),
  391. fullrowcount=fullrowcount,
  392. )
  393. compare = payload.get("compare")
  394. if isinstance(compare, str):
  395. comp_df = _compare_function(
  396. self._gui, compare, var_name, t.cast(pd.DataFrame, orig_df), payload.get("compare_datas", "")
  397. )
  398. if isinstance(comp_df, pd.DataFrame) and not comp_df.empty:
  399. try:
  400. if isinstance(comp_df.columns[0], tuple):
  401. cols: t.List[t.Hashable] = [c for c in comp_df.columns if c[1] == "other"]
  402. comp_df = t.cast(pd.DataFrame, comp_df.get(cols))
  403. comp_df.columns = t.cast(pd.Index, [t.cast(tuple, c)[0] for c in cols])
  404. comp_df.dropna(axis=1, how="all", inplace=True)
  405. comp_df = self.__build_transferred_cols(
  406. columns, comp_df, new_indexes=t.cast(np.ndarray, new_indexes)
  407. )
  408. dict_ret["comp"] = self._format_data(comp_df, data_format, "records").get("data")
  409. except Exception as e:
  410. _warn("Pandas accessor compare raised an exception", e)
  411. else:
  412. ret_payload["alldata"] = True
  413. decimator_payload: t.Dict[str, t.Any] = payload.get("decimatorPayload", {})
  414. decimators = decimator_payload.get("decimators", [])
  415. decimated_dfs: t.List[pd.DataFrame] = []
  416. for decimator_pl in decimators:
  417. if decimator_pl is None:
  418. continue
  419. decimator = decimator_pl.get("decimator")
  420. if decimator is None:
  421. x_column = decimator_pl.get("xAxis", "")
  422. y_column = decimator_pl.get("yAxis", "")
  423. z_column = decimator_pl.get("zAxis", "")
  424. filtered_columns = [x_column, y_column, z_column] if z_column else [x_column, y_column]
  425. decimated_df = df.copy().filter(filtered_columns, axis=1)
  426. decimated_dfs.append(decimated_df)
  427. continue
  428. decimator_instance = (
  429. self._gui._get_user_instance(decimator, PropertyType.decimator.value)
  430. if decimator is not None
  431. else None
  432. )
  433. if isinstance(decimator_instance, PropertyType.decimator.value):
  434. # Run the on_decimate method -> check if the decimator should be applied
  435. # -> apply the decimator
  436. decimated_df, is_decimator_applied, is_copied = decimator_instance._on_decimate(
  437. df, decimator_pl, decimator_payload, is_copied
  438. )
  439. # add decimated dataframe to the list of decimated
  440. decimated_dfs.append(decimated_df)
  441. if is_decimator_applied:
  442. self._gui._call_on_change(f"{var_name}.{decimator}.nb_rows", len(decimated_df))
  443. # merge the decimated dataFrames
  444. if len(decimated_dfs) > 1:
  445. # get the unique columns from all decimated dataFrames
  446. decimated_columns = pd.Index([])
  447. for _df in decimated_dfs:
  448. decimated_columns = decimated_columns.append(_df.columns)
  449. # find the columns that are duplicated across dataFrames
  450. overlapping_columns = decimated_columns[decimated_columns.duplicated()].unique()
  451. # concatenate the dataFrames without overwriting columns
  452. merged_df = pd.concat(decimated_dfs, axis=1)
  453. # resolve overlapping columns by combining values
  454. for col in overlapping_columns:
  455. # for each overlapping column, combine the values across dataFrames
  456. # (e.g., take the first non-null value)
  457. cols_to_combine = merged_df.loc[:, col].columns
  458. merged_df[col] = merged_df[cols_to_combine].bfill(axis=1).iloc[:, 0]
  459. # drop duplicated col since they are now the same
  460. df = merged_df.loc[:, ~merged_df.columns.duplicated()]
  461. elif len(decimated_dfs) == 1:
  462. df = decimated_dfs[0]
  463. if data_format is _DataFormat.CSV:
  464. df = self.__build_transferred_cols(
  465. columns,
  466. t.cast(pd.DataFrame, df),
  467. is_copied=is_copied,
  468. handle_nan=payload.get("handlenan", False),
  469. )
  470. ret_payload["df"] = df
  471. dict_ret = None
  472. else:
  473. df = self.__build_transferred_cols(
  474. columns,
  475. t.cast(pd.DataFrame, df),
  476. styles=payload.get("styles"),
  477. tooltips=payload.get("tooltips"),
  478. is_copied=is_copied,
  479. handle_nan=payload.get("handlenan", False),
  480. formats=payload.get("formats"),
  481. )
  482. dict_ret = self._format_data(df, data_format, "list", data_extraction=True)
  483. ret_payload["value"] = dict_ret
  484. return ret_payload
  485. def get_data(
  486. self, var_name: str, value: t.Any, payload: t.Dict[str, t.Any], data_format: _DataFormat
  487. ) -> t.Dict[str, t.Any]:
  488. if isinstance(value, list):
  489. # If is_chart data
  490. if payload.get("alldata", False):
  491. ret_payload = {
  492. "alldata": True,
  493. "value": {"multi": True},
  494. "pagekey": payload.get("pagekey", "unknown page"),
  495. }
  496. data = []
  497. for i, v in enumerate(value):
  498. ret = (
  499. self.__get_data(var_name, self._to_dataframe(v), payload, data_format, f"{i}/")
  500. if isinstance(v, _PandasDataAccessor.__types)
  501. else {}
  502. )
  503. ret_val = ret.get("value", {})
  504. data.append(ret_val.pop("data", None))
  505. ret_payload.get("value", {}).update(ret_val)
  506. ret_payload["value"]["data"] = data
  507. return ret_payload
  508. else:
  509. value = value[0]
  510. return self.__get_data(var_name, self._to_dataframe(value), payload, data_format)
  511. def _get_index_value(self, index: t.Any) -> t.Any:
  512. return tuple(index) if isinstance(index, list) else index
  513. def on_edit(self, value: t.Any, payload: t.Dict[str, t.Any]):
  514. df = self.to_pandas(value)
  515. if not isinstance(df, pd.DataFrame) or not isinstance(payload.get("index"), (int, float)):
  516. raise ValueError(f"Cannot edit {type(value)} at {payload.get('index')}.")
  517. df.at[self._get_index_value(payload.get("index", 0)), payload["col"]] = payload["value"]
  518. return self._from_pandas(df, type(value))
  519. def on_delete(self, value: t.Any, payload: t.Dict[str, t.Any]):
  520. df = self.to_pandas(value)
  521. if not isinstance(df, pd.DataFrame) or not isinstance(payload.get("index"), (int, float)):
  522. raise ValueError(f"Cannot delete a row from {type(value)} at {payload.get('index')}.")
  523. return self._from_pandas(df.drop(self._get_index_value(payload.get("index", 0))), type(value))
  524. def on_add(self, value: t.Any, payload: t.Dict[str, t.Any], new_row: t.Optional[t.List[t.Any]] = None):
  525. df = self.to_pandas(value)
  526. if not isinstance(df, pd.DataFrame) or not isinstance(payload.get("index"), (int, float)):
  527. raise ValueError(f"Cannot add a row to {type(value)} at {payload.get('index')}.")
  528. # Save the insertion index
  529. index = payload.get("index", 0)
  530. # Create the new row (Column value types must match the original DataFrame's)
  531. if list(df.columns):
  532. new_row = [0 if is_numeric_dtype(dt) else "" for dt in df.dtypes] if new_row is None else new_row
  533. if index > 0:
  534. # Column names and value types must match the original DataFrame
  535. new_df = pd.DataFrame([new_row], columns=df.columns.copy())
  536. # Split the DataFrame
  537. rows_before = df.iloc[:index]
  538. rows_after = df.iloc[index:]
  539. return self._from_pandas(pd.concat([rows_before, new_df, rows_after], ignore_index=True), type(value))
  540. else:
  541. df = df.copy()
  542. # Insert as the new first row
  543. df.loc[-1] = new_row # Insert the new row
  544. df.index = df.index + 1 # Shift index
  545. return self._from_pandas(df.sort_index(), type(value))
  546. return value
  547. def to_csv(self, var_name: str, value: t.Any):
  548. df = self.to_pandas(value)
  549. if not isinstance(df, pd.DataFrame):
  550. raise ValueError(f"Cannot export {type(value)} to csv.")
  551. dict_ret = self.__get_data(var_name, df, {"alldata": True}, _DataFormat.CSV)
  552. if isinstance(dict_ret, dict):
  553. dfr = dict_ret.get("df")
  554. if isinstance(dfr, pd.DataFrame):
  555. fd, temp_path = mkstemp(".csv", var_name, text=True)
  556. with os.fdopen(fd, "wt", newline="") as csv_file:
  557. dfr.to_csv(csv_file, index=False)
  558. return temp_path
  559. return None