Source code for kani.model_specific.json

import json

from kani.models import FunctionCall, ToolCall
from .base import BaseToolCallParser


[docs] class NaiveJSONToolCallParser(BaseToolCallParser): """ If the model's output contains only valid JSON of form: .. code-block:: json { "name": "function_name", "parameters": { "key": "value..." } } then assume it is a function call. Otherwise, return the content unchanged. """ def __init__(self, *args, **kwargs): super().__init__(*args, tool_call_start_token=None, tool_call_end_token=None, **kwargs) def parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]: """Given the string completion of the model, return the content without tool calls and the parsed tool calls.""" try: data = json.loads(content.strip()) match data: case {"name": str(name), "parameters": dict(parameters)}: tc = ToolCall.from_function_call(FunctionCall.with_args(name, **parameters)) return "", [tc] except json.JSONDecodeError: return content, [] return content, [] async def stream(self, messages, functions=None, **hyperparams): # special case - if we see a { at start of message, defer until end of message to see if it's a function call # otherwise stream as normal seen_non_tool_call_token = False in_tool_call = False # consume from the inner iterator, yielding as normal until we see a tool call or a completion async for elem in super().stream(messages, functions, **hyperparams): if isinstance(elem, str): # if we see {, stop yielding and start buffering if elem.lstrip().startswith("{") and not seen_non_tool_call_token: in_tool_call = True # otherwise yield the string if elem and not in_tool_call: seen_non_tool_call_token = True yield elem else: # yield the inner completion yield elem