test_mcp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License in the LICENSE file at the
  6. # root of this repository, or online at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import pytest
  17. import pytest_asyncio
  18. from mcp import Tool
  19. from mcp.types import ToolAnnotations
  20. from psycopg_pool import AsyncConnectionPool
  21. from mcp_materialize.mz_client import MzClient
  22. @pytest_asyncio.fixture(scope="function")
  23. async def materialize_pool():
  24. conn = os.getenv("MZ_DSN", "postgres://materialize@localhost:6875/materialize")
  25. async with AsyncConnectionPool(
  26. conninfo=conn, min_size=1, max_size=10, open=False
  27. ) as pool:
  28. async with pool.connection() as conn:
  29. await conn.set_autocommit(True)
  30. async with conn.cursor() as cur:
  31. await cur.execute("DROP SCHEMA IF EXISTS materialize.tools CASCADE;")
  32. await cur.execute("CREATE SCHEMA materialize.tools;")
  33. yield pool
  34. @pytest.mark.asyncio
  35. async def test_basic_tool(materialize_pool):
  36. async with MzClient(pool=materialize_pool) as client:
  37. tools = await client.list_tools()
  38. assert len(tools) == 0
  39. async with materialize_pool.connection() as conn:
  40. await conn.set_autocommit(True)
  41. async with conn.cursor() as cur:
  42. await cur.execute(
  43. """CREATE OR REPLACE VIEW tools.my_tool AS
  44. SELECT 1 AS id, 'hello' AS result;"""
  45. )
  46. await cur.execute("CREATE INDEX my_tool_id_idx ON tools.my_tool (id);")
  47. await cur.execute("COMMENT ON VIEW tools.my_tool IS 'Get result from id';")
  48. await cur.execute(
  49. """CREATE OR REPLACE VIEW tools.missing_comment AS
  50. SELECT 1 AS id, 'goodbye' AS result;"""
  51. )
  52. await cur.execute(
  53. "CREATE INDEX missing_comment_id_idx ON tools.missing_comment (id);"
  54. )
  55. await cur.execute(
  56. """CREATE OR REPLACE VIEW tools.missing_idx AS
  57. SELECT 1 AS id, 'not it' AS result;"""
  58. )
  59. await cur.execute(
  60. "COMMENT ON VIEW tools.missing_idx IS 'Get result from id';"
  61. )
  62. async with MzClient(pool=materialize_pool) as client:
  63. tools = await client.list_tools()
  64. assert len(tools) == 1
  65. assert tools[0] == Tool(
  66. name="materialize_tools_my_tool_id_idx",
  67. description="Get result from id",
  68. inputSchema={
  69. "type": "object",
  70. "required": ["id"],
  71. "properties": {"id": {"type": "number"}},
  72. },
  73. outputSchema={
  74. "type": "object",
  75. "properties": {
  76. "rows": {
  77. "type": "array",
  78. "items": {
  79. "type": "object",
  80. "required": ["result"],
  81. "properties": {"result": {"type": "string"}},
  82. },
  83. }
  84. },
  85. "required": ["rows"],
  86. },
  87. annotations=ToolAnnotations(
  88. title="materialize::tools::my_tool(id)",
  89. readOnlyHint=True,
  90. ),
  91. )
  92. result = await client.call_tool("materialize_tools_my_tool_id_idx", {"id": 1})
  93. rows = result["rows"]
  94. assert len(rows) == 1
  95. assert rows[0] == {"result": "hello"}
  96. @pytest.mark.asyncio
  97. async def test_exists_tool(materialize_pool):
  98. async with materialize_pool.connection() as conn:
  99. await conn.set_autocommit(True)
  100. async with conn.cursor() as cur:
  101. await cur.execute("CREATE OR REPLACE VIEW tools.my_tool AS SELECT 1 AS id;")
  102. await cur.execute("CREATE INDEX my_tool_id_idx ON tools.my_tool (id);")
  103. await cur.execute("COMMENT ON VIEW tools.my_tool IS 'Check if id exists';")
  104. async with MzClient(pool=materialize_pool) as client:
  105. result = await client.call_tool("materialize_tools_my_tool_id_idx", {"id": 1})
  106. rows = result["rows"]
  107. assert len(rows) == 1
  108. assert rows[0] == {"exists": True}
  109. result = await client.call_tool("materialize_tools_my_tool_id_idx", {"id": 2})
  110. rows = result["rows"]
  111. assert len(rows) == 1
  112. assert rows[0] == {"exists": False}
  113. @pytest.mark.asyncio
  114. async def test_type_handling_keys(materialize_pool):
  115. async with materialize_pool.connection() as conn:
  116. await conn.set_autocommit(True)
  117. async with conn.cursor() as cur:
  118. await cur.execute(
  119. """
  120. CREATE OR REPLACE VIEW tools.all_types AS
  121. SELECT
  122. 1::smallint AS smallint_col,
  123. 2::integer AS integer_col,
  124. 3::bigint AS bigint_col,
  125. 2::uint4 AS uint2_col,
  126. 4::uint4 AS uint4_col,
  127. 8::uint4 AS uint8_col,
  128. 4.5::real AS real_col,
  129. 6.7::double precision AS double_col,
  130. 1.23::numeric AS numeric_col,
  131. true AS boolean_col,
  132. 'a'::char AS char_col,
  133. 'abc'::varchar AS varchar_col,
  134. 'abc'::text AS text_col,
  135. '2024-01-01'::date AS date_col,
  136. '12:34:56'::time AS time_col,
  137. '2024-01-01 12:34:56'::timestamp AS timestamp_col,
  138. '2024-01-01 12:34:56+00'::timestamptz AS timestamptz_col,
  139. decode('DEADBEEF', 'hex')::bytea AS bytea_col,
  140. '{"a": 1, "b": [1, 2, 3]}'::jsonb AS jsonb_col,
  141. '550e8400-e29b-41d4-a716-446655440000'::uuid AS uuid_col;
  142. """
  143. )
  144. await cur.execute("CREATE DEFAULT INDEX all_types_idx ON tools.all_types;")
  145. await cur.execute("COMMENT ON VIEW tools.all_types IS 'All types';")
  146. async with MzClient(pool=materialize_pool) as client:
  147. tools = await client.list_tools()
  148. assert len(tools) == 1
  149. assert tools[0].name == "materialize_tools_all_types_idx"
  150. assert tools[0].description == "All types"
  151. assert sorted(tools[0].inputSchema["required"]) == sorted(
  152. [
  153. "bigint_col",
  154. "boolean_col",
  155. "bytea_col",
  156. "char_col",
  157. "date_col",
  158. "double_col",
  159. "integer_col",
  160. "jsonb_col",
  161. "numeric_col",
  162. "real_col",
  163. "smallint_col",
  164. "text_col",
  165. "time_col",
  166. "timestamp_col",
  167. "timestamptz_col",
  168. "uint2_col",
  169. "uint4_col",
  170. "uint8_col",
  171. "uuid_col",
  172. "varchar_col",
  173. ]
  174. )
  175. assert tools[0].inputSchema["properties"] == {
  176. "bigint_col": {"type": "number"},
  177. "boolean_col": {"type": "boolean"},
  178. "bytea_col": {
  179. "type": "string",
  180. "contentEncoding": "base64",
  181. "contentMediaType": "application/octet-stream",
  182. },
  183. "char_col": {"type": "string"},
  184. "smallint_col": {"type": "number"},
  185. "double_col": {"type": "number"},
  186. "text_col": {"type": "string"},
  187. "integer_col": {"type": "number"},
  188. "uint2_col": {"type": "number"},
  189. "uint4_col": {"type": "number"},
  190. "uint8_col": {"type": "number"},
  191. "date_col": {"type": "string", "format": "date"},
  192. "time_col": {"type": "string", "format": "time"},
  193. "timestamp_col": {"type": "string", "format": "date-time"},
  194. "timestamptz_col": {"type": "string", "format": "date-time"},
  195. "jsonb_col": {"type": "object"},
  196. "numeric_col": {"type": "number"},
  197. "real_col": {"type": "number"},
  198. "varchar_col": {"type": "string"},
  199. "uuid_col": {"type": "string", "format": "uuid"},
  200. }
  201. @pytest.mark.asyncio
  202. async def test_type_handling_values(materialize_pool):
  203. async with materialize_pool.connection() as conn:
  204. await conn.set_autocommit(True)
  205. async with conn.cursor() as cur:
  206. await cur.execute(
  207. """
  208. CREATE OR REPLACE VIEW tools.all_types AS
  209. SELECT
  210. 1 AS id,
  211. 1::smallint AS smallint_col,
  212. 2::integer AS integer_col,
  213. 3::bigint AS bigint_col,
  214. 2::uint4 AS uint2_col,
  215. 4::uint4 AS uint4_col,
  216. 8::uint4 AS uint8_col,
  217. 4.5::real AS real_col,
  218. 6.7::double precision AS double_col,
  219. 1.23::numeric AS numeric_col,
  220. true AS boolean_col,
  221. 'a'::char AS char_col,
  222. 'abc'::varchar AS varchar_col,
  223. 'abc'::text AS text_col,
  224. '2024-01-01'::date AS date_col,
  225. '12:34:56'::time AS time_col,
  226. '2024-01-01 12:34:56'::timestamp AS timestamp_col,
  227. '2024-01-01 12:34:56+00'::timestamptz AS timestamptz_col,
  228. decode('DEADBEEF', 'hex')::bytea AS bytea_col,
  229. '{"a": 1, "b": [1, 2, 3]}'::jsonb AS jsonb_col,
  230. '550e8400-e29b-41d4-a716-446655440000'::uuid AS uuid_col;
  231. """
  232. )
  233. await cur.execute("CREATE INDEX all_types_idx ON tools.all_types (id);")
  234. await cur.execute("COMMENT ON VIEW tools.all_types IS 'All types';")
  235. async with MzClient(pool=materialize_pool) as client:
  236. results = await client.call_tool("materialize_tools_all_types_idx", {"id": 1})
  237. assert len(results) == 1