Source code for kani.engines.openai.engine

import functools
import inspect
import itertools
import logging
import warnings
from typing import AsyncIterable, Literal

from kani import _optional
from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole
from . import function_calling, mm_tokens
from .model_constants import API_BY_PREFIX, CONTEXT_SIZES_BY_PREFIX
from .translation import OPENAI_PIPELINE, ChatCompletion, kani_cm_to_openai_cm, openai_tc_to_kani_tc
from .translation_responses import kani_cm_to_openai_responses_inputs, openai_responses_response_to_kani_completion
from .utils import DottableDict
from ..base import BaseCompletion, BaseEngine, Completion
from ..mixins import TokenCached

try:
    import tiktoken
    from openai import AsyncOpenAI as OpenAIClient
    from openai.types.chat import ChatCompletionMessageParam
    from openai.types.responses import ResponseInputItemParam, ResponseInputParam
    from openai.types.shared_params import FunctionDefinition
except ImportError as e:
    raise MissingModelDependencies(
        'The OpenAIEngine requires extra dependencies. Please install kani with "pip install kani[openai]".'
    ) from None

log = logging.getLogger(__name__)


[docs] class OpenAIEngine(TokenCached, BaseEngine): """ Engine for using the OpenAI API. This engine supports all chat-based models and fine-tunes. **Multimodal support**: images, audio. **Message Extras** * ``"openai_completion"``: The ChatCompletion (raw response) returned by the OpenAI servers, as a dictionary. Non-streaming responses only. * ``"openai_usage"``: The usage data (raw response) returned by the OpenAI servers, as a dictionary. """ disable_function_calling_kwargs = {"tool_choice": "none"} def __init__( self, api_key: str = None, model="gpt-4.1-nano", max_context_size: int = None, *, api_type: Literal["chat_completions", "responses"] = None, organization: str = None, retry: int = 5, api_base: str = "https://api.openai.com/v1", headers: dict = None, client: OpenAIClient = None, tokenizer=None, **hyperparams, ): """ :param api_key: Your OpenAI API key. By default, the API key will be read from the `OPENAI_API_KEY` environment variable. :param model: The id of the model to use (e.g. "gpt-4o-mini", "ft:gpt-3.5-turbo:my-org:custom_suffix:id"). :param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given model's full context size. :param api_type: Whether to use the Chat Completions API (default for most models) or Responses API (default for "deep-reasoning" style models). If unset, the best API type for the given model will be chosen. :param organization: The OpenAI organization to use in requests. By default, the org ID would be read from the `OPENAI_ORG_ID` environment variable (defaults to the API key's default org if not set). :param retry: How many times the engine should retry failed HTTP calls with exponential backoff (default 5). :param api_base: The base URL of the OpenAI API to use. :param headers: A dict of HTTP headers to include with each request. :param client: An instance of `openai.AsyncOpenAI <https://github.com/openai/openai-python>`_ (for reusing the same client in multiple engines). You must specify exactly one of ``(api_key, client)``. If this is passed the ``organization``, ``retry``, ``api_base``, and ``headers`` params will be ignored. :param tokenizer: The tokenizer to use for token estimation - for OpenAI models this will be loaded automatically. A class with a ``.encode(text: str)`` method that returns a list (usually of token ids). :param hyperparams: The arguments to pass to the ``create_chat_completion`` call with each request. See https://platform.openai.com/docs/api-reference/chat/create for a full list of params. """ if api_key and client: raise ValueError("You must supply no more than one of (api_key, client).") if max_context_size is None: matched_prefix, max_context_size = next( (prefix, size) for prefix, size in CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix) ) if not matched_prefix: warnings.warn( "The context length for this model was not found, defaulting to 2048 tokens. Please specify" " `max_context_size` if this is incorrect.", stacklevel=2, ) if api_type is None: matched_prefix, api_type = next((prefix, a) for prefix, a in API_BY_PREFIX if model.startswith(prefix)) warnings.warn( f"The OpenAI API type for this model was not set, defaulting to {api_type!r} for {model!r}. " 'Please specify `api_type="chat_completions"` or `api_type="responses"` if this is incorrect.', stacklevel=2, ) super().__init__() self.client = client or OpenAIClient( api_key=api_key, organization=organization, max_retries=retry, base_url=api_base, default_headers=headers ) self.model = model self.max_context_size = max_context_size self.hyperparams = hyperparams self.openai_api_type = api_type self._tokenizer = tokenizer # tiktoken caches a tokenizer globally in module, so we can unconditionally load it # ==== token counting ==== @property def tokenizer(self): if self._tokenizer is not None: return self._tokenizer # we load this lazily now since responses API doesn't need it, and to allow overrides more easily try: self._tokenizer = tiktoken.encoding_for_model(self.model) except KeyError: warnings.warn( f"Could not find a tokenizer for the {self.model} model. You may need to update tiktoken. Using" " o200k_base tokenizer as default." ) self._tokenizer = tiktoken.get_encoding("o200k_base") return self._tokenizer
[docs] def message_len(self, message: ChatMessage) -> int: if (cached_len := self.get_cached_message_len(message)) is not None: return cached_len mlen = 7 # main content if _optional.has_multimodal_core: for part in message.parts: if isinstance(part, _optional.multimodal_core.AudioPart): mlen += mm_tokens.tokens_from_audio_duration(part.duration, self.model) elif isinstance(part, _optional.multimodal_core.ImagePart): mlen += mm_tokens.tokens_from_image_size(part.size, self.model) else: mlen += len(self.tokenizer.encode(str(part))) else: if message.text: mlen += len(self.tokenizer.encode(message.text)) # additional keys if message.name: mlen += len(self.tokenizer.encode(message.name)) if message.tool_calls: for tc in message.tool_calls: mlen += len(self.tokenizer.encode(tc.function.name)) mlen += len(self.tokenizer.encode(tc.function.arguments)) # HACK: using gpt-4o and parallel function calling, the API randomly adds tokens based on the length of the # TOOL message (see tokencounting.ipynb)??? # this seems to be ~ 6 + (token len / 20) tokens per message (though it randomly varies), but this formula # is <10 tokens of an overestimate in most cases if self.model.startswith("gpt-4o") and message.role == ChatRole.FUNCTION: mlen += 6 + (mlen // 20) self.set_cached_message_len(message, mlen) return mlen
[docs] def function_token_reserve(self, functions: list[AIFunction]) -> int: if not functions: return 0 # wrap an inner impl to use lru_cache with tuple return self._function_token_reserve_impl(tuple(functions))
@functools.lru_cache(maxsize=256) def _function_token_reserve_impl(self, functions): prompt = function_calling.prompt(self.translate_functions(functions)) return len(self.tokenizer.encode(prompt)) @functools.cached_property def _count_tokens_arg_names(self): """A list of valid kwarg names that can be passed to self.client.responses.input_tokens.count""" try: inspected_params = set(inspect.signature(self.client.responses.input_tokens.count).parameters) return inspected_params except Exception as e: log.warning( "Could not introspect responses.input_tokens.count for parameter names, returning default:", exc_info=e ) # default return { "conversation", "input", "instructions", "model", "parallel_tool_calls", "previous_response_id", "reasoning", "text", "tool_choice", "tools", "truncation", "extra_headers", "extra_query", "extra_body", "timeout", } # ==== hackable stuff for requests ==== # --- kani -> oai translation ---
[docs] def translate_functions(self, functions: list[AIFunction]) -> list[dict]: r"""Translate a list of Kani :class:`.AIFunction`\ s to a list of OpenAI tool definitions.""" if self.openai_api_type == "chat_completions": return [ dict( type="function", function=FunctionDefinition(name=f.name, description=f.desc, parameters=f.json_schema), ) for f in functions ] elif self.openai_api_type == "responses": return [dict(type="function", name=f.name, description=f.desc, parameters=f.json_schema) for f in functions] raise ValueError(f"Unknown OpenAI API type: {self.openai_api_type!r}")
[docs] def translate_messages(self, messages: list[ChatMessage]) -> list[ChatCompletionMessageParam] | ResponseInputParam: r"""Translate a list of Kani :class:`.ChatMessage`\ s to a list of OpenAI messages.""" # we don't use a .apply() step here for hackability, so the pipeline just binds tool calls and cleans up # any invalid prefixes inter = OPENAI_PIPELINE(messages) if self.openai_api_type == "chat_completions": return [self.translate_kani_message_to_openai(m) for m in inter] elif self.openai_api_type == "responses": return list( itertools.chain.from_iterable(self.translate_kani_message_to_openai_responses(m) for m in inter) ) raise ValueError(f"Unknown OpenAI API type: {self.openai_api_type!r}")
def _prepare_request( self, messages, functions, *, intent: str = "chat_completions.create", **kwargs ) -> tuple[dict, list, dict | None]: """ Prepare the API request to the OpenAI API. Returns a tuple (kwargs, messages, tools) to be passed to the OpenAIClient's chat.completions.create() or responses.create() method. :param messages: The Kani ChatMessages to translate into OpenAI-format messages. :param functions: The Kani AIFunctions to translate into OpenAI-format tools. :param intent: one of ("chat_completions.create", "chat_completions.stream", "responses.create", "responses.stream", "responses.input_tokens.count") -- the underlying OpenAI SDK call the returned keyword arguments will be passed to. :param kwargs: The request-specific kwargs passed to the request, from either the engine initialization or the chat_round call. """ if functions: tool_specs = self.translate_functions(functions) else: tool_specs = None # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound translated_messages = self.translate_messages(messages) # responses API params if self.openai_api_type == "responses": # ensure include reasoning is set kwargs.setdefault("include", []) if "reasoning.encrypted_content" not in kwargs["include"]: kwargs["include"].append("reasoning.encrypted_content") return kwargs, translated_messages, tool_specs # --- oai -> kani translation --- def _translate_openai_chat_completion(self, completion): """Translate an OpenAI completion to a Kani completion. Only called for non-streaming requests by default.""" if self.openai_api_type == "chat_completions": return ChatCompletion(openai_completion=completion) elif self.openai_api_type == "responses": return openai_responses_response_to_kani_completion(completion) raise ValueError(f"Unknown OpenAI API type: {self.openai_api_type!r}") # ==== chat completions ====
[docs] @staticmethod def translate_kani_message_to_openai(message: ChatMessage) -> ChatCompletionMessageParam: """Translate a single Kani :class:`.ChatMessage` to a single OpenAI message.""" return kani_cm_to_openai_cm(message)
async def _predict_chat_completions( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: local_kwargs, translated_messages, tool_specs = self._prepare_request( messages, functions, intent="chat_completions.create", **(self.hyperparams | hyperparams) ) # make API call completion = await self.client.chat.completions.create( model=self.model, messages=translated_messages, tools=tool_specs, **local_kwargs, ) # translate into Kani spec and return kani_cmpl = self._translate_openai_chat_completion(completion) self.set_cached_prompt_len(messages, functions, completion.usage.prompt_tokens) self.set_cached_prompt_len( messages + [kani_cmpl.message], functions, completion.usage.prompt_tokens + kani_cmpl.completion_tokens ) self.set_cached_message_len(kani_cmpl.message, kani_cmpl.completion_tokens) return kani_cmpl async def _stream_chat_completions( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> AsyncIterable[str | BaseCompletion]: local_kwargs, translated_messages, tool_specs = self._prepare_request( messages, functions, intent="chat_completions.stream", **(self.hyperparams | hyperparams) ) # make API call stream = await self.client.chat.completions.create( model=self.model, messages=translated_messages, tools=tool_specs, stream=True, stream_options={"include_usage": True}, **local_kwargs, ) # save requested tool calls and content as streamed content_chunks = [] tool_call_partials = {} # index -> tool call usage = None # iterate over the stream and yield/save async for chunk in stream: # save usage if present if chunk.usage is not None: usage = chunk.usage if not chunk.choices: continue # process content delta delta = chunk.choices[0].delta # yield content if delta.content is not None: content_chunks.append(delta.content) yield delta.content # tool calls are partials, save a mapping to the latest state and we'll translate them later once complete if delta.tool_calls: for tc_delta in delta.tool_calls: if tc_delta.index not in tool_call_partials: tool_call_partials[tc_delta.index] = tc_delta else: partial = tool_call_partials[tc_delta.index] if tc_delta.function.name is not None: partial.function.name += tc_delta.function.name if tc_delta.function.arguments is not None: partial.function.arguments += tc_delta.function.arguments # construct the final completion with streamed tool calls content = None if not content_chunks else "".join(content_chunks) tool_calls = [openai_tc_to_kani_tc(tc) for tc in sorted(tool_call_partials.values(), key=lambda c: c.index)] msg = ChatMessage(role=ChatRole.ASSISTANT, content=content, tool_calls=tool_calls) # token counting if usage: self.set_cached_prompt_len(messages, functions, usage.prompt_tokens) self.set_cached_prompt_len(messages + [msg], functions, usage.prompt_tokens + usage.completion_tokens) self.set_cached_message_len(msg, usage.completion_tokens) prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens msg.extra["openai_usage"] = DottableDict(usage.model_dump(mode="json")) else: prompt_tokens = completion_tokens = None yield Completion(message=msg, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) # ==== responses ====
[docs] @staticmethod def translate_kani_message_to_openai_responses(message: ChatMessage) -> list[ResponseInputItemParam]: """Translate a single Kani :class:`.ChatMessage` to its corresponding OpenAI responses input items.""" return kani_cm_to_openai_responses_inputs(message)
async def _predict_responses( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: local_kwargs, translated_messages, tool_specs = self._prepare_request( messages, functions, intent="responses.create", **(self.hyperparams | hyperparams) ) # make API call response = await self.client.responses.create( model=self.model, input=translated_messages, tools=tool_specs, **local_kwargs ) # translate into Kani spec and return kani_cmpl = self._translate_openai_chat_completion(response) self.set_cached_prompt_len(messages, functions, response.usage.input_tokens) self.set_cached_prompt_len(messages + [kani_cmpl.message], functions, response.usage.total_tokens) self.set_cached_message_len(kani_cmpl.message, kani_cmpl.completion_tokens) return kani_cmpl async def _stream_responses( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> AsyncIterable[str | BaseCompletion]: local_kwargs, translated_messages, tool_specs = self._prepare_request( messages, functions, intent="responses.stream", **(self.hyperparams | hyperparams) ) # the sdk handles the text streaming for us, so we can just use that async with self.client.responses.stream( model=self.model, input=translated_messages, tools=tool_specs or [], **local_kwargs ) as streamer: async for event in streamer: # we only want to emit the content of response.output_text.delta if event.type == "response.output_text.delta": yield event.delta # translate into Kani spec and return response = await streamer.get_final_response() kani_cmpl = self._translate_openai_chat_completion(response) self.set_cached_prompt_len(messages, functions, response.usage.input_tokens) self.set_cached_prompt_len(messages + [kani_cmpl.message], functions, response.usage.total_tokens) self.set_cached_message_len(kani_cmpl.message, kani_cmpl.completion_tokens) yield kani_cmpl # ==== main kani impl ====
[docs] async def prompt_len(self, messages, functions=None, **kwargs) -> int: # we have a token counting endpoint for responses api, so prepare and count if self.openai_api_type == "responses": local_kwargs, translated_messages, tool_specs = self._prepare_request( messages, functions, intent="responses.input_tokens.count", **(self.hyperparams | kwargs) ) valid_count_token_kwargs = {k: v for k, v in local_kwargs.items() if k in self._count_tokens_arg_names} resp = await self.client.responses.input_tokens.count( model=self.model, input=translated_messages, tools=tool_specs, **valid_count_token_kwargs ) return resp.input_tokens # for chat completions api, we have to count ourselves # optimization: since chat-based appends messages 1 at a time, see if the messages - 1 is in prompt cache if messages and (cached := self.get_cached_prompt_len(messages[:-1], functions, **kwargs)) is not None: return cached + self.message_len(messages[-1]) # OpenAI does not use an API for token counting, so we'll use the old-style message-wise counting return sum(self.message_len(m) for m in messages) + self.function_token_reserve(functions)
[docs] async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: if self.openai_api_type == "chat_completions": return await self._predict_chat_completions(messages, functions, **hyperparams) elif self.openai_api_type == "responses": return await self._predict_responses(messages, functions, **hyperparams) raise ValueError(f"Unknown OpenAI API type: {self.openai_api_type!r}")
[docs] async def stream( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> AsyncIterable[str | BaseCompletion]: if self.openai_api_type == "chat_completions": async for item in self._stream_chat_completions(messages, functions, **hyperparams): yield item elif self.openai_api_type == "responses": async for item in self._stream_responses(messages, functions, **hyperparams): yield item else: raise ValueError(f"Unknown OpenAI API type: {self.openai_api_type!r}")
[docs] async def close(self): await self.client.close()
def __repr__(self): return ( f"{type(self).__name__}(model={self.model}, max_context_size={self.max_context_size}," f" hyperparams={self.hyperparams})" )