pandas_data_accessor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import os
  12. import typing as t
  13. from datetime import datetime
  14. from importlib import util
  15. from tempfile import mkstemp
  16. import numpy as np
  17. import pandas as pd
  18. from pandas.api.types import is_numeric_dtype
  19. from .._warnings import _warn
  20. from ..gui import Gui
  21. from ..types import PropertyType
  22. from ..utils import _RE_PD_TYPE, _get_date_col_str_name
  23. from .comparison import _compare_function
  24. from .data_accessor import _DataAccessor
  25. from .data_format import _DataFormat
  26. from .utils import _df_data_filter, _df_relayout
  27. _has_arrow_module = False
  28. if util.find_spec("pyarrow"):
  29. _has_arrow_module = True
  30. import pyarrow as pa
  31. class _PandasDataAccessor(_DataAccessor):
  32. __types = (pd.DataFrame, pd.Series)
  33. __INDEX_COL = "_tp_index"
  34. __AGGREGATE_FUNCTIONS: t.List[str] = ["count", "sum", "mean", "median", "min", "max", "std", "first", "last"]
  35. def to_pandas(self, value: t.Union[pd.DataFrame, pd.Series]) -> t.Union[t.List[pd.DataFrame], pd.DataFrame]:
  36. return self.__to_dataframe(value)
  37. def __to_dataframe(self, value: t.Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
  38. if isinstance(value, pd.Series):
  39. return pd.DataFrame(value)
  40. return t.cast(pd.DataFrame, value)
  41. def _from_pandas(self, value: pd.DataFrame, data_type: t.Type):
  42. if data_type is pd.Series:
  43. return value.iloc[:, 0]
  44. return value
  45. @staticmethod
  46. def get_supported_classes() -> t.List[t.Type]:
  47. return list(_PandasDataAccessor.__types)
  48. @staticmethod
  49. def __user_function(
  50. row: pd.Series, gui: Gui, column_name: t.Optional[str], user_function: t.Callable, function_name: str
  51. ) -> str: # pragma: no cover
  52. args = []
  53. if column_name:
  54. args.append(row[column_name])
  55. args.extend((row.name, row)) # type: ignore[arg-type]
  56. if column_name:
  57. args.append(column_name) # type: ignore[arg-type]
  58. try:
  59. return str(gui._call_function_with_state(user_function, args))
  60. except Exception as e:
  61. _warn(f"Exception raised when calling user function {function_name}()", e)
  62. return ""
  63. def __is_date_column(self, data: pd.DataFrame, col_name: str) -> bool:
  64. col_types = data.dtypes[data.dtypes.index.astype(str) == col_name]
  65. return len(col_types[col_types.astype(str).str.startswith("datetime")]) > 0 # type: ignore
  66. def __build_transferred_cols(
  67. self,
  68. payload_cols: t.Any,
  69. dataframe: pd.DataFrame,
  70. styles: t.Optional[t.Dict[str, str]] = None,
  71. tooltips: t.Optional[t.Dict[str, str]] = None,
  72. is_copied: t.Optional[bool] = False,
  73. new_indexes: t.Optional[np.ndarray] = None,
  74. handle_nan: t.Optional[bool] = False,
  75. ) -> pd.DataFrame:
  76. if isinstance(payload_cols, list) and len(payload_cols):
  77. col_types = dataframe.dtypes[dataframe.dtypes.index.astype(str).isin(payload_cols)]
  78. else:
  79. col_types = dataframe.dtypes
  80. cols = col_types.index.astype(str).tolist()
  81. if styles:
  82. if not is_copied:
  83. # copy the df so that we don't "mess" with the user's data
  84. dataframe = dataframe.copy()
  85. is_copied = True
  86. for k, v in styles.items():
  87. col_applied = False
  88. func = self._gui._get_user_function(v)
  89. if callable(func):
  90. col_applied = self.__apply_user_function(func, k if k in cols else None, v, dataframe, "tps__")
  91. if not col_applied:
  92. dataframe[v] = v
  93. cols.append(col_applied or v)
  94. if tooltips:
  95. if not is_copied:
  96. # copy the df so that we don't "mess" with the user's data
  97. dataframe = dataframe.copy()
  98. is_copied = True
  99. for k, v in tooltips.items():
  100. col_applied = False
  101. func = self._gui._get_user_function(v)
  102. if callable(func):
  103. col_applied = self.__apply_user_function(func, k if k in cols else None, v, dataframe, "tpt__")
  104. cols.append(col_applied or v)
  105. # deal with dates
  106. datecols = col_types[col_types.astype(str).str.startswith("datetime")].index.tolist() # type: ignore
  107. if len(datecols) != 0:
  108. if not is_copied:
  109. # copy the df so that we don't "mess" with the user's data
  110. dataframe = dataframe.copy()
  111. tz = Gui._get_timezone()
  112. for col in datecols:
  113. newcol = _get_date_col_str_name(cols, col)
  114. cols.append(newcol)
  115. re_type = _RE_PD_TYPE.match(str(col_types[col]))
  116. grps = re_type.groups() if re_type else ()
  117. if len(grps) > 4 and grps[4]:
  118. dataframe[newcol] = (
  119. dataframe[col]
  120. .dt.tz_convert("UTC")
  121. .dt.strftime(_DataAccessor._WS_DATE_FORMAT)
  122. .astype(str)
  123. .replace("nan", "NaT" if handle_nan else None)
  124. )
  125. else:
  126. dataframe[newcol] = (
  127. dataframe[col]
  128. .dt.tz_localize(tz)
  129. .dt.tz_convert("UTC")
  130. .dt.strftime(_DataAccessor._WS_DATE_FORMAT)
  131. .astype(str)
  132. .replace("nan", "NaT" if handle_nan else None)
  133. )
  134. # remove the date columns from the list of columns
  135. cols = list(set(cols) - set(datecols))
  136. dataframe = dataframe.iloc[new_indexes] if new_indexes is not None else dataframe
  137. dataframe = dataframe.loc[:, dataframe.dtypes[dataframe.dtypes.index.astype(str).isin(cols)].index] # type: ignore
  138. return dataframe
  139. def __apply_user_function(
  140. self,
  141. user_function: t.Callable,
  142. column_name: t.Optional[str],
  143. function_name: str,
  144. data: pd.DataFrame,
  145. prefix: t.Optional[str],
  146. ):
  147. try:
  148. new_col_name = f"{prefix}{column_name}__{function_name}" if column_name else function_name
  149. data[new_col_name] = data.apply(
  150. _PandasDataAccessor.__user_function,
  151. axis=1,
  152. args=(self._gui, column_name, user_function, function_name),
  153. )
  154. return new_col_name
  155. except Exception as e:
  156. _warn(f"Exception raised when invoking user function {function_name}()", e)
  157. return False
  158. def __format_data(
  159. self,
  160. data: pd.DataFrame,
  161. data_format: _DataFormat,
  162. orient: str,
  163. start: t.Optional[int] = None,
  164. rowcount: t.Optional[int] = None,
  165. data_extraction: t.Optional[bool] = None,
  166. handle_nan: t.Optional[bool] = False,
  167. fullrowcount: t.Optional[int] = None,
  168. ) -> t.Dict[str, t.Any]:
  169. ret: t.Dict[str, t.Any] = {
  170. "format": str(data_format.value),
  171. }
  172. if rowcount is not None:
  173. ret["rowcount"] = rowcount
  174. if fullrowcount is not None and fullrowcount != rowcount:
  175. ret["fullrowcount"] = fullrowcount
  176. if start is not None:
  177. ret["start"] = start
  178. if data_extraction is not None:
  179. ret["dataExtraction"] = data_extraction # Extract data out of dictionary on front-end
  180. if data_format is _DataFormat.APACHE_ARROW:
  181. if not _has_arrow_module:
  182. raise RuntimeError("Cannot use Arrow as pyarrow package is not installed")
  183. # Convert from pandas to Arrow
  184. table = pa.Table.from_pandas(data)
  185. # Create sink buffer stream
  186. sink = pa.BufferOutputStream()
  187. # Create Stream writer
  188. writer = pa.ipc.new_stream(sink, table.schema)
  189. # Write data to table
  190. writer.write_table(table)
  191. writer.close()
  192. # End buffer stream
  193. buf = sink.getvalue()
  194. # Convert buffer to Python bytes and return
  195. ret["data"] = buf.to_pybytes()
  196. ret["orient"] = orient
  197. else:
  198. # Workaround for Python built in JSON encoder that does not yet support ignore_nan
  199. ret["data"] = data.replace([np.nan, pd.NA], [None, None]).to_dict(orient=orient) # type: ignore
  200. return ret
  201. def get_col_types(self, var_name: str, value: t.Any) -> t.Union[None, t.Dict[str, str]]: # type: ignore
  202. if isinstance(value, list):
  203. ret_dict: t.Dict[str, str] = {}
  204. for i, v in enumerate(value):
  205. ret_dict.update(
  206. {f"{i}/{k}": v for k, v in self.__to_dataframe(v).dtypes.apply(lambda x: x.name.lower()).items()}
  207. )
  208. return ret_dict
  209. return {str(k): v for k, v in self.__to_dataframe(value).dtypes.apply(lambda x: x.name.lower()).items()}
  210. def __get_data( # noqa: C901
  211. self,
  212. var_name: str,
  213. df: pd.DataFrame,
  214. payload: t.Dict[str, t.Any],
  215. data_format: _DataFormat,
  216. col_prefix: t.Optional[str] = "",
  217. ) -> t.Dict[str, t.Any]:
  218. columns = payload.get("columns", [])
  219. if col_prefix:
  220. columns = [c[len(col_prefix) :] if c.startswith(col_prefix) else c for c in columns]
  221. ret_payload = {"pagekey": payload.get("pagekey", "unknown page")}
  222. paged = not payload.get("alldata", False)
  223. is_copied = False
  224. orig_df = df
  225. # add index if not chart
  226. if paged:
  227. if _PandasDataAccessor.__INDEX_COL not in df.columns:
  228. df = df.copy()
  229. is_copied = True
  230. df[_PandasDataAccessor.__INDEX_COL] = df.index
  231. if columns and _PandasDataAccessor.__INDEX_COL not in columns:
  232. columns.append(_PandasDataAccessor.__INDEX_COL)
  233. fullrowcount = len(df)
  234. # filtering
  235. filters = payload.get("filters")
  236. if isinstance(filters, list) and len(filters) > 0:
  237. query = ""
  238. vars = []
  239. for fd in filters:
  240. col = fd.get("col")
  241. val = fd.get("value")
  242. action = fd.get("action")
  243. if isinstance(val, str):
  244. if self.__is_date_column(t.cast(pd.DataFrame, df), col):
  245. val = datetime.fromisoformat(val[:-1])
  246. vars.append(val)
  247. val = f"@vars[{len(vars) - 1}]" if isinstance(val, (str, datetime)) else val
  248. right = f".str.contains({val})" if action == "contains" else f" {action} {val}"
  249. if query:
  250. query += " and "
  251. query += f"`{col}`{right}"
  252. try:
  253. df = df.query(query)
  254. is_copied = True
  255. except Exception as e:
  256. _warn(f"Dataframe filtering: invalid query '{query}' on {df.head()}", e)
  257. dictret: t.Optional[t.Dict[str, t.Any]]
  258. if paged:
  259. aggregates = payload.get("aggregates")
  260. applies = payload.get("applies")
  261. if isinstance(aggregates, list) and len(aggregates) and isinstance(applies, dict):
  262. applies_with_fn = {
  263. k: v if v in _PandasDataAccessor.__AGGREGATE_FUNCTIONS else self._gui._get_user_function(v)
  264. for k, v in applies.items()
  265. }
  266. for col in columns:
  267. if col not in applies_with_fn.keys():
  268. applies_with_fn[col] = "first"
  269. try:
  270. df = t.cast(pd.DataFrame, df).groupby(aggregates).agg(applies_with_fn)
  271. except Exception:
  272. _warn(f"Cannot aggregate {var_name} with groupby {aggregates} and aggregates {applies}.")
  273. inf = payload.get("infinite")
  274. if inf is not None:
  275. ret_payload["infinite"] = inf
  276. # real number of rows is needed to calculate the number of pages
  277. rowcount = len(df)
  278. # here we'll deal with start and end values from payload if present
  279. if isinstance(payload["start"], int):
  280. start = int(payload["start"])
  281. else:
  282. try:
  283. start = int(str(payload["start"]), base=10)
  284. except Exception:
  285. _warn(f'start should be an int value {payload["start"]}.')
  286. start = 0
  287. if isinstance(payload["end"], int):
  288. end = int(payload["end"])
  289. else:
  290. try:
  291. end = int(str(payload["end"]), base=10)
  292. except Exception:
  293. end = -1
  294. if start < 0 or start >= rowcount:
  295. start = 0
  296. if end < 0 or end >= rowcount:
  297. end = rowcount - 1
  298. if payload.get("reverse", False):
  299. diff = end - start
  300. end = rowcount - 1 - start
  301. if end < 0:
  302. end = rowcount - 1
  303. start = end - diff
  304. if start < 0:
  305. start = 0
  306. # deal with sort
  307. order_by = payload.get("orderby")
  308. if isinstance(order_by, str) and len(order_by):
  309. try:
  310. if df.columns.dtype.name == "int64":
  311. order_by = int(order_by)
  312. new_indexes = t.cast(pd.DataFrame, df)[order_by].values.argsort(axis=0)
  313. if payload.get("sort") == "desc":
  314. # reverse order
  315. new_indexes = new_indexes[::-1]
  316. new_indexes = new_indexes[slice(start, end + 1)]
  317. except Exception:
  318. _warn(f"Cannot sort {var_name} on columns {order_by}.")
  319. new_indexes = slice(start, end + 1) # type: ignore
  320. else:
  321. new_indexes = slice(start, end + 1) # type: ignore
  322. df = self.__build_transferred_cols(
  323. columns,
  324. t.cast(pd.DataFrame, df),
  325. styles=payload.get("styles"),
  326. tooltips=payload.get("tooltips"),
  327. is_copied=is_copied,
  328. new_indexes=new_indexes,
  329. handle_nan=payload.get("handlenan", False),
  330. )
  331. dictret = self.__format_data(
  332. df,
  333. data_format,
  334. "records",
  335. start,
  336. rowcount,
  337. handle_nan=payload.get("handlenan", False),
  338. fullrowcount=fullrowcount,
  339. )
  340. compare = payload.get("compare")
  341. if isinstance(compare, str):
  342. comp_df = _compare_function(
  343. self._gui, compare, var_name, t.cast(pd.DataFrame, orig_df), payload.get("compare_datas", "")
  344. )
  345. if isinstance(comp_df, pd.DataFrame) and not comp_df.empty:
  346. try:
  347. if isinstance(comp_df.columns[0], tuple):
  348. cols: t.List[t.Hashable] = [c for c in comp_df.columns if c[1] == "other"]
  349. comp_df = t.cast(pd.DataFrame, comp_df.get(cols))
  350. comp_df.columns = t.cast(pd.Index, [t.cast(tuple, c)[0] for c in cols])
  351. comp_df.dropna(axis=1, how="all", inplace=True)
  352. comp_df = self.__build_transferred_cols(columns, comp_df, new_indexes=new_indexes)
  353. dictret["comp"] = self.__format_data(comp_df, data_format, "records").get("data")
  354. except Exception as e:
  355. _warn("Pandas accessor compare raised an exception", e)
  356. else:
  357. ret_payload["alldata"] = True
  358. decimator_payload: t.Dict[str, t.Any] = payload.get("decimatorPayload", {})
  359. decimators = decimator_payload.get("decimators", [])
  360. nb_rows_max = decimator_payload.get("width")
  361. for decimator_pl in decimators:
  362. decimator = decimator_pl.get("decimator")
  363. decimator_instance = (
  364. self._gui._get_user_instance(decimator, PropertyType.decimator.value)
  365. if decimator is not None
  366. else None
  367. )
  368. if isinstance(decimator_instance, PropertyType.decimator.value):
  369. x_column, y_column, z_column = (
  370. decimator_pl.get("xAxis", ""),
  371. decimator_pl.get("yAxis", ""),
  372. decimator_pl.get("zAxis", ""),
  373. )
  374. chart_mode = decimator_pl.get("chartMode", "")
  375. if decimator_instance._zoom and "relayoutData" in decimator_payload and not z_column:
  376. relayoutData = decimator_payload.get("relayoutData", {})
  377. x0 = relayoutData.get("xaxis.range[0]")
  378. x1 = relayoutData.get("xaxis.range[1]")
  379. y0 = relayoutData.get("yaxis.range[0]")
  380. y1 = relayoutData.get("yaxis.range[1]")
  381. df, is_copied = _df_relayout(
  382. t.cast(pd.DataFrame, df), x_column, y_column, chart_mode, x0, x1, y0, y1, is_copied
  383. )
  384. if nb_rows_max and decimator_instance._is_applicable(df, nb_rows_max, chart_mode):
  385. try:
  386. df, is_copied = _df_data_filter(
  387. t.cast(pd.DataFrame, df),
  388. x_column,
  389. y_column,
  390. z_column,
  391. decimator=decimator_instance,
  392. payload=decimator_payload,
  393. is_copied=is_copied,
  394. )
  395. self._gui._call_on_change(f"{var_name}.{decimator}.nb_rows", len(df))
  396. except Exception as e:
  397. _warn(f"Limit rows error with {decimator} for Dataframe", e)
  398. df = self.__build_transferred_cols(columns, t.cast(pd.DataFrame, df), is_copied=is_copied)
  399. if data_format is _DataFormat.CSV:
  400. ret_payload["df"] = df
  401. dictret = None
  402. else:
  403. dictret = self.__format_data(df, data_format, "list", data_extraction=True)
  404. ret_payload["value"] = dictret
  405. return ret_payload
  406. def get_data(
  407. self, var_name: str, value: t.Any, payload: t.Dict[str, t.Any], data_format: _DataFormat
  408. ) -> t.Dict[str, t.Any]:
  409. if isinstance(value, list):
  410. # If is_chart data
  411. if payload.get("alldata", False):
  412. ret_payload = {
  413. "alldata": True,
  414. "value": {"multi": True},
  415. "pagekey": payload.get("pagekey", "unknown page"),
  416. }
  417. data = []
  418. for i, v in enumerate(value):
  419. ret = (
  420. self.__get_data(var_name, self.__to_dataframe(v), payload, data_format, f"{i}/")
  421. if isinstance(v, _PandasDataAccessor.__types)
  422. else {}
  423. )
  424. ret_val = ret.get("value", {})
  425. data.append(ret_val.pop("data", None))
  426. ret_payload.get("value", {}).update(ret_val)
  427. ret_payload["value"]["data"] = data
  428. return ret_payload
  429. else:
  430. value = value[0]
  431. return self.__get_data(var_name, self.__to_dataframe(value), payload, data_format)
  432. def on_edit(self, value: t.Any, payload: t.Dict[str, t.Any]):
  433. df = self.to_pandas(value)
  434. if not isinstance(df, pd.DataFrame):
  435. raise ValueError(f"Cannot edit {type(value)}.")
  436. df.at[payload["index"], payload["col"]] = payload["value"]
  437. return self._from_pandas(df, type(value))
  438. def on_delete(self, value: t.Any, payload: t.Dict[str, t.Any]):
  439. df = self.to_pandas(value)
  440. if not isinstance(df, pd.DataFrame):
  441. raise ValueError(f"Cannot delete a row from {type(value)}.")
  442. return self._from_pandas(df.drop(payload["index"]), type(value))
  443. def on_add(self, value: t.Any, payload: t.Dict[str, t.Any], new_row: t.Optional[t.List[t.Any]] = None):
  444. df = self.to_pandas(value)
  445. if not isinstance(df, pd.DataFrame):
  446. raise ValueError(f"Cannot add a row to {type(value)}.")
  447. # Save the insertion index
  448. index = payload["index"]
  449. # Create the new row (Column value types must match the original DataFrame's)
  450. col_types = self.get_col_types("", df)
  451. if col_types:
  452. new_row = [0 if is_numeric_dtype(df[c]) else "" for c in list(col_types)] if new_row is None else new_row
  453. if index > 0:
  454. # Column names and value types must match the original DataFrame
  455. new_df = pd.DataFrame([new_row], columns=list(col_types))
  456. # Split the DataFrame
  457. rows_before = df.iloc[:index]
  458. rows_after = df.iloc[index:]
  459. return self._from_pandas(pd.concat([rows_before, new_df, rows_after], ignore_index=True), type(value))
  460. else:
  461. df = df.copy()
  462. # Insert as the new first row
  463. df.loc[-1] = new_row # Insert the new row
  464. df.index = df.index + 1 # Shift index
  465. return self._from_pandas(df.sort_index(), type(value))
  466. return value
  467. def to_csv(self, var_name: str, value: t.Any):
  468. df = self.to_pandas(value)
  469. if not isinstance(df, pd.DataFrame):
  470. raise ValueError(f"Cannot export {type(value)} to csv.")
  471. dict_ret = self.__get_data(var_name, df, {"alldata": True}, _DataFormat.CSV)
  472. if isinstance(dict_ret, dict):
  473. dfr = dict_ret.get("df")
  474. if isinstance(dfr, pd.DataFrame):
  475. fd, temp_path = mkstemp(".csv", var_name, text=True)
  476. with os.fdopen(fd, "wt", newline="") as csv_file:
  477. dfr.to_csv(csv_file, index=False)
  478. return temp_path
  479. return None