Skip to content

Commit 60846e4

Browse files
committed
feat(tools): introduce tool abstraction and decorator
1 parent 92fb938 commit 60846e4

4 files changed

Lines changed: 238 additions & 0 deletions

File tree

examples/tools/basic_tool.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Basic example of creating and using a QuantMind tool."""
2+
3+
import asyncio
4+
5+
import aiohttp
6+
7+
from quantmind.tools import tool
8+
9+
10+
@tool
11+
async def get_btc_usdt_price() -> float:
12+
"""Fetch the latest BTC/USDT price from Binance API."""
13+
url = "https://api.binance.com/api/v3/ticker/price?symbol=BTCUSDT"
14+
15+
async with aiohttp.ClientSession() as session:
16+
async with session.get(url) as response:
17+
data = await response.json()
18+
return float(data["price"])
19+
20+
21+
async def main():
22+
"""Main function to run the example."""
23+
price = await get_btc_usdt_price.run()
24+
print(f"BTC/USDT price: {price}")
25+
26+
27+
if __name__ == "__main__":
28+
asyncio.run(main())

quantmind/tools/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Tool abstractions for QuantMind.
2+
3+
Exports the standardized tool interface and convenience decorator.
4+
"""
5+
6+
from .base import BaseTool, FunctionTool, tool
7+
8+
__all__ = ["BaseTool", "FunctionTool", "tool"]

quantmind/tools/base.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from abc import ABC, abstractmethod
5+
from functools import partial
6+
from typing import Any, Callable, Dict, Type
7+
8+
from pydantic import BaseModel, Field, create_model
9+
10+
11+
class BaseTool(ABC):
12+
"""Abstract base class for all QuantMind tools.
13+
14+
A tool self-describes its capability via a name, description, and a Pydantic
15+
input schema, and exposes an async `run` that validates inputs before execution.
16+
"""
17+
18+
@property
19+
@abstractmethod
20+
def name(self) -> str:
21+
"""Unique name used by an LLM to invoke this tool."""
22+
23+
@property
24+
@abstractmethod
25+
def description(self) -> str:
26+
"""Human-readable description of the tool for LLM selection."""
27+
28+
@property
29+
@abstractmethod
30+
def args_schema(self) -> Type[BaseModel]:
31+
"""Pydantic model describing required/optional input arguments."""
32+
33+
@abstractmethod
34+
async def _arun(self, **kwargs: Any) -> Any:
35+
"""Core async execution logic for the tool (implemented by subclasses)."""
36+
37+
async def run(self, **kwargs: Any) -> Any:
38+
"""Validate inputs against schema, then execute the tool asynchronously."""
39+
validated = self.args_schema(**kwargs)
40+
return await self._arun(**validated.model_dump())
41+
42+
def to_openai_schema(self) -> Dict[str, Any]:
43+
"""Return schema compatible with OpenAI function calling tools."""
44+
return {
45+
"type": "function",
46+
"function": {
47+
"name": self.name,
48+
"description": self.description,
49+
"parameters": self.args_schema.model_json_schema(),
50+
},
51+
}
52+
53+
54+
class FunctionTool(BaseTool):
55+
"""Wrap a Python callable as a QuantMind tool.
56+
57+
The callable may be sync or async. Sync functions are executed in a thread
58+
pool to avoid blocking the event loop.
59+
"""
60+
61+
def __init__(
62+
self,
63+
fn: Callable[..., Any],
64+
name: str,
65+
description: str,
66+
args_schema: Type[BaseModel],
67+
) -> None:
68+
self._fn = fn
69+
self._name = name
70+
self._description = description
71+
self._args_schema = args_schema
72+
73+
@property
74+
def name(self) -> str: # type: ignore[override]
75+
return self._name
76+
77+
@property
78+
def description(self) -> str: # type: ignore[override]
79+
return self._description
80+
81+
@property
82+
def args_schema(self) -> Type[BaseModel]: # type: ignore[override]
83+
return self._args_schema
84+
85+
async def _arun(self, **kwargs: Any) -> Any: # type: ignore[override]
86+
if inspect.iscoroutinefunction(self._fn):
87+
return await self._fn(**kwargs) # type: ignore[misc]
88+
# For sync functions, run in a thread pool
89+
import asyncio
90+
91+
loop = asyncio.get_running_loop()
92+
return await loop.run_in_executor(None, partial(self._fn, **kwargs))
93+
94+
95+
def _build_args_schema_from_signature(
96+
fn: Callable[..., Any],
97+
) -> Type[BaseModel]:
98+
"""Create a Pydantic model from a function's signature.
99+
100+
Parameters without annotations default to `Any`. All parameters are required
101+
unless a default value exists on the function.
102+
"""
103+
sig = inspect.signature(fn)
104+
fields: Dict[str, tuple[Any, Any]] = {}
105+
106+
for param in sig.parameters.values():
107+
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
108+
# Skip variadic params for schema simplicity
109+
continue
110+
111+
annotation = (
112+
param.annotation if param.annotation is not inspect._empty else Any
113+
)
114+
115+
# Required if no default, else use default
116+
if param.default is inspect._empty:
117+
default = Field(..., description=f"Parameter for {param.name}")
118+
else:
119+
default = Field(
120+
default=param.default, description=f"Parameter for {param.name}"
121+
)
122+
123+
fields[param.name] = (annotation, default)
124+
125+
model_name = f"{fn.__name__.capitalize()}Inputs"
126+
return create_model(model_name, **fields) # type: ignore[return-value]
127+
128+
129+
def tool(fn: Callable[..., Any]) -> BaseTool:
130+
"""Decorator that converts a function into a QuantMind Tool.
131+
132+
The function's docstring becomes the description; its signature and type
133+
annotations define the input schema. Returns a `FunctionTool` instance.
134+
"""
135+
docstring = inspect.getdoc(fn)
136+
if not docstring:
137+
raise ValueError(
138+
"Tool function must have a docstring for its description."
139+
)
140+
141+
description = docstring.strip()
142+
name = fn.__name__
143+
args_schema = _build_args_schema_from_signature(fn)
144+
return FunctionTool(
145+
fn=fn, name=name, description=description, args_schema=args_schema
146+
)

tests/tools/test_tools.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import asyncio
2+
import unittest
3+
4+
from pydantic import ValidationError
5+
6+
from quantmind.tools import BaseTool, tool
7+
8+
9+
class TestTools(unittest.TestCase):
10+
"""Test the tools module."""
11+
12+
def test_tool_requires_docstring(self):
13+
"""Test that a tool requires a docstring."""
14+
15+
def no_doc(a: int):
16+
return a
17+
18+
with self.assertRaises(ValueError):
19+
tool(no_doc)
20+
21+
def test_sync_function_tool_run_and_schema(self):
22+
"""Test that a sync function tool runs and validates the schema."""
23+
24+
@tool
25+
def add(a: int, b: int) -> int:
26+
"""Adds two integers."""
27+
return a + b
28+
29+
# to_openai_schema shape
30+
schema = add.to_openai_schema()
31+
self.assertEqual(schema["type"], "function")
32+
self.assertEqual(schema["function"]["name"], "add")
33+
self.assertIn("parameters", schema["function"])
34+
35+
# args_schema validation
36+
with self.assertRaises(ValidationError):
37+
# missing required field
38+
asyncio.run(add.run(a=1))
39+
40+
result = asyncio.run(add.run(a=2, b=3))
41+
self.assertEqual(result, 5)
42+
43+
def test_async_function_tool_run(self):
44+
"""Test that an async function tool runs."""
45+
46+
@tool
47+
async def mul(a: int, b: int) -> int:
48+
"""Multiplies two integers asynchronously."""
49+
return a * b
50+
51+
result = asyncio.run(mul.run(a=4, b=5))
52+
self.assertEqual(result, 20)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)