123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License in the LICENSE file at the
- # root of this repository, or online at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import asyncio
- import base64
- import decimal
- import json
- import logging
- from importlib.resources import files
- from textwrap import dedent
- from typing import Any
- from uuid import UUID
- import aiorwlock
- from mcp import Tool
- from mcp.types import ToolAnnotations
- from psycopg import sql
- from psycopg.rows import dict_row
- from psycopg_pool import AsyncConnectionPool
- logger = logging.getLogger("mz_mcp_server")
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- )
- TOOL_QUERY = (files("mcp_materialize.sql") / "tools.sql").read_text()
- class MzTool:
- def __init__(
- self,
- name,
- database,
- schema,
- object_name,
- cluster,
- title,
- description,
- input_schema,
- output_schema,
- output_columns,
- ):
- self.name = name
- self.database = database
- self.schema = schema
- self.object_name = object_name
- self.cluster = cluster
- self.title = title
- self.description = description
- self.input_schema = input_schema
- self.output_schema = output_schema
- self.output_columns = output_columns
- def as_tool(self) -> Tool:
- return Tool(
- name=self.name,
- description=self.description,
- inputSchema=self.input_schema,
- outputSchema=self.output_schema,
- annotations=ToolAnnotations(title=self.title, readOnlyHint=True),
- )
- class MissingTool(Exception):
- def __init__(self, message):
- super().__init__(message)
- class MzClient:
- def __init__(self, pool: AsyncConnectionPool) -> None:
- self.pool = pool
- self.tools: dict[str, MzTool] = {}
- self._lock = aiorwlock.RWLock()
- self._bg_task: asyncio.Task | None = None
- async def __aenter__(self) -> "MzClient":
- await self._load_tools()
- self._bg_task = asyncio.create_task(self._subscribe())
- return self
- async def __aexit__(self, exc_type, exc, tb) -> None:
- if self._bg_task:
- self._bg_task.cancel()
- try:
- await self._bg_task
- except asyncio.CancelledError:
- pass
- async def _subscribe(self) -> None:
- """
- Watches materialize for new tools.
- We cannot subscribe to the `TOOL` query directly because it relies on
- non-materializable functions. Instead, we watch indexes on objects that
- have comments as a proxy and then execute the full query.
- """
- try:
- async with self.pool.connection() as conn:
- await conn.set_autocommit(True)
- async with conn.cursor(row_factory=dict_row) as cur:
- logger.info("Starting background tool subscription")
- await cur.execute("BEGIN")
- await cur.execute(
- dedent(
- """
- DECLARE c CURSOR FOR
- SUBSCRIBE (
- SELECT count(*) AS eligible_tools
- FROM mz_objects o
- JOIN mz_indexes i ON o.id = i.on_id
- JOIN mz_internal.mz_comments cts ON cts.id = o.id
- ) WITH (PROGRESS)
- """
- )
- )
- while True:
- await cur.execute("FETCH ALL c")
- reload = False
- async for row in cur:
- if not row["mz_progressed"]:
- reload = True
- elif reload:
- logger.info("Reloading catalog of available tools")
- await self._load_tools()
- reload = False
- except asyncio.CancelledError:
- logger.info("Stopping background tool subscription")
- return
- async def _load_tools(self) -> None:
- """
- Load the catalog of available tools into self.tools under lock.
- """
- new_tools: dict[str, MzTool] = {}
- async with self.pool.connection() as conn:
- await conn.set_autocommit(True)
- async with conn.cursor(row_factory=dict_row) as cur:
- await cur.execute(TOOL_QUERY)
- async for row in cur:
- tool = MzTool(
- name=row["name"],
- database=row["database"],
- schema=row["schema"],
- object_name=row["object_name"],
- cluster=row["cluster"],
- title=row["title"],
- description=row["description"],
- input_schema=row["input_schema"],
- output_schema=row["output_schema"],
- output_columns=row["output_columns"],
- )
- new_tools[tool.name] = tool
- # swap in the fresh catalog
- async with self._lock.writer_lock:
- self.tools = new_tools
- async def list_tools(self) -> list[Tool]:
- """
- Return the catalog of available tools.
- """
- async with self._lock.reader_lock:
- return [tool.as_tool() for tool in self.tools.values()]
- async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
- pool = self.pool
- async with self._lock.reader_lock:
- tool = self.tools.get(name)
- if not tool:
- raise MissingTool(f"Tool not found: {name}")
- async with pool.connection() as conn:
- await conn.set_autocommit(True)
- async with conn.cursor() as cur:
- await cur.execute(
- sql.SQL("SET cluster TO {};").format(sql.Identifier(tool.cluster))
- )
- await cur.execute(
- sql.SQL("SELECT {} FROM {} WHERE {};").format(
- (
- sql.SQL("count(*) > 0 AS exists")
- if not tool.output_columns
- else sql.SQL(",").join(
- sql.Identifier(col) for col in tool.output_columns
- )
- ),
- sql.Identifier(tool.database, tool.schema, tool.object_name),
- sql.SQL(" AND ").join(
- [
- sql.SQL("{} = {}").format(
- sql.Identifier(k), sql.Placeholder()
- )
- for k in arguments.keys()
- ]
- ),
- ),
- list(arguments.values()),
- )
- rows = await cur.fetchall()
- columns = [desc.name for desc in cur.description]
- raw = [
- {k: v for k, v in dict(zip(columns, row)).items()} for row in rows
- ]
- return serialize({"rows": raw})
- def serialize(obj):
- """Serialize any Decimal/date/bytes/UUID into JSON-safe primitives."""
- # json.dumps will call json_serial for any non-standard type,
- # then json.loads turns it back into a Python dict/list of primitives.
- # Structured output types require the tool returns dict[str, Any]
- # but the json encoder used by the mcp library does not support all
- # standard postgres types
- return json.loads(json.dumps(obj, default=json_serial))
- def json_serial(obj):
- """JSON serializer for objects not serializable by default json code"""
- from datetime import date, datetime, time, timedelta
- if isinstance(obj, datetime | date | time):
- return obj.isoformat()
- elif isinstance(obj, timedelta):
- return obj.total_seconds()
- elif isinstance(obj, bytes):
- return base64.b64encode(obj).decode("ascii")
- elif isinstance(obj, decimal.Decimal):
- return float(obj)
- elif isinstance(obj, UUID):
- return str(obj)
- else:
- raise TypeError(f"Type {type(obj)} not serializable. This is a bug.")
|