imports.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """Import operations."""
  2. from __future__ import annotations
  3. import dataclasses
  4. from collections import defaultdict
  5. from collections.abc import Mapping, Sequence
  6. def merge_imports(
  7. *imports: ImportDict | ParsedImportDict | ParsedImportTuple,
  8. ) -> ParsedImportDict:
  9. """Merge multiple import dicts together.
  10. Args:
  11. *imports: The list of import dicts to merge.
  12. Returns:
  13. The merged import dicts.
  14. """
  15. all_imports: defaultdict[str, list[ImportVar]] = defaultdict(list)
  16. for import_dict in imports:
  17. for lib, fields in (
  18. import_dict if isinstance(import_dict, tuple) else import_dict.items()
  19. ):
  20. # If the lib is an absolute path, we need to prefix it with a $
  21. lib = (
  22. "$" + lib
  23. if lib.startswith(("/utils/", "/components/", "/styles/", "/public/"))
  24. else lib
  25. )
  26. if isinstance(fields, (list, tuple, set)):
  27. all_imports[lib].extend(
  28. ImportVar(field) if isinstance(field, str) else field
  29. for field in fields
  30. )
  31. else:
  32. all_imports[lib].append(
  33. ImportVar(fields) if isinstance(fields, str) else fields
  34. )
  35. return all_imports
  36. def parse_imports(
  37. imports: ImmutableImportDict | ImmutableParsedImportDict,
  38. ) -> ParsedImportDict:
  39. """Parse the import dict into a standard format.
  40. Args:
  41. imports: The import dict to parse.
  42. Returns:
  43. The parsed import dict.
  44. """
  45. def _make_list(
  46. value: ImmutableImportTypes,
  47. ) -> list[str | ImportVar] | list[ImportVar]:
  48. if isinstance(value, (str, ImportVar)):
  49. return [value]
  50. return list(value)
  51. return {
  52. package: [
  53. ImportVar(tag=tag) if isinstance(tag, str) else tag
  54. for tag in _make_list(maybe_tags)
  55. ]
  56. for package, maybe_tags in imports.items()
  57. }
  58. def collapse_imports(
  59. imports: ParsedImportDict | ParsedImportTuple,
  60. ) -> ParsedImportDict:
  61. """Remove all duplicate ImportVar within an ImportDict.
  62. Args:
  63. imports: The import dict to collapse.
  64. Returns:
  65. The collapsed import dict.
  66. """
  67. return {
  68. lib: (
  69. list(set(import_vars))
  70. if isinstance(import_vars, list)
  71. else list(import_vars)
  72. )
  73. for lib, import_vars in (
  74. imports if isinstance(imports, tuple) else imports.items()
  75. )
  76. }
  77. @dataclasses.dataclass(frozen=True)
  78. class ImportVar:
  79. """An import var."""
  80. # The name of the import tag.
  81. tag: str | None
  82. # whether the import is default or named.
  83. is_default: bool | None = False
  84. # The tag alias.
  85. alias: str | None = None
  86. # Whether this import need to install the associated lib
  87. install: bool | None = True
  88. # whether this import should be rendered or not
  89. render: bool | None = True
  90. # The path of the package to import from.
  91. package_path: str = "/"
  92. # whether this import package should be added to transpilePackages in next.config.js
  93. # https://nextjs.org/docs/app/api-reference/next-config-js/transpilePackages
  94. transpile: bool | None = False
  95. @property
  96. def name(self) -> str:
  97. """The name of the import.
  98. Returns:
  99. The name(tag name with alias) of tag.
  100. """
  101. if self.alias:
  102. return (
  103. self.alias if self.is_default else " as ".join([self.tag, self.alias]) # pyright: ignore [reportCallIssue,reportArgumentType]
  104. )
  105. else:
  106. return self.tag or ""
  107. ImportTypes = str | ImportVar | list[str | ImportVar] | list[ImportVar]
  108. ImmutableImportTypes = str | ImportVar | Sequence[str | ImportVar]
  109. ImportDict = dict[str, ImportTypes]
  110. ImmutableImportDict = Mapping[str, ImmutableImportTypes]
  111. ParsedImportDict = dict[str, list[ImportVar]]
  112. ImmutableParsedImportDict = Mapping[str, Sequence[ImportVar]]
  113. ParsedImportTuple = tuple[tuple[str, tuple[ImportVar, ...]], ...]