scratch.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. """Utilities for launching and interacting with scratch EC2 instances."""
  10. import asyncio
  11. import csv
  12. import datetime
  13. import os
  14. import shlex
  15. import subprocess
  16. import sys
  17. from subprocess import CalledProcessError
  18. from typing import NamedTuple, cast
  19. import boto3
  20. from botocore.exceptions import ClientError
  21. from mypy_boto3_ec2.literals import InstanceTypeType
  22. from mypy_boto3_ec2.service_resource import Instance
  23. from mypy_boto3_ec2.type_defs import (
  24. FilterTypeDef,
  25. InstanceNetworkInterfaceSpecificationTypeDef,
  26. InstanceTypeDef,
  27. RunInstancesRequestServiceResourceCreateInstancesTypeDef,
  28. )
  29. from prettytable import PrettyTable
  30. from pydantic import BaseModel
  31. from materialize import MZ_ROOT, git, spawn, ui, util
  32. # Sane defaults for internal Materialize use in the scratch account
  33. DEFAULT_SECURITY_GROUP_NAME = "scratch-security-group"
  34. DEFAULT_INSTANCE_PROFILE_NAME = "admin-instance"
  35. SSH_COMMAND = ["mssh", "-o", "StrictHostKeyChecking=off"]
  36. SFTP_COMMAND = ["msftp", "-o", "StrictHostKeyChecking=off"]
  37. say = ui.speaker("scratch> ")
  38. def tags(i: Instance) -> dict[str, str]:
  39. if not i.tags:
  40. return {}
  41. return {t["Key"]: t["Value"] for t in i.tags}
  42. def instance_typedef_tags(i: InstanceTypeDef) -> dict[str, str]:
  43. return {t["Key"]: t["Value"] for t in i.get("Tags", [])}
  44. def name(tags: dict[str, str]) -> str | None:
  45. return tags.get("Name")
  46. def launched_by(tags: dict[str, str]) -> str | None:
  47. return tags.get("LaunchedBy")
  48. def ami_user(tags: dict[str, str]) -> str | None:
  49. return tags.get("ami-user", "ubuntu")
  50. def delete_after(tags: dict[str, str]) -> datetime.datetime | None:
  51. unix = tags.get("scratch-delete-after")
  52. if not unix:
  53. return None
  54. unix = int(float(unix))
  55. return datetime.datetime.fromtimestamp(unix)
  56. def instance_host(instance: Instance, user: str | None = None) -> str:
  57. if user is None:
  58. user = ami_user(tags(instance))
  59. return f"{user}@{instance.id}"
  60. def print_instances(ists: list[Instance], format: str) -> None:
  61. field_names = [
  62. "Name",
  63. "Instance ID",
  64. "Public IP Address",
  65. "Private IP Address",
  66. "Launched By",
  67. "Delete After",
  68. "State",
  69. ]
  70. rows = [
  71. [
  72. name(tags),
  73. i.instance_id,
  74. i.public_ip_address,
  75. i.private_ip_address,
  76. launched_by(tags),
  77. delete_after(tags),
  78. i.state["Name"],
  79. ]
  80. for (i, tags) in [(i, tags(i)) for i in ists]
  81. ]
  82. if format == "table":
  83. pt = PrettyTable()
  84. pt.field_names = field_names
  85. pt.add_rows(rows)
  86. print(pt)
  87. elif format == "csv":
  88. w = csv.writer(sys.stdout)
  89. w.writerow(field_names)
  90. w.writerows(rows)
  91. else:
  92. raise RuntimeError("Unknown format passed to print_instances")
  93. def launch(
  94. *,
  95. key_name: str | None,
  96. instance_type: str,
  97. ami: str,
  98. ami_user: str,
  99. tags: dict[str, str],
  100. display_name: str | None = None,
  101. size_gb: int,
  102. security_group_name: str,
  103. instance_profile: str | None,
  104. nonce: str,
  105. delete_after: datetime.datetime,
  106. ) -> Instance:
  107. """Launch and configure an ec2 instance with the given properties."""
  108. if display_name:
  109. tags["Name"] = display_name
  110. tags["scratch-delete-after"] = str(delete_after.timestamp())
  111. tags["nonce"] = nonce
  112. tags["git_ref"] = git.describe()
  113. tags["ami-user"] = ami_user
  114. ec2 = boto3.client("ec2")
  115. groups = ec2.describe_security_groups()
  116. security_group_id = None
  117. for group in groups["SecurityGroups"]:
  118. if group["GroupName"] == security_group_name:
  119. security_group_id = group["GroupId"]
  120. break
  121. if security_group_id is None:
  122. vpcs = ec2.describe_vpcs()
  123. vpc_id = None
  124. for vpc in vpcs["Vpcs"]:
  125. if vpc["IsDefault"] == True:
  126. vpc_id = vpc["VpcId"]
  127. break
  128. if vpc_id is None:
  129. default_vpc = ec2.create_default_vpc()
  130. vpc_id = default_vpc["Vpc"]["VpcId"]
  131. securitygroup = ec2.create_security_group(
  132. GroupName=security_group_name,
  133. Description="Allows all.",
  134. VpcId=vpc_id,
  135. )
  136. security_group_id = securitygroup["GroupId"]
  137. ec2.authorize_security_group_ingress(
  138. GroupId=security_group_id,
  139. CidrIp="0.0.0.0/0",
  140. IpProtocol="tcp",
  141. FromPort=22,
  142. ToPort=22,
  143. )
  144. network_interface: InstanceNetworkInterfaceSpecificationTypeDef = {
  145. "AssociatePublicIpAddress": True,
  146. "DeviceIndex": 0,
  147. "Groups": [security_group_id],
  148. }
  149. say(f"launching instance {display_name or '(unnamed)'}")
  150. with open(MZ_ROOT / "misc" / "scratch" / "provision.bash") as f:
  151. provisioning_script = f.read()
  152. kwargs: RunInstancesRequestServiceResourceCreateInstancesTypeDef = {
  153. "MinCount": 1,
  154. "MaxCount": 1,
  155. "ImageId": ami,
  156. "InstanceType": cast(InstanceTypeType, instance_type),
  157. "UserData": provisioning_script,
  158. "TagSpecifications": [
  159. {
  160. "ResourceType": "instance",
  161. "Tags": [{"Key": k, "Value": v} for (k, v) in tags.items()],
  162. }
  163. ],
  164. "NetworkInterfaces": [network_interface],
  165. "BlockDeviceMappings": [
  166. {
  167. "DeviceName": "/dev/sda1",
  168. "Ebs": {
  169. "VolumeSize": size_gb,
  170. "VolumeType": "gp3",
  171. },
  172. }
  173. ],
  174. "MetadataOptions": {
  175. # Allow Docker containers to access IMDSv2.
  176. "HttpPutResponseHopLimit": 2,
  177. },
  178. }
  179. if key_name:
  180. kwargs["KeyName"] = key_name
  181. if instance_profile:
  182. kwargs["IamInstanceProfile"] = {"Name": instance_profile}
  183. i = boto3.resource("ec2").create_instances(**kwargs)[0]
  184. return i
  185. class CommandResult(NamedTuple):
  186. status: str
  187. stdout: str
  188. stderr: str
  189. async def setup(
  190. i: Instance,
  191. git_rev: str,
  192. ) -> None:
  193. def is_ready(i: Instance) -> bool:
  194. return bool(
  195. i.public_ip_address and i.state and i.state.get("Name") == "running"
  196. )
  197. done = False
  198. async for remaining in ui.async_timeout_loop(60, 5):
  199. say(f"Waiting for instance to become ready: {remaining:0.0f}s remaining")
  200. try:
  201. i.reload()
  202. if is_ready(i):
  203. done = True
  204. break
  205. except ClientError:
  206. pass
  207. if not done:
  208. raise RuntimeError(
  209. f"Instance {i} did not become ready in a reasonable amount of time"
  210. )
  211. done = False
  212. async for remaining in ui.async_timeout_loop(300, 5):
  213. say(f"Checking whether setup has completed: {remaining:0.0f}s remaining")
  214. try:
  215. mssh(i, "[[ -f /opt/provision/done ]]")
  216. done = True
  217. break
  218. except CalledProcessError:
  219. continue
  220. if not done:
  221. raise RuntimeError(
  222. "Instance did not finish setup in a reasonable amount of time"
  223. )
  224. mkrepo(i, git_rev)
  225. def mkrepo(i: Instance, rev: str, init: bool = True, force: bool = False) -> None:
  226. if init:
  227. mssh(
  228. i,
  229. "git clone https://github.com/MaterializeInc/materialize.git --recurse-submodules",
  230. )
  231. rev = git.rev_parse(rev)
  232. cmd: list[str] = [
  233. "git",
  234. "push",
  235. "--no-verify",
  236. f"{instance_host(i)}:materialize/.git",
  237. # Explicit refspec is required if the host repository is in detached
  238. # HEAD mode.
  239. f"{rev}:refs/heads/scratch",
  240. "--no-recurse-submodules",
  241. ]
  242. if force:
  243. cmd.append("--force")
  244. spawn.runv(
  245. cmd,
  246. cwd=MZ_ROOT,
  247. env=dict(os.environ, GIT_SSH_COMMAND=" ".join(SSH_COMMAND)),
  248. )
  249. mssh(
  250. i,
  251. f"cd materialize && git config core.bare false && git checkout {rev} && git submodule sync --recursive && git submodule update --recursive",
  252. )
  253. class MachineDesc(BaseModel):
  254. name: str
  255. launch_script: str | None
  256. instance_type: str
  257. ami: str
  258. tags: dict[str, str] = {}
  259. size_gb: int
  260. checkout: bool = True
  261. ami_user: str = "ubuntu"
  262. def launch_cluster(
  263. descs: list[MachineDesc],
  264. *,
  265. nonce: str | None = None,
  266. key_name: str | None = None,
  267. security_group_name: str = DEFAULT_SECURITY_GROUP_NAME,
  268. instance_profile: str | None = DEFAULT_INSTANCE_PROFILE_NAME,
  269. extra_tags: dict[str, str] = {},
  270. delete_after: datetime.datetime,
  271. git_rev: str = "HEAD",
  272. extra_env: dict[str, str] = {},
  273. ) -> list[Instance]:
  274. """Launch a cluster of instances with a given nonce"""
  275. if not nonce:
  276. nonce = util.nonce(8)
  277. instances = [
  278. launch(
  279. key_name=key_name,
  280. instance_type=d.instance_type,
  281. ami=d.ami,
  282. ami_user=d.ami_user,
  283. tags={**d.tags, **extra_tags},
  284. display_name=f"{nonce}-{d.name}",
  285. size_gb=d.size_gb,
  286. security_group_name=security_group_name,
  287. instance_profile=instance_profile,
  288. nonce=nonce,
  289. delete_after=delete_after,
  290. )
  291. for d in descs
  292. ]
  293. loop = asyncio.get_event_loop()
  294. loop.run_until_complete(
  295. asyncio.gather(
  296. *(
  297. setup(i, git_rev if d.checkout else "HEAD")
  298. for (i, d) in zip(instances, descs)
  299. )
  300. )
  301. )
  302. hosts_str = "".join(
  303. f"{i.private_ip_address}\t{d.name}\n" for (i, d) in zip(instances, descs)
  304. )
  305. for i in instances:
  306. mssh(i, "sudo tee -a /etc/hosts", input=hosts_str.encode())
  307. env = " ".join(f"{k}={shlex.quote(v)}" for k, v in extra_env.items())
  308. for i, d in zip(instances, descs):
  309. if d.launch_script:
  310. mssh(
  311. i,
  312. f"(cd materialize && {env} nohup bash -c {shlex.quote(d.launch_script)}) &> mzscratch.log &",
  313. )
  314. return instances
  315. def whoami() -> str:
  316. return boto3.client("sts").get_caller_identity()["UserId"].split(":")[1]
  317. def get_instance(name: str) -> Instance:
  318. """
  319. Get an instance by instance id. The special name 'mine' resolves to a
  320. unique running owned instance, if there is one; otherwise the name is
  321. assumed to be an instance id.
  322. :param name: The instance id or the special case 'mine'.
  323. :return: The instance to which the name refers.
  324. """
  325. if name == "mine":
  326. filters: list[FilterTypeDef] = [
  327. {"Name": "tag:LaunchedBy", "Values": [whoami()]},
  328. {"Name": "instance-state-name", "Values": ["pending", "running"]},
  329. ]
  330. instances = [i for i in boto3.resource("ec2").instances.filter(Filters=filters)]
  331. if not instances:
  332. raise RuntimeError("can't understand 'mine': no owned instance?")
  333. if len(instances) > 1:
  334. raise RuntimeError(
  335. f"can't understand 'mine': too many owned instances ({', '.join(i.id for i in instances)})"
  336. )
  337. instance = instances[0]
  338. say(f"understanding 'mine' as unique owned instance {instance.id}")
  339. return instance
  340. return boto3.resource("ec2").Instance(name)
  341. def get_instances_by_tag(k: str, v: str) -> list[InstanceTypeDef]:
  342. return [
  343. i
  344. for r in boto3.client("ec2").describe_instances()["Reservations"]
  345. for i in r["Instances"]
  346. if instance_typedef_tags(i).get(k) == v
  347. ]
  348. def get_old_instances() -> list[InstanceTypeDef]:
  349. def exists(i: InstanceTypeDef) -> bool:
  350. return i["State"]["Name"] != "terminated"
  351. def is_old(i: InstanceTypeDef) -> bool:
  352. delete_after = instance_typedef_tags(i).get("scratch-delete-after")
  353. if delete_after is None:
  354. return False
  355. delete_after = float(delete_after)
  356. return datetime.datetime.utcnow().timestamp() > delete_after
  357. return [
  358. i
  359. for r in boto3.client("ec2").describe_instances()["Reservations"]
  360. for i in r["Instances"]
  361. if exists(i) and is_old(i)
  362. ]
  363. def mssh(
  364. instance: Instance,
  365. command: str,
  366. *,
  367. extra_ssh_args: list[str] = [],
  368. input: bytes | None = None,
  369. ) -> None:
  370. """Runs a command over SSH via EC2 Instance Connect."""
  371. host = instance_host(instance)
  372. if command:
  373. print(f"{host}$ {command}", file=sys.stderr)
  374. # Quote to work around:
  375. # https://github.com/aws/aws-ec2-instance-connect-cli/pull/26
  376. command = shlex.quote(command)
  377. else:
  378. print(f"$ mssh {host}")
  379. subprocess.run(
  380. [
  381. *SSH_COMMAND,
  382. *extra_ssh_args,
  383. host,
  384. command,
  385. ],
  386. check=True,
  387. input=input,
  388. )
  389. def msftp(
  390. instance: Instance,
  391. ) -> None:
  392. """Connects over SFTP via EC2 Instance Connect."""
  393. host = instance_host(instance)
  394. spawn.runv([*SFTP_COMMAND, host])