create.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import argparse
  10. import datetime
  11. import json
  12. import sys
  13. from typing import Any
  14. from materialize.cli.scratch import check_required_vars
  15. from materialize.scratch import (
  16. DEFAULT_INSTANCE_PROFILE_NAME,
  17. DEFAULT_SECURITY_GROUP_NAME,
  18. MZ_ROOT,
  19. MachineDesc,
  20. launch_cluster,
  21. mssh,
  22. print_instances,
  23. whoami,
  24. )
  25. MAX_AGE_DAYS = 1.5
  26. def multi_json(s: str) -> list[dict[Any, Any]]:
  27. """Read zero or more JSON objects from a string,
  28. without requiring each of them to be on its own line.
  29. For example:
  30. {
  31. "name": "First Object"
  32. }{"name": "Second Object"}
  33. """
  34. decoder = json.JSONDecoder()
  35. idx = 0
  36. result = []
  37. while idx < len(s):
  38. if s[idx] in " \t\n\r":
  39. idx += 1
  40. else:
  41. (obj, idx) = decoder.raw_decode(s, idx)
  42. result.append(obj)
  43. return result
  44. def configure_parser(parser: argparse.ArgumentParser) -> None:
  45. parser.add_argument(
  46. "--key-name", type=str, required=False, help="Optional EC2 Key Pair name"
  47. )
  48. parser.add_argument(
  49. "--security-group-name",
  50. type=str,
  51. default=DEFAULT_SECURITY_GROUP_NAME,
  52. help="EC2 Security Group name. Defaults to Materialize scratch account.",
  53. )
  54. parser.add_argument(
  55. "--extra-tags",
  56. type=str,
  57. required=False,
  58. help='Additional EC2 tags for created instance. Format: {"key", "value"}',
  59. )
  60. parser.add_argument(
  61. "--instance-profile",
  62. type=str,
  63. default=DEFAULT_INSTANCE_PROFILE_NAME,
  64. help="EC2 instance profile / IAM role. Defaults to `%s`."
  65. % DEFAULT_INSTANCE_PROFILE_NAME,
  66. )
  67. parser.add_argument("--output-format", choices=["table", "csv"], default="table")
  68. parser.add_argument(
  69. "--git-rev",
  70. type=str,
  71. default="HEAD",
  72. help="Git revision of `materialize` codebase to push to scratch instance. Defaults to `HEAD`",
  73. )
  74. parser.add_argument(
  75. "--ssh",
  76. action="store_true",
  77. help=(
  78. "ssh into the machine after the launch script is run. "
  79. "Only works if a single instance was started"
  80. ),
  81. )
  82. parser.add_argument(
  83. "machine",
  84. nargs="?",
  85. const=None,
  86. help=(
  87. "Use a config from {machine}.json in `misc/scratch`. "
  88. "Hint: `dev-box` is a good starter!"
  89. ),
  90. )
  91. parser.add_argument(
  92. "--max-age-days",
  93. type=float,
  94. default=MAX_AGE_DAYS,
  95. help="Maximum age for scratch instance in days. Defaults to 1.5",
  96. )
  97. def run(args: argparse.Namespace) -> None:
  98. extra_tags = {}
  99. if args.extra_tags:
  100. extra_tags = json.loads(args.extra_tags)
  101. if not isinstance(extra_tags, dict) or not all(
  102. isinstance(k, str) and isinstance(v, str) for k, v in extra_tags.items()
  103. ):
  104. raise RuntimeError(
  105. "extra-tags must be a JSON dictionary of strings to strings"
  106. )
  107. check_required_vars()
  108. extra_tags["LaunchedBy"] = whoami()
  109. if args.machine:
  110. with open(MZ_ROOT / "misc" / "scratch" / f"{args.machine}.json") as f:
  111. print(f"Reading machine configs from {f.name}")
  112. descs = [MachineDesc.model_validate(obj) for obj in multi_json(f.read())]
  113. else:
  114. print("Reading machine configs from stdin...")
  115. descs = [
  116. MachineDesc.model_validate(obj) for obj in multi_json(sys.stdin.read())
  117. ]
  118. if args.ssh and len(descs) != 1:
  119. raise RuntimeError(f"Cannot use `--ssh` with {len(descs)} instances")
  120. if args.max_age_days <= 0:
  121. raise RuntimeError(f"max_age_days must be positive, got {args.max_age_days}")
  122. max_age = datetime.timedelta(days=args.max_age_days)
  123. instances = launch_cluster(
  124. descs,
  125. key_name=args.key_name,
  126. security_group_name=args.security_group_name,
  127. instance_profile=args.instance_profile,
  128. extra_tags=extra_tags,
  129. delete_after=datetime.datetime.utcnow() + max_age,
  130. git_rev=args.git_rev,
  131. extra_env={},
  132. )
  133. print("Launched instances:")
  134. print_instances(instances, args.output_format)
  135. if args.ssh:
  136. print(f"ssh-ing into: {instances[0].instance_id}")
  137. mssh(instances[0], "")