Source code for kani.engines.openai.translation

"""Helpers to translate kani chat objects into OpenAI params."""

import base64

from kani import _optional
from kani.ai_function import AIFunction
from kani.engines.base import BaseCompletion
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated
from .utils import DottableDict

try:
    from openai.types.chat import (
        ChatCompletion as OpenAIChatCompletion,
        ChatCompletionAssistantMessageParam,
        ChatCompletionFunctionMessageParam,
        ChatCompletionMessage,
        ChatCompletionMessageFunctionToolCallParam,
        ChatCompletionMessageParam,
        ChatCompletionMessageToolCall,
        ChatCompletionMessageToolCallParam,
        ChatCompletionSystemMessageParam,
        ChatCompletionToolMessageParam,
        ChatCompletionToolParam,
        ChatCompletionUserMessageParam,
    )
    from openai.types.shared_params import FunctionDefinition
except ImportError as e:
    raise MissingModelDependencies(
        'The OpenAIEngine requires extra dependencies. Please install kani with "pip install kani[openai]".'
    ) from None


# ==== kani -> openai ====
# decomp
def kani_cm_to_openai_cm(msg: ChatMessage) -> ChatCompletionMessageParam:
    """Translate a kani ChatMessage into an OpenAI Message."""
    # translate tool responses to a function to the right openai format
    match msg.role:
        case ChatRole.FUNCTION if msg.tool_call_id is not None:
            return ChatCompletionToolMessageParam(role="tool", content=msg.text, tool_call_id=msg.tool_call_id)
        case ChatRole.FUNCTION:
            return ChatCompletionFunctionMessageParam(**_msg_kwargs(msg))
        case ChatRole.SYSTEM:
            return ChatCompletionSystemMessageParam(**_msg_kwargs(msg))
        case ChatRole.USER:
            return ChatCompletionUserMessageParam(**_msg_kwargs(msg))
        case _:  # assistant
            if msg.tool_calls:
                tool_calls = [kani_tc_to_openai_tc(tc) for tc in msg.tool_calls]
                return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg), tool_calls=tool_calls)
            return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg))


def _msg_kwargs(msg: ChatMessage) -> dict:
    match msg:
        case ChatMessage(role=ChatRole.USER, content=list(parts)):
            content = _parts_to_oai(parts)
        case _:
            content = msg.text

    data = dict(role=msg.role.value, content=content)
    if msg.name is not None:
        data["name"] = msg.name
    return data


def kani_tc_to_openai_tc(tc: ToolCall) -> ChatCompletionMessageFunctionToolCallParam:
    """Translate a kani ToolCall into an OpenAI dict"""
    oai_function = dict(name=tc.function.name, arguments=tc.function.arguments)
    return dict(id=tc.id, type="function", function=oai_function)


# --- multimodal ---
if _optional.has_multimodal_core:

    def _parts_to_oai(parts: list[MessagePart | str]) -> list[dict]:
        """Translate a list of Kani messageparts into openai message components."""
        out = []
        for part in parts:
            if isinstance(part, _optional.multimodal_core.AudioPart):
                wav_data = base64.b64encode(part.as_wav_bytes()).decode()
                out.append({"type": "input_audio", "input_audio": {"data": wav_data, "format": "wav"}})
            elif isinstance(part, _optional.multimodal_core.ImagePart):
                data_uri = part.as_b64_uri()
                out.append({"type": "image_url", "image_url": {"url": data_uri}})
            else:
                out.append({"type": "text", "text": str(part)})
        return out

else:

    def _parts_to_oai(parts: list[MessagePart | str]) -> str:
        """If multimodal-core is not installed, return the string."""
        return "".join(map(str, parts))


# --- main ---
OPENAI_PIPELINE = (
    PromptPipeline().ensure_bound_function_calls().ensure_start(predicate=lambda msg: msg.role != ChatRole.FUNCTION)
)


@deprecated("Use OpenAIEngine.translate_functions() instead.")
def translate_functions(functions: list[AIFunction]) -> list[dict]:
    # this is in the engine for hackability - this function is kept for back-compatibility
    return [
        dict(type="function", function=FunctionDefinition(name=f.name, description=f.desc, parameters=f.json_schema))
        for f in functions
    ]


@deprecated("Use OpenAIEngine.translate_messages() instead.")
def translate_messages(messages: list[ChatMessage]) -> list[ChatCompletionMessageParam]:
    # this is in the engine for hackability - this function is kept for back-compatibility
    from kani.engines.openai import OpenAIEngine

    inter = OPENAI_PIPELINE(messages)
    return [OpenAIEngine.translate_kani_message_to_openai(m) for m in inter]


# ==== openai -> kani ====
def openai_cm_to_kani_cm(msg: ChatCompletionMessage) -> ChatMessage:
    """Translate an OpenAI ChatCompletionMessage into a kani ChatMessage."""
    # translate tool role to function role
    if msg.role == "tool":
        role = ChatRole.FUNCTION
    else:
        role = ChatRole(msg.role)
    # translate FunctionCall to singular ToolCall
    if msg.tool_calls:
        tool_calls = [openai_tc_to_kani_tc(tc) for tc in msg.tool_calls]
    elif msg.function_call:
        tool_calls = [ToolCall.from_function_call(msg.function_call)]
    else:
        tool_calls = None
    return ChatMessage(role=role, content=msg.content, tool_calls=tool_calls)


def openai_tc_to_kani_tc(tc) -> ToolCall:
    return ToolCall(id=tc.id, type=tc.type, function=openai_fc_to_kani_fc(tc.function))


def openai_fc_to_kani_fc(fc) -> FunctionCall:
    return FunctionCall(name=fc.name, arguments=fc.arguments)


[docs] class ChatCompletion(BaseCompletion): """A wrapper around the OpenAI ChatCompletion to make it compatible with the Kani interface.""" def __init__(self, openai_completion: OpenAIChatCompletion): self.openai_completion = openai_completion """The underlying OpenAI ChatCompletion.""" self._message = openai_cm_to_kani_cm(openai_completion.choices[0].message) self._message.extra["openai_completion"] = DottableDict(openai_completion.model_dump(mode="json")) self._message.extra["openai_usage"] = DottableDict(openai_completion.usage.model_dump(mode="json")) @property def message(self): return self._message @property def prompt_tokens(self): return self.openai_completion.usage.prompt_tokens @property def completion_tokens(self): # for some reason, the OpenAI API doesn't return the tokens used by ChatML # so we add on the length of "<|im_start|>assistant" and "<|im_end|>" here return self.openai_completion.usage.completion_tokens + 5