import functools
import inspect
import itertools
import json
import logging
import os
import warnings
from typing import AsyncIterable
from kani import _optional
from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated, warn_in_userspace
from . import mm_tokens, model_constants
from .parts import AnthropicThinkingPart, AnthropicUnknownPart
from ..base import BaseCompletion, BaseEngine, Completion
from ..mixins import TokenCached
try:
from anthropic import AsyncAnthropic
from anthropic.types import Message
except ImportError as e:
raise MissingModelDependencies(
'The AnthropicEngine requires extra dependencies. Please install kani with "pip install kani[anthropic]".'
) from None
log = logging.getLogger(__name__)
# ==== pipe ====
def content_transform(msg: ChatMessage):
content = []
for part in msg.parts:
# --- multimodal ---
if _optional.has_multimodal_core and isinstance(part, _optional.multimodal_core.ImagePart):
# USER messages with images should look like:
# {
# "role": "user",
# "content": [
# {
# "type": "image",
# "source": {
# "type": "base64",
# "media_type": image1_media_type,
# "data": image1_data,
# },
# },
# {
# "type": "text",
# "text": "Describe this image."
# }
# ],
# }
media_type = "image/png"
data = part.as_b64(format="png")
content.append({"type": "image", "source": {"type": "base64", "media_type": media_type, "data": data}})
# --- PDF ---
elif (
_optional.has_multimodal_core
and isinstance(part, _optional.multimodal_core.BinaryFilePart)
and part.mime == "application/pdf"
):
# {
# "role": "user",
# "content": [
# {
# "type": "document",
# "source": {
# "type": "base64",
# "media_type": "application/pdf",
# "data": pdf_data
# }
# },
# {
# "type": "text",
# "text": "What are the key findings in this document?"
# }
# ]
# }
data = part.as_b64()
content.append({"type": "document", "source": {"type": "base64", "media_type": part.mime, "data": data}})
# --- AnthropicThinkingPart ----
elif isinstance(part, AnthropicThinkingPart):
# unnecessary parts are filtered out by the API so we'll just pass them all back
content.append({"type": "thinking", "thinking": part.content, "signature": part.signature})
# --- AnthropicUnknownPart ----
elif isinstance(part, AnthropicUnknownPart):
# e.g. web search results, computer use, other server tools
content.append(part.data)
# default (function messages can't have multimodal parts, so we handle it in the branch below)
elif msg.role != ChatRole.FUNCTION:
content.append({"type": "text", "text": str(part)})
# FUNCTION messages should look like:
# {
# "role": "user",
# "content": [
# {
# "type": "tool_result",
# "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
# "content": "65 degrees"
# }
# ]
# }
if msg.role == ChatRole.FUNCTION:
result = {"type": "tool_result", "tool_use_id": msg.tool_call_id, "content": msg.text}
# tool call error
if msg.is_tool_call_error:
result["is_error"] = True
content.append(result)
# ASSISTANT messages with tool calls should look like:
# {
# "role": "assistant",
# "content": [
# {
# "type": "text",
# "text": "<thinking>I need to use the get_weather, and the user wants San Francisco, CA.</thinking>"
# },
# {
# "type": "tool_use",
# "id": "toolu_01A09q90qw90lq917835lq9",
# "name": "get_weather",
# "input": {"location": "San Francisco, CA", "unit": "celsius"}
# }
# ]
# }
if msg.role == ChatRole.ASSISTANT and msg.tool_calls:
for tc in msg.tool_calls:
content.append({"type": "tool_use", "id": tc.id, "name": tc.function.name, "input": tc.function.kwargs})
return content
# assumes system messages are plucked before calling
CLAUDE_PIPELINE = (
PromptPipeline()
.translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER)
.merge_consecutive(role=ChatRole.USER, sep="\n")
.merge_consecutive(role=ChatRole.ASSISTANT, sep=" ")
.ensure_bound_function_calls()
.conversation_dict(function_role="user", content_transform=content_transform)
)
[docs]
class AnthropicEngine(TokenCached, BaseEngine):
"""
Engine for using the Anthropic API.
This engine supports all Claude models. See https://docs.anthropic.com/claude/docs/getting-access-to-claude for
information on accessing the Claude API.
See https://docs.anthropic.com/en/docs/about-claude/models/overview for a list of available models.
**Multimodal support**: images.
**Additional capabilities**: PDF document processing. Use :class:`kani.ext.multimodal_core.BinaryFilePart`.
**Message Extras**: ``"anthropic_message"``: The Message (raw response) returned by the Anthropic servers.
"""
disable_function_calling_kwargs = {"tool_choice": {"type": "none"}}
def __init__(
self,
api_key: str = None,
model: str = "claude-sonnet-4-0",
max_tokens: int = 2048,
max_context_size: int = None,
*,
retry: int = 2,
api_base: str = None,
headers: dict = None,
client: AsyncAnthropic = None,
**hyperparams,
):
"""
:param api_key: Your Anthropic API key. By default, the API key will be read from the `ANTHROPIC_API_KEY`
environment variable.
:param model: The id of the model to use (e.g. "claude-opus-4-0"). See
https://docs.anthropic.com/en/docs/about-claude/models/overview for a list of models.
:param max_tokens: The maximum number of tokens to sample at each generation (defaults to 2048).
Generally, you should set this to the same number as your Kani's ``desired_response_tokens``.
:param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given
model's full context size.
:param retry: How many times the engine should retry failed HTTP calls with exponential backoff (default 2).
:param api_base: The base URL of the Anthropic API to use.
:param headers: A dict of HTTP headers to include with each request.
:param client: An instance of ``anthropic.AsyncAnthropic`` (for reusing the same client in multiple engines).
You must specify exactly one of (api_key, client). If this is passed the ``retry``, ``api_base``,
and ``headers`` params will be ignored.
:param hyperparams: Any additional parameters to pass to the underlying API call (see
https://docs.claude.com/en/api/messages).
"""
if api_key and client:
raise ValueError("You must supply no more than one of (api_key, client).")
if api_key is None and client is None:
api_key = os.getenv("ANTHROPIC_API_KEY")
if api_key is None:
raise ValueError(
"You must supply an `api_key`, `client`, or set the `ANTHROPIC_API_KEY` environment variable to use"
" the AnthropicEngine."
)
if max_context_size is None:
matched_prefix, max_context_size = next(
(prefix, size) for prefix, size in model_constants.CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix)
)
if not matched_prefix:
warnings.warn(
f"The context length for this model was not found, defaulting to {max_context_size} tokens. Please"
" specify `max_context_size` if this is incorrect."
)
super().__init__()
self.client = client or AsyncAnthropic(
api_key=api_key, max_retries=retry, base_url=api_base, default_headers=headers
)
self.model = model
self.max_tokens = max_tokens
self.max_context_size = max_context_size
self.hyperparams = hyperparams
# ==== hackable stuff for requests ====
@property
def _messages_api(self):
"""Return the messages API resource object. Useful to override to use the beta API instead."""
return self.client.messages
@staticmethod
def _prepare_request(messages, functions, *, intent: str = "create") -> tuple[dict, list]:
"""
Prepare the API request to the Anthropic API. Returns a tuple (kwargs, messages) to be passed to the
AnthropicClient's messages.create() method.
:param messages: The Kani ChatMessages to translate into Anthropic-format messages.
:param functions: The Kani AIFunctions to translate into Anthropic-format tools.
:param intent: one of ("create", "stream", or "count_tokens") -- the underlying Anthropic SDK call the returned
keyword arguments will be passed to.
"""
kwargs = {}
# --- messages ---
# pluck system messages
last_system_idx = next((i for i, m in enumerate(messages) if m.role != ChatRole.SYSTEM), None)
if last_system_idx:
kwargs["system"] = "\n\n".join(m.text for m in messages[:last_system_idx])
messages = messages[last_system_idx:]
# enforce ordering and function call bindings
# and translate to dict spec
claude_fmt_messages = CLAUDE_PIPELINE(messages)
# merge FUNCTION (which get translated to user), USER consecutives into one with multiple parts
prompt_msgs = []
for role, group_msgs in itertools.groupby(claude_fmt_messages, key=lambda m: m["role"]):
group_msgs = list(group_msgs)
# >1 consecutive user messages get merged
if role == "user" and len(group_msgs) > 1:
# turn str parts into {type: text, text: ...}
prompt_msg_content = []
for msg in group_msgs:
if isinstance(msg["content"], str):
prompt_msg_content.append({"type": "text", "text": msg["content"]})
else:
prompt_msg_content.extend(msg["content"])
# and output the final msg
prompt_msgs.append({"role": "user", "content": prompt_msg_content})
# else send to output
else:
prompt_msgs.extend(group_msgs)
# --- tools ---
if functions:
kwargs["tools"] = [
{"name": f.name, "description": f.desc, "input_schema": f.json_schema} for f in functions
]
log.debug(f"Claude message format: {prompt_msgs}")
return kwargs, prompt_msgs
def _translate_anthropic_message(self, message: Message) -> Completion:
"""Translate an Anthropic message to a Kani completion."""
tool_calls = []
parts = []
for part in message.content:
if part.type == "text":
parts.append(part.text)
elif part.type == "tool_use":
fc = FunctionCall(name=part.name, arguments=json.dumps(part.input))
tc = ToolCall(id=part.id, type="function", function=fc)
tool_calls.append(tc)
elif part.type == "thinking":
parts.append(AnthropicThinkingPart(content=part.thinking, signature=part.signature))
else:
parts.append(AnthropicUnknownPart(type=part.type, data=part.model_dump()))
warnings.warn(
f"The engine returned an unknown part: {part.type}. This has been saved as an AnthropicUnknownPart,"
" but will not stringify to a natural language prompt for other language models."
)
content = parts[0] if len(parts) == 1 else parts
kani_msg = ChatMessage.assistant(content, tool_calls=tool_calls or None)
# also cache the message token len
self.set_cached_message_len(kani_msg, message.usage.output_tokens)
# set the extra
kani_msg.extra["anthropic_message"] = message
return Completion(
message=kani_msg,
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
)
@functools.cached_property
def _count_tokens_arg_names(self):
"""A list of valid kwarg names that can be passed to self._messages_api.count_tokens"""
try:
inspected_params = set(inspect.signature(self._messages_api.count_tokens).parameters)
return inspected_params
except Exception as e:
log.warning("Could not introspect count_tokens for parameter names, returning default:", exc_info=e)
# default
return {
"messages",
"model",
"system",
"thinking",
"tool_choice",
"tools",
"extra_headers",
"extra_query",
"extra_body",
"timeout",
}
# ==== kani impls ====
[docs]
async def prompt_len(self, messages, functions=None, **kwargs) -> int:
if (cached_len := self.get_cached_prompt_len(messages, functions, **kwargs)) is not None:
return cached_len
predict_kwargs, prompt_msgs = self._prepare_request(messages, functions, intent="count_tokens")
# only include valid kwargs from inspecting self._messages_api.count_tokens
valid_count_token_kwargs = {
k: v for k, v in (predict_kwargs | self.hyperparams | kwargs).items() if k in self._count_tokens_arg_names
}
result = await self._messages_api.count_tokens(
model=self.model,
messages=prompt_msgs,
**valid_count_token_kwargs,
)
self.set_cached_prompt_len(messages, functions, length=result.input_tokens, **kwargs)
return result.input_tokens
[docs]
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> Completion:
kwargs, prompt_msgs = self._prepare_request(messages, functions, intent="create")
assert len(prompt_msgs) > 0
# --- completion ---
message = await self._messages_api.create(
model=self.model,
max_tokens=self.max_tokens,
messages=prompt_msgs,
# to prevent toe-stepping, hyperparams > self.hyperparams > kwargs
**(kwargs | self.hyperparams | hyperparams),
)
# translate to kani
return self._translate_anthropic_message(message)
[docs]
async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> AsyncIterable[str | BaseCompletion]:
# do the stream
kwargs, prompt_msgs = self._prepare_request(messages, functions, intent="stream")
assert len(prompt_msgs) > 0
async with self._messages_api.stream(
model=self.model,
max_tokens=self.max_tokens,
messages=prompt_msgs,
# to prevent toe-stepping, hyperparams > self.hyperparams > kwargs
**(kwargs | self.hyperparams | hyperparams),
) as stream:
async for text in stream.text_stream:
yield text
message = await stream.get_final_message()
yield self._translate_anthropic_message(message)
[docs]
async def close(self):
await self.client.close()
# ==== deprecated ====
# because we have to estimate tokens wildly and the ctx is so long we'll just reserve a bunch
token_reserve = 500
[docs]
@deprecated("Use prompt_len instead")
def message_len(self, message: ChatMessage) -> int:
if (cached_len := self.get_cached_message_len(message)) is not None:
return cached_len
# TODO with async token counting use the token counting API
chars = len(message.role.value)
tokens = 0
if _optional.has_multimodal_core:
for part in message.parts:
if isinstance(part, _optional.multimodal_core.ImagePart):
tokens += mm_tokens.tokens_from_image_size(part.size)
else:
chars += len(str(part))
else:
chars += len(message.text)
# tools
if message.tool_calls:
for tc in message.tool_calls:
chars += len(tc.function.name) + len(tc.function.arguments)
# token counting - claude 3+ does not release tokenizer so we have to do heuristics and cache
# Anthropic documents 3.4 bytes per token, so we do a conservative 3.2 char/tok
warn_in_userspace(
f"This Claude model ({self.model}) does not have a public tokenizer, so local token counting will only"
" return an estimate (3.2 characters/token). Use `await engine.prompt_len(messages, functions)` instead"
" for an exact count."
)
return int(chars / 3.2) + tokens
[docs]
@deprecated("Use prompt_len instead")
def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
warn_in_userspace(
f"This Claude model ({self.model}) does not have a public tokenizer, so local token counting will only"
" return an estimate (3.2 characters/token). Use `await engine.prompt_len(messages, functions)` instead"
" for an exact count."
)
# wrap an inner impl to use lru_cache with frozensets
return self._function_token_reserve_impl(frozenset(functions))
@functools.lru_cache(maxsize=256)
def _function_token_reserve_impl(self, functions):
# panik, also assume len/4?
n = sum(len(f.name) + len(f.desc) + len(json.dumps(f.json_schema)) for f in functions)
return int(n / 3.2)