mz_client.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License in the LICENSE file at the
  6. # root of this repository, or online at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import asyncio
  16. import base64
  17. import decimal
  18. import json
  19. import logging
  20. from importlib.resources import files
  21. from textwrap import dedent
  22. from typing import Any
  23. from uuid import UUID
  24. import aiorwlock
  25. from mcp import Tool
  26. from mcp.types import ToolAnnotations
  27. from psycopg import sql
  28. from psycopg.rows import dict_row
  29. from psycopg_pool import AsyncConnectionPool
  30. logger = logging.getLogger("mz_mcp_server")
  31. logging.basicConfig(
  32. level=logging.INFO,
  33. format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  34. )
  35. TOOL_QUERY = (files("mcp_materialize.sql") / "tools.sql").read_text()
  36. class MzTool:
  37. def __init__(
  38. self,
  39. name,
  40. database,
  41. schema,
  42. object_name,
  43. cluster,
  44. title,
  45. description,
  46. input_schema,
  47. output_schema,
  48. output_columns,
  49. ):
  50. self.name = name
  51. self.database = database
  52. self.schema = schema
  53. self.object_name = object_name
  54. self.cluster = cluster
  55. self.title = title
  56. self.description = description
  57. self.input_schema = input_schema
  58. self.output_schema = output_schema
  59. self.output_columns = output_columns
  60. def as_tool(self) -> Tool:
  61. return Tool(
  62. name=self.name,
  63. description=self.description,
  64. inputSchema=self.input_schema,
  65. outputSchema=self.output_schema,
  66. annotations=ToolAnnotations(title=self.title, readOnlyHint=True),
  67. )
  68. class MissingTool(Exception):
  69. def __init__(self, message):
  70. super().__init__(message)
  71. class MzClient:
  72. def __init__(self, pool: AsyncConnectionPool) -> None:
  73. self.pool = pool
  74. self.tools: dict[str, MzTool] = {}
  75. self._lock = aiorwlock.RWLock()
  76. self._bg_task: asyncio.Task | None = None
  77. async def __aenter__(self) -> "MzClient":
  78. await self._load_tools()
  79. self._bg_task = asyncio.create_task(self._subscribe())
  80. return self
  81. async def __aexit__(self, exc_type, exc, tb) -> None:
  82. if self._bg_task:
  83. self._bg_task.cancel()
  84. try:
  85. await self._bg_task
  86. except asyncio.CancelledError:
  87. pass
  88. async def _subscribe(self) -> None:
  89. """
  90. Watches materialize for new tools.
  91. We cannot subscribe to the `TOOL` query directly because it relies on
  92. non-materializable functions. Instead, we watch indexes on objects that
  93. have comments as a proxy and then execute the full query.
  94. """
  95. try:
  96. async with self.pool.connection() as conn:
  97. await conn.set_autocommit(True)
  98. async with conn.cursor(row_factory=dict_row) as cur:
  99. logger.info("Starting background tool subscription")
  100. await cur.execute("BEGIN")
  101. await cur.execute(
  102. dedent(
  103. """
  104. DECLARE c CURSOR FOR
  105. SUBSCRIBE (
  106. SELECT count(*) AS eligible_tools
  107. FROM mz_objects o
  108. JOIN mz_indexes i ON o.id = i.on_id
  109. JOIN mz_internal.mz_comments cts ON cts.id = o.id
  110. ) WITH (PROGRESS)
  111. """
  112. )
  113. )
  114. while True:
  115. await cur.execute("FETCH ALL c")
  116. reload = False
  117. async for row in cur:
  118. if not row["mz_progressed"]:
  119. reload = True
  120. elif reload:
  121. logger.info("Reloading catalog of available tools")
  122. await self._load_tools()
  123. reload = False
  124. except asyncio.CancelledError:
  125. logger.info("Stopping background tool subscription")
  126. return
  127. async def _load_tools(self) -> None:
  128. """
  129. Load the catalog of available tools into self.tools under lock.
  130. """
  131. new_tools: dict[str, MzTool] = {}
  132. async with self.pool.connection() as conn:
  133. await conn.set_autocommit(True)
  134. async with conn.cursor(row_factory=dict_row) as cur:
  135. await cur.execute(TOOL_QUERY)
  136. async for row in cur:
  137. tool = MzTool(
  138. name=row["name"],
  139. database=row["database"],
  140. schema=row["schema"],
  141. object_name=row["object_name"],
  142. cluster=row["cluster"],
  143. title=row["title"],
  144. description=row["description"],
  145. input_schema=row["input_schema"],
  146. output_schema=row["output_schema"],
  147. output_columns=row["output_columns"],
  148. )
  149. new_tools[tool.name] = tool
  150. # swap in the fresh catalog
  151. async with self._lock.writer_lock:
  152. self.tools = new_tools
  153. async def list_tools(self) -> list[Tool]:
  154. """
  155. Return the catalog of available tools.
  156. """
  157. async with self._lock.reader_lock:
  158. return [tool.as_tool() for tool in self.tools.values()]
  159. async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
  160. pool = self.pool
  161. async with self._lock.reader_lock:
  162. tool = self.tools.get(name)
  163. if not tool:
  164. raise MissingTool(f"Tool not found: {name}")
  165. async with pool.connection() as conn:
  166. await conn.set_autocommit(True)
  167. async with conn.cursor() as cur:
  168. await cur.execute(
  169. sql.SQL("SET cluster TO {};").format(sql.Identifier(tool.cluster))
  170. )
  171. await cur.execute(
  172. sql.SQL("SELECT {} FROM {} WHERE {};").format(
  173. (
  174. sql.SQL("count(*) > 0 AS exists")
  175. if not tool.output_columns
  176. else sql.SQL(",").join(
  177. sql.Identifier(col) for col in tool.output_columns
  178. )
  179. ),
  180. sql.Identifier(tool.database, tool.schema, tool.object_name),
  181. sql.SQL(" AND ").join(
  182. [
  183. sql.SQL("{} = {}").format(
  184. sql.Identifier(k), sql.Placeholder()
  185. )
  186. for k in arguments.keys()
  187. ]
  188. ),
  189. ),
  190. list(arguments.values()),
  191. )
  192. rows = await cur.fetchall()
  193. columns = [desc.name for desc in cur.description]
  194. raw = [
  195. {k: v for k, v in dict(zip(columns, row)).items()} for row in rows
  196. ]
  197. return serialize({"rows": raw})
  198. def serialize(obj):
  199. """Serialize any Decimal/date/bytes/UUID into JSON-safe primitives."""
  200. # json.dumps will call json_serial for any non-standard type,
  201. # then json.loads turns it back into a Python dict/list of primitives.
  202. # Structured output types require the tool returns dict[str, Any]
  203. # but the json encoder used by the mcp library does not support all
  204. # standard postgres types
  205. return json.loads(json.dumps(obj, default=json_serial))
  206. def json_serial(obj):
  207. """JSON serializer for objects not serializable by default json code"""
  208. from datetime import date, datetime, time, timedelta
  209. if isinstance(obj, datetime | date | time):
  210. return obj.isoformat()
  211. elif isinstance(obj, timedelta):
  212. return obj.total_seconds()
  213. elif isinstance(obj, bytes):
  214. return base64.b64encode(obj).decode("ascii")
  215. elif isinstance(obj, decimal.Decimal):
  216. return float(obj)
  217. elif isinstance(obj, UUID):
  218. return str(obj)
  219. else:
  220. raise TypeError(f"Type {type(obj)} not serializable. This is a bug.")