git.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. """Git utilities."""
  10. import functools
  11. import subprocess
  12. import sys
  13. from pathlib import Path
  14. from typing import TypeVar
  15. from materialize import spawn
  16. from materialize.mz_version import MzVersion, TypedVersionBase
  17. from materialize.util import YesNoOnce
  18. VERSION_TYPE = TypeVar("VERSION_TYPE", bound=TypedVersionBase)
  19. MATERIALIZE_REMOTE_URL = "https://github.com/MaterializeInc/materialize"
  20. fetched_tags_in_remotes: set[str | None] = set()
  21. def rev_count(rev: str) -> int:
  22. """Count the commits up to a revision.
  23. Args:
  24. rev: A Git revision in any format know to the Git CLI.
  25. Returns:
  26. count: The number of commits in the Git repository starting from the
  27. initial commit and ending with the specified commit, inclusive.
  28. """
  29. return int(spawn.capture(["git", "rev-list", "--count", rev, "--"]).strip())
  30. def rev_parse(rev: str, *, abbrev: bool = False) -> str:
  31. """Compute the hash for a revision.
  32. Args:
  33. rev: A Git revision in any format known to the Git CLI.
  34. abbrev: Return a branch or tag name instead of a git sha
  35. Returns:
  36. ref: A 40 character hex-encoded SHA-1 hash representing the ID of the
  37. named revision in Git's object database.
  38. With "abbrev=True" this will return an abbreviated ref, or throw an
  39. error if there is no abbrev.
  40. """
  41. a = ["--abbrev-ref"] if abbrev else []
  42. out = spawn.capture(["git", "rev-parse", *a, "--verify", rev]).strip()
  43. if not out:
  44. raise RuntimeError(f"No parsed rev for {rev}")
  45. return out
  46. @functools.cache
  47. def expand_globs(root: Path, *specs: Path | str) -> set[str]:
  48. """Find unignored files within the specified paths."""
  49. # The goal here is to find all files in the working tree that are not
  50. # ignored by .gitignore. Naively using `git ls-files` doesn't work, because
  51. # it reports files that have been deleted in the working tree if they are
  52. # still present in the index. Using `os.walkdir` doesn't work because there
  53. # is no good way to evaluate .gitignore rules from Python. So we use a
  54. # combination of `git diff` and `git ls-files`.
  55. # `git diff` against the empty tree surfaces all tracked files that have
  56. # not been deleted.
  57. empty_tree = (
  58. "4b825dc642cb6eb9a060e54bf8d69288fbee4904" # git hash-object -t tree /dev/null
  59. )
  60. diff_files = spawn.capture(
  61. ["git", "diff", "--name-only", "-z", "--relative", empty_tree, "--", *specs],
  62. cwd=root,
  63. )
  64. # `git ls-files --others --exclude-standard` surfaces any non-ignored,
  65. # untracked files, which are not included in the `git diff` output above.
  66. ls_files = spawn.capture(
  67. ["git", "ls-files", "--others", "--exclude-standard", "-z", "--", *specs],
  68. cwd=root,
  69. )
  70. return set(f for f in (diff_files + ls_files).split("\0") if f.strip() != "")
  71. def get_version_tags(
  72. *,
  73. version_type: type[VERSION_TYPE],
  74. newest_first: bool = True,
  75. fetch: bool = True,
  76. remote_url: str = MATERIALIZE_REMOTE_URL,
  77. ) -> list[VERSION_TYPE]:
  78. """List all the version-like tags in the repo
  79. Args:
  80. fetch: If false, don't automatically run `git fetch --tags`.
  81. prefix: A prefix to strip from each tag before attempting to parse the
  82. tag as a version.
  83. """
  84. if fetch:
  85. _fetch(
  86. remote=get_remote(remote_url),
  87. include_tags=YesNoOnce.ONCE,
  88. force=True,
  89. only_tags=True,
  90. )
  91. tags = []
  92. for t in spawn.capture(["git", "tag"]).splitlines():
  93. if not t.startswith(version_type.get_prefix()):
  94. continue
  95. try:
  96. tags.append(version_type.parse(t))
  97. except ValueError as e:
  98. print(f"WARN: {e}", file=sys.stderr)
  99. return sorted(tags, reverse=newest_first)
  100. def get_latest_version(
  101. version_type: type[VERSION_TYPE],
  102. excluded_versions: set[VERSION_TYPE] | None = None,
  103. current_version: VERSION_TYPE | None = None,
  104. ) -> VERSION_TYPE:
  105. all_version_tags: list[VERSION_TYPE] = get_version_tags(
  106. version_type=version_type, fetch=True
  107. )
  108. if excluded_versions is not None:
  109. all_version_tags = [
  110. v
  111. for v in all_version_tags
  112. if v not in excluded_versions
  113. and (not current_version or v < current_version)
  114. ]
  115. return max(all_version_tags)
  116. def get_tags_of_current_commit(include_tags: YesNoOnce = YesNoOnce.ONCE) -> list[str]:
  117. if include_tags:
  118. fetch(include_tags=include_tags, only_tags=True)
  119. result = spawn.capture(["git", "tag", "--points-at", "HEAD"])
  120. if len(result) == 0:
  121. return []
  122. return result.splitlines()
  123. def is_ancestor(earlier: str, later: str) -> bool:
  124. """True if earlier is in an ancestor of later"""
  125. try:
  126. spawn.capture(["git", "merge-base", "--is-ancestor", earlier, later])
  127. except subprocess.CalledProcessError:
  128. return False
  129. return True
  130. def is_dirty() -> bool:
  131. """Check if the working directory has modifications to tracked files"""
  132. proc = subprocess.run("git diff --no-ext-diff --quiet --exit-code".split())
  133. idx = subprocess.run("git diff --cached --no-ext-diff --quiet --exit-code".split())
  134. return proc.returncode != 0 or idx.returncode != 0
  135. def first_remote_matching(pattern: str) -> str | None:
  136. """Get the name of the remote that matches the pattern"""
  137. remotes = spawn.capture(["git", "remote", "-v"])
  138. for remote in remotes.splitlines():
  139. if pattern in remote:
  140. return remote.split()[0]
  141. return None
  142. def describe() -> str:
  143. """Describe the relationship between the current commit and the most recent tag"""
  144. return spawn.capture(["git", "describe"]).strip()
  145. def fetch(
  146. remote: str | None = None,
  147. all_remotes: bool = False,
  148. include_tags: YesNoOnce = YesNoOnce.NO,
  149. force: bool = False,
  150. branch: str | None = None,
  151. only_tags: bool = False,
  152. include_submodules: bool = False,
  153. ) -> str:
  154. """Fetch from remotes"""
  155. if remote is not None and all_remotes:
  156. raise RuntimeError("all_remotes must be false when a remote is specified")
  157. if branch is not None and remote is None:
  158. raise RuntimeError("remote must be specified when a branch is specified")
  159. if branch is not None and only_tags:
  160. raise RuntimeError("branch must not be specified if only_tags is set")
  161. command = ["git", "fetch"]
  162. if remote:
  163. command.append(remote)
  164. if branch:
  165. command.append(branch)
  166. if all_remotes:
  167. command.append("--all")
  168. # explicitly specify both cases to be independent of the git config
  169. if include_submodules:
  170. command.append("--recurse-submodules")
  171. else:
  172. command.append("--no-recurse-submodules")
  173. fetch_tags = (
  174. include_tags == YesNoOnce.YES
  175. # fetch tags again if used with force (tags might have changed)
  176. or (include_tags == YesNoOnce.ONCE and force)
  177. or (
  178. include_tags == YesNoOnce.ONCE
  179. and remote not in fetched_tags_in_remotes
  180. and "*" not in fetched_tags_in_remotes
  181. )
  182. )
  183. if fetch_tags:
  184. command.append("--tags")
  185. if force:
  186. command.append("--force")
  187. if not fetch_tags and only_tags:
  188. return ""
  189. output = spawn.capture(command).strip()
  190. if fetch_tags:
  191. fetched_tags_in_remotes.add(remote)
  192. if all_remotes:
  193. fetched_tags_in_remotes.add("*")
  194. return output
  195. _fetch = fetch # renamed because an argument shadows the fetch name in get_tags
  196. def try_get_remote_name_by_url(url: str) -> str | None:
  197. result = spawn.capture(["git", "remote", "--verbose"])
  198. for line in result.splitlines():
  199. remote, desc = line.split("\t")
  200. if desc.lower() in (f"{url} (fetch)".lower(), f"{url}.git (fetch)".lower()):
  201. return remote
  202. return None
  203. def get_remote(
  204. url: str = MATERIALIZE_REMOTE_URL,
  205. default_remote_name: str = "origin",
  206. ) -> str:
  207. # Alternative syntax
  208. remote = try_get_remote_name_by_url(url) or try_get_remote_name_by_url(
  209. url.replace("https://github.com/", "git@github.com:")
  210. )
  211. if not remote:
  212. remote = default_remote_name
  213. print(f"Remote for URL {url} not found, using {remote}")
  214. return remote
  215. def get_common_ancestor_commit(remote: str, branch: str, fetch_branch: bool) -> str:
  216. if fetch_branch:
  217. fetch(remote=remote, branch=branch)
  218. command = ["git", "merge-base", "HEAD", f"{remote}/{branch}"]
  219. return spawn.capture(command).strip()
  220. def is_on_release_version() -> bool:
  221. git_tags = get_tags_of_current_commit()
  222. return any(MzVersion.is_valid_version_string(git_tag) for git_tag in git_tags)
  223. def contains_commit(
  224. commit_sha: str,
  225. target: str = "HEAD",
  226. fetch: bool = False,
  227. remote_url: str = MATERIALIZE_REMOTE_URL,
  228. ) -> bool:
  229. if fetch:
  230. remote = get_remote(remote_url)
  231. _fetch(remote=remote)
  232. target = f"{remote}/{target}"
  233. command = ["git", "merge-base", "--is-ancestor", commit_sha, target]
  234. return_code = spawn.run_and_get_return_code(command)
  235. return return_code == 0
  236. def get_tagged_release_version(version_type: type[VERSION_TYPE]) -> VERSION_TYPE | None:
  237. """
  238. This returns the release version if exactly this commit is tagged.
  239. If multiple release versions are present, the highest one will be returned.
  240. None will be returned if the commit is not tagged.
  241. """
  242. git_tags = get_tags_of_current_commit()
  243. versions: list[VERSION_TYPE] = []
  244. for git_tag in git_tags:
  245. if version_type.is_valid_version_string(git_tag):
  246. versions.append(version_type.parse(git_tag))
  247. if len(versions) == 0:
  248. return None
  249. if len(versions) > 1:
  250. print(
  251. "Warning! Commit is tagged with multiple release versions! Returning the highest."
  252. )
  253. return max(versions)
  254. def get_commit_message(commit_sha: str) -> str | None:
  255. try:
  256. command = ["git", "log", "-1", "--pretty=format:%s", commit_sha]
  257. return spawn.capture(command, stderr=subprocess.DEVNULL).strip()
  258. except subprocess.CalledProcessError:
  259. # Sometimes mz_version() will report a Git SHA that is not available
  260. # in the current repository
  261. return None
  262. def get_branch_name() -> str:
  263. """This may not work on Buildkite; consider using the same function from build_context."""
  264. command = ["git", "branch", "--show-current"]
  265. return spawn.capture(command).strip()
  266. # Work tree mutation
  267. def create_branch(name: str) -> None:
  268. spawn.runv(["git", "checkout", "-b", name])
  269. def checkout(rev: str, path: str | None = None) -> None:
  270. """Git checkout the rev"""
  271. cmd = ["git", "checkout", rev]
  272. if path:
  273. cmd.extend(["--", path])
  274. spawn.runv(cmd)
  275. def add_file(file: str) -> None:
  276. """Git add a file"""
  277. spawn.runv(["git", "add", file])
  278. def commit_all_changed(message: str) -> None:
  279. """Commit all changed files with the given message"""
  280. spawn.runv(["git", "commit", "-a", "-m", message])
  281. def tag_annotated(tag: str) -> None:
  282. """Create an annotated tag on HEAD"""
  283. spawn.runv(["git", "tag", "-a", "-m", tag, tag])