Source code for kani.engines.openai.translation
"""Helpers to translate kani chat objects into OpenAI params."""
import base64
from kani import _optional
from kani.ai_function import AIFunction
from kani.engines.base import BaseCompletion
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated
from .utils import DottableDict
try:
from openai.types.chat import (
ChatCompletion as OpenAIChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessage,
ChatCompletionMessageFunctionToolCallParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.shared_params import FunctionDefinition
except ImportError as e:
raise MissingModelDependencies(
'The OpenAIEngine requires extra dependencies. Please install kani with "pip install kani[openai]".'
) from None
# ==== kani -> openai ====
# decomp
def kani_cm_to_openai_cm(msg: ChatMessage) -> ChatCompletionMessageParam:
"""Translate a kani ChatMessage into an OpenAI Message."""
# translate tool responses to a function to the right openai format
match msg.role:
case ChatRole.FUNCTION if msg.tool_call_id is not None:
return ChatCompletionToolMessageParam(role="tool", content=msg.text, tool_call_id=msg.tool_call_id)
case ChatRole.FUNCTION:
return ChatCompletionFunctionMessageParam(**_msg_kwargs(msg))
case ChatRole.SYSTEM:
return ChatCompletionSystemMessageParam(**_msg_kwargs(msg))
case ChatRole.USER:
return ChatCompletionUserMessageParam(**_msg_kwargs(msg))
case _: # assistant
if msg.tool_calls:
tool_calls = [kani_tc_to_openai_tc(tc) for tc in msg.tool_calls]
return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg), tool_calls=tool_calls)
return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg))
def _msg_kwargs(msg: ChatMessage) -> dict:
match msg:
case ChatMessage(role=ChatRole.USER, content=list(parts)):
content = _parts_to_oai(parts)
case _:
content = msg.text
data = dict(role=msg.role.value, content=content)
if msg.name is not None:
data["name"] = msg.name
return data
def kani_tc_to_openai_tc(tc: ToolCall) -> ChatCompletionMessageFunctionToolCallParam:
"""Translate a kani ToolCall into an OpenAI dict"""
oai_function = dict(name=tc.function.name, arguments=tc.function.arguments)
return dict(id=tc.id, type="function", function=oai_function)
# --- multimodal ---
if _optional.has_multimodal_core:
def _parts_to_oai(parts: list[MessagePart | str]) -> list[dict]:
"""Translate a list of Kani messageparts into openai message components."""
out = []
for part in parts:
if isinstance(part, _optional.multimodal_core.AudioPart):
wav_data = base64.b64encode(part.as_wav_bytes()).decode()
out.append({"type": "input_audio", "input_audio": {"data": wav_data, "format": "wav"}})
elif isinstance(part, _optional.multimodal_core.ImagePart):
data_uri = part.as_b64_uri()
out.append({"type": "image_url", "image_url": {"url": data_uri}})
else:
out.append({"type": "text", "text": str(part)})
return out
else:
def _parts_to_oai(parts: list[MessagePart | str]) -> str:
"""If multimodal-core is not installed, return the string."""
return "".join(map(str, parts))
# --- main ---
OPENAI_PIPELINE = (
PromptPipeline().ensure_bound_function_calls().ensure_start(predicate=lambda msg: msg.role != ChatRole.FUNCTION)
)
@deprecated("Use OpenAIEngine.translate_functions() instead.")
def translate_functions(functions: list[AIFunction]) -> list[dict]:
# this is in the engine for hackability - this function is kept for back-compatibility
return [
dict(type="function", function=FunctionDefinition(name=f.name, description=f.desc, parameters=f.json_schema))
for f in functions
]
@deprecated("Use OpenAIEngine.translate_messages() instead.")
def translate_messages(messages: list[ChatMessage]) -> list[ChatCompletionMessageParam]:
# this is in the engine for hackability - this function is kept for back-compatibility
from kani.engines.openai import OpenAIEngine
inter = OPENAI_PIPELINE(messages)
return [OpenAIEngine.translate_kani_message_to_openai(m) for m in inter]
# ==== openai -> kani ====
def openai_cm_to_kani_cm(msg: ChatCompletionMessage) -> ChatMessage:
"""Translate an OpenAI ChatCompletionMessage into a kani ChatMessage."""
# translate tool role to function role
if msg.role == "tool":
role = ChatRole.FUNCTION
else:
role = ChatRole(msg.role)
# translate FunctionCall to singular ToolCall
if msg.tool_calls:
tool_calls = [openai_tc_to_kani_tc(tc) for tc in msg.tool_calls]
elif msg.function_call:
tool_calls = [ToolCall.from_function_call(msg.function_call)]
else:
tool_calls = None
return ChatMessage(role=role, content=msg.content, tool_calls=tool_calls)
def openai_tc_to_kani_tc(tc) -> ToolCall:
return ToolCall(id=tc.id, type=tc.type, function=openai_fc_to_kani_fc(tc.function))
def openai_fc_to_kani_fc(fc) -> FunctionCall:
return FunctionCall(name=fc.name, arguments=fc.arguments)
[docs]
class ChatCompletion(BaseCompletion):
"""A wrapper around the OpenAI ChatCompletion to make it compatible with the Kani interface."""
def __init__(self, openai_completion: OpenAIChatCompletion):
self.openai_completion = openai_completion
"""The underlying OpenAI ChatCompletion."""
self._message = openai_cm_to_kani_cm(openai_completion.choices[0].message)
self._message.extra["openai_completion"] = DottableDict(openai_completion.model_dump(mode="json"))
self._message.extra["openai_usage"] = DottableDict(openai_completion.usage.model_dump(mode="json"))
@property
def message(self):
return self._message
@property
def prompt_tokens(self):
return self.openai_completion.usage.prompt_tokens
@property
def completion_tokens(self):
# for some reason, the OpenAI API doesn't return the tokens used by ChatML
# so we add on the length of "<|im_start|>assistant" and "<|im_end|>" here
return self.openai_completion.usage.completion_tokens + 5