benchmark_reflex_size.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Checks the size of a specific directory and uploads result."""
  2. import argparse
  3. import os
  4. import subprocess
  5. from datetime import datetime
  6. import psycopg2
  7. def get_directory_size(directory):
  8. """Get the size of a directory in bytes.
  9. Args:
  10. directory: The directory to check.
  11. Returns:
  12. The size of the dir in bytes.
  13. """
  14. total_size = 0
  15. for dirpath, _, filenames in os.walk(directory):
  16. for f in filenames:
  17. fp = os.path.join(dirpath, f)
  18. total_size += os.path.getsize(fp)
  19. return total_size
  20. def get_python_version(venv_path, os_name):
  21. """Get the python version of python in a virtual env.
  22. Args:
  23. venv_path: Path to virtual environment.
  24. os_name: Name of os.
  25. Returns:
  26. The python version.
  27. """
  28. python_executable = (
  29. os.path.join(venv_path, "bin", "python")
  30. if "windows" not in os_name
  31. else os.path.join(venv_path, "Scripts", "python.exe")
  32. )
  33. try:
  34. output = subprocess.check_output(
  35. [python_executable, "--version"], stderr=subprocess.STDOUT
  36. )
  37. python_version = output.decode("utf-8").strip().split()[1]
  38. return ".".join(python_version.split(".")[:-1])
  39. except subprocess.CalledProcessError:
  40. return None
  41. def get_package_size(venv_path, os_name):
  42. """Get the size of a specified package.
  43. Args:
  44. venv_path: The path to the venv.
  45. os_name: Name of os.
  46. Returns:
  47. The total size of the package in bytes.
  48. Raises:
  49. ValueError: when venv does not exist or python version is None.
  50. """
  51. python_version = get_python_version(venv_path, os_name)
  52. if python_version is None:
  53. raise ValueError("Error: Failed to determine Python version.")
  54. is_windows = "windows" in os_name
  55. full_path = (
  56. ["lib", f"python{python_version}", "site-packages"]
  57. if not is_windows
  58. else ["Lib", "site-packages"]
  59. )
  60. package_dir = os.path.join(venv_path, *full_path)
  61. if not os.path.exists(package_dir):
  62. raise ValueError(
  63. "Error: Virtual environment does not exist or is not activated."
  64. )
  65. total_size = get_directory_size(package_dir)
  66. return total_size
  67. def insert_benchmarking_data(
  68. db_connection_url: str,
  69. os_type_version: str,
  70. python_version: str,
  71. measurement_type: str,
  72. commit_sha: str,
  73. pr_title: str,
  74. branch_name: str,
  75. pr_id: str,
  76. path: str,
  77. ):
  78. """Insert the benchmarking data into the database.
  79. Args:
  80. db_connection_url: The URL to connect to the database.
  81. os_type_version: The OS type and version to insert.
  82. python_version: The Python version to insert.
  83. measurement_type: The type of metric to measure.
  84. commit_sha: The commit SHA to insert.
  85. pr_title: The PR title to insert.
  86. branch_name: The name of the branch.
  87. pr_id: The id of the PR.
  88. path: The path to the dir or file to check size.
  89. """
  90. if measurement_type == "reflex-package":
  91. size = get_package_size(path, os_type_version)
  92. else:
  93. size = get_directory_size(path)
  94. # Get the current timestamp
  95. current_timestamp = datetime.now()
  96. # Connect to the database and insert the data
  97. with psycopg2.connect(db_connection_url) as conn, conn.cursor() as cursor:
  98. insert_query = """
  99. INSERT INTO size_benchmarks (os, python_version, commit_sha, created_at, pr_title, branch_name, pr_id, measurement_type, size)
  100. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  101. """
  102. cursor.execute(
  103. insert_query,
  104. (
  105. os_type_version,
  106. python_version,
  107. commit_sha,
  108. current_timestamp,
  109. pr_title,
  110. branch_name,
  111. pr_id,
  112. measurement_type,
  113. round(
  114. size / (1024 * 1024), 3
  115. ), # save size in mb and round to 3 places.
  116. ),
  117. )
  118. # Commit the transaction
  119. conn.commit()
  120. def main():
  121. """Runs the benchmarks and inserts the results."""
  122. parser = argparse.ArgumentParser(description="Run benchmarks and process results.")
  123. parser.add_argument(
  124. "--os", help="The OS type and version to insert into the database."
  125. )
  126. parser.add_argument(
  127. "--python-version", help="The Python version to insert into the database."
  128. )
  129. parser.add_argument(
  130. "--commit-sha", help="The commit SHA to insert into the database."
  131. )
  132. parser.add_argument(
  133. "--db-url",
  134. help="The URL to connect to the database.",
  135. required=True,
  136. )
  137. parser.add_argument(
  138. "--pr-title",
  139. help="The PR title to insert into the database.",
  140. )
  141. parser.add_argument(
  142. "--branch-name",
  143. help="The current branch",
  144. required=True,
  145. )
  146. parser.add_argument(
  147. "--pr-id",
  148. help="The pr id",
  149. required=True,
  150. )
  151. parser.add_argument(
  152. "--measurement-type",
  153. help="The type of metric to be checked.",
  154. required=True,
  155. )
  156. parser.add_argument(
  157. "--path",
  158. help="the current path to check size.",
  159. required=True,
  160. )
  161. args = parser.parse_args()
  162. # Get the PR title from env or the args. For the PR merge or push event, there is no PR title, leaving it empty.
  163. pr_title = args.pr_title or os.getenv("PR_TITLE", "")
  164. # Insert the data into the database
  165. insert_benchmarking_data(
  166. db_connection_url=args.db_url,
  167. os_type_version=args.os,
  168. python_version=args.python_version,
  169. measurement_type=args.measurement_type,
  170. commit_sha=args.commit_sha,
  171. pr_title=pr_title,
  172. branch_name=args.branch_name,
  173. pr_id=args.pr_id,
  174. path=args.path,
  175. )
  176. if __name__ == "__main__":
  177. main()