import asyncio
import datetime
import functools
import io
import json
import logging
import os
import warnings
from typing import AsyncIterable
from kani import _optional
from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall
from kani.parts import ReasoningPart
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated, warn_in_userspace
from . import mm_tokens, model_constants
from ..base import BaseCompletion, BaseEngine, Completion
from ..mixins import TokenCached
try:
from google import genai
from google.genai import types as genai_types
except ImportError as e:
raise MissingModelDependencies(
'The GoogleAIEngine requires extra dependencies. Please install kani with "pip install kani[google]".'
) from None
log = logging.getLogger(__name__)
# ==== pipe ====
# assumes system messages are plucked before calling
GOOGLE_PIPELINE: PromptPipeline[list[genai_types.Content]] = (
PromptPipeline().translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER).ensure_bound_function_calls()
)
ROLE_TRANSFORMS = {
ChatRole.ASSISTANT: "model",
ChatRole.FUNCTION: "tool",
}
RAW_RESPONSE_EXTRA_KEY = "google_response"
[docs]
class GoogleAIEngine(TokenCached, BaseEngine):
"""
Engine for using the Google AI Studio API (aka Gemini Developer API, Google AI API) and
Google Vertex AI API (aka Google Cloud API).
This engine supports all Google AI models.
See https://ai.google.dev/gemini-api/docs/models for a list of available models.
**Multimodal support**: images, audio, video.
**Message Extras**: ``"google_response"``: The
`raw response <https://ai.google.dev/api/generate-content#generatecontentresponse>`_ returned by the Google AI API.
"""
def __init__(
self,
api_key: str = None,
model: str = "gemini-2.5-flash",
max_context_size: int = None,
*,
# client settings
retry: int = 2,
api_base: str = None,
headers: dict = None,
client: genai.Client = None,
# kani settings
multimodal_upload_bytes_threshold: int = 512_000,
**hyperparams,
):
"""
:param api_key: Your Gemini Developer API key. By default, the API key will be read from the `GEMINI_API_KEY`
environment variable.
:param model: The id of the model to use (e.g. "gemini-2.5-flash"). See
https://ai.google.dev/gemini-api/docs/models for a list of models.
:param max_tokens: The maximum number of tokens to sample at each generation (defaults to 512).
Generally, you should set this to the same number as your Kani's ``desired_response_tokens``.
:param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given
model's full context size.
:param retry: How many times the engine should retry failed HTTP calls with exponential backoff (default 2).
:param api_base: The base URL of the Google AI API to use. If not specified, the default URL for the specified
API (AI Studio/Vertex) will be used.
:param headers: A dict of HTTP headers to include with each request.
:param client: An instance of ``genai.Client`` (for reusing the same client in multiple engines).
You must specify exactly one of (api_key, client).
:param multimodal_upload_bytes_threshold: If a multimodal object (audio, image, video) is larger than this
number of bytes, upload it as a file instead of passing it inline in a request. Default 512kB.
:param hyperparams: Any additional parameters to pass to the underlying API call (see
https://googleapis.github.io/python-genai/genai.html#genai.types.GenerateContentConfig).
"""
if api_key and client:
raise ValueError("You must supply no more than one of (api_key, client).")
if api_key is None and client is None:
api_key = os.getenv("GEMINI_API_KEY")
if api_key is None:
raise ValueError(
"You must supply an `api_key`, `client`, or set the `GEMINI_API_KEY` environment variable to use"
" the GoogleAIEngine."
)
if max_context_size is None:
matched_prefix, max_context_size = next(
(prefix, size) for prefix, size in model_constants.CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix)
)
if not matched_prefix:
warnings.warn(
f"The context length for this model was not found, defaulting to {max_context_size} tokens. Please"
" specify `max_context_size` if this is incorrect."
)
super().__init__()
self.client = client or genai.Client(
http_options=genai_types.HttpOptions(
retry_options=genai_types.HttpRetryOptions(attempts=retry),
base_url=api_base,
headers=headers,
)
)
self.model = model
self.max_context_size = max_context_size
self.hyperparams = hyperparams
# multimodal file caching
self.multimodal_upload_bytes_threshold = multimodal_upload_bytes_threshold
self._multimodal_file_cache: dict[bytes, genai.types.File] = {} # multimodal part sha256 -> google file
# ==== requests ====
async def _prepare_request(
self, messages, functions, hyperparams, intent: str = "generate_content"
) -> tuple[genai_types.GenerateContentConfigDict, list[genai_types.Content]]:
"""
Prepare the API request to the Google AI API. Returns a tuple (GenerateContentConfigDict, Content[]) to be
passed to the genai Client's ``generate_content()`` method.
:param messages: The Kani ChatMessages to translate into Google-format messages.
:param functions: The Kani AIFunctions to translate into Google-format tools.
:param intent: one of ("generate_content", "generate_content_stream", or "count_tokens") -- the underlying
Google AI SDK call the returned keyword arguments will be passed to.
"""
kwargs = {}
# --- messages ---
# pluck system messages
last_system_idx = next((i for i, m in enumerate(messages) if m.role != ChatRole.SYSTEM), None)
if last_system_idx:
kwargs["system_instruction"] = "\n\n".join(m.text for m in messages[:last_system_idx])
messages = messages[last_system_idx:]
# enforce ordering and function call bindings
translated_messages = GOOGLE_PIPELINE(messages)
# translate to content list
translated_messages = [await self._translate_message(m) for m in translated_messages]
# --- tools ---
if functions:
kwargs.setdefault("tools", [])
kwargs["tools"].append(
genai_types.Tool(
function_declarations=[
genai_types.FunctionDeclaration(
name=f.name,
description=f.desc,
parameters=genai_types.Schema.from_json_schema(
json_schema=genai_types.JSONSchema.model_validate(f.json_schema)
),
)
for f in functions
]
)
)
# --- kwargs ---
if intent != "count_tokens":
kwargs.update(self.hyperparams)
kwargs.update(hyperparams)
log.debug(f"translated prompt: {translated_messages}")
return kwargs, translated_messages
async def _translate_message(self, msg: ChatMessage) -> genai_types.Content:
"""
Translate one message into a Content object;
automatically upload and save references to large multimodal objects using Files API (max 2GB per file,
20GB total, autodeletes after 48h).
"""
# if we already have the translated content from google, just use that
# this is useful for thought signatures, mainly
if RAW_RESPONSE_EXTRA_KEY in msg.extra:
return msg.extra[RAW_RESPONSE_EXTRA_KEY].candidates[0].content
role = ROLE_TRANSFORMS.get(msg.role, msg.role.value)
content = []
# FUNCTION
if msg.role == ChatRole.FUNCTION:
# tool call error
if msg.is_tool_call_error:
content.append(
genai_types.Part(
function_response=genai_types.FunctionResponse(
# we do name=None when model tries to call a function that doesn't exist
# but Gemini doesn't allow name=None for inputs, so we just set it to "error"
id=msg.tool_call_id,
name=msg.name or "error",
response={"error": msg.text},
)
)
)
else:
content.append(
genai_types.Part(
function_response=genai_types.FunctionResponse(
id=msg.tool_call_id, name=msg.name, response={"result": msg.text}
)
)
)
# ASSISTANT, USER messages
else:
for part in msg.parts:
# --- multimodal ---
if _optional.has_multimodal_core and isinstance(
part,
(
_optional.multimodal_core.ImagePart,
_optional.multimodal_core.AudioPart,
_optional.multimodal_core.BinaryFilePart,
),
):
content.append(await self._translate_multimodal_part(part))
# reasoning
elif isinstance(part, ReasoningPart):
content.append(genai_types.Part(text=part.content, thought=True))
# default
else:
content.append(genai_types.Part(text=str(part)))
# ASSISTANT messages with tool calls
if msg.role == ChatRole.ASSISTANT and msg.tool_calls:
for tc in msg.tool_calls:
content.append(
genai_types.Part(
function_call=genai_types.FunctionCall(id=tc.id, name=tc.function.name, args=tc.function.kwargs)
)
)
return genai_types.Content(role=role, parts=content)
async def _translate_multimodal_part(self, part) -> genai_types.Part:
"""
Translate a multimodal kani part to a google part, uploading it to the Files API if it's large.
Caches uploaded files for re-use based on sha256.
"""
# if we have uploaded this file to the files API before, add the file part
sha256 = part.sha256()
if sha256 in self._multimodal_file_cache:
google_file = self._multimodal_file_cache[sha256]
# check if the upload is still valid
now = datetime.datetime.now(tz=datetime.timezone.utc)
if now < google_file.expiration_time:
log.debug(f"Using cached google file part: {google_file}")
return genai_types.Part.from_uri(file_uri=google_file.uri, mime_type=google_file.mime_type)
log.debug(f"Google file part is expired, falling through to re-upload")
# otherwise read the file
# image
if isinstance(part, _optional.multimodal_core.ImagePart):
media_type = "image/png"
data = part.as_bytes(format="png")
# audio
elif isinstance(part, _optional.multimodal_core.AudioPart):
media_type = "audio/wav"
data = part.as_wav_bytes()
# video/arbitrary binary
elif isinstance(part, _optional.multimodal_core.BinaryFilePart):
media_type = part.mime
data = part.as_bytes()
else:
raise ValueError(
f"Invalid multimodal message part: {part!r}. This should never happen. Please open a"
" bug report with reproduction steps."
)
# if the file data is more than the threshold, upload it and use the file part
if len(data) >= self.multimodal_upload_bytes_threshold:
log.debug(f"Uploading multimodal file to Files API (len={len(data)})")
google_file = await self.client.aio.files.upload(
file=io.BytesIO(data), config=genai_types.UploadFileConfig(mime_type=media_type)
)
log.debug(google_file)
if google_file.state not in (genai_types.FileState.ACTIVE, genai_types.FileState.PROCESSING):
raise RuntimeError(f"Invalid google file state, file: {google_file}")
self._multimodal_file_cache[sha256] = google_file
google_part = genai_types.Part.from_uri(file_uri=google_file.uri, mime_type=google_file.mime_type)
if google_file.state == genai_types.FileState.ACTIVE:
return google_part
# wait until the file is done processing
log.debug(f"Uploaded google file part is not ACTIVE, waiting for ACTIVE")
for idx in range(50):
# poll every 5s since google does not offer a blocking wait
await asyncio.sleep(5)
google_file = await self.client.aio.files.get(name=google_file.name)
log.debug(f"{(idx + 1) * 5} sec...\n{google_file}")
if google_file.state == genai_types.FileState.ACTIVE:
self._multimodal_file_cache[sha256] = google_file
return google_part
self._multimodal_file_cache.pop(sha256, None)
raise RuntimeError(
f"Google file state is not ACTIVE after long wait, something might be wrong!\n{google_file}"
)
# otherwise just include the bytes inline
return genai_types.Part.from_bytes(data=data, mime_type=media_type)
def _translate_google_response(self, resp: genai_types.GenerateContentResponse) -> Completion:
tool_calls = []
parts = []
resp_content = resp.candidates[0].content
resp_parts = resp_content.parts if resp_content is not None else None
if resp_parts is None:
resp_parts = []
warn_in_userspace(
"The engine did not return any content. Consider increasing `max_output_tokens` if set.\nResponse:"
f" {resp}",
stacklevel=3,
)
for part in resp_parts:
if part.thought:
parts.append(ReasoningPart(content=part.text))
elif part.text:
parts.append(part.text)
elif part.function_call:
fc = FunctionCall(name=part.function_call.name, arguments=json.dumps(part.function_call.args))
tc = ToolCall.from_function_call(fc, call_id_=part.function_call.id)
tool_calls.append(tc)
else:
warnings.warn(
f"The engine returned an unknown part: {part}. This will not be returned in the ChatMessage. To"
f" access this part, use `message.extra[{RAW_RESPONSE_EXTRA_KEY!r}].candidates[0].content.parts`."
)
if len(parts) == 1 and isinstance(parts[0], str):
content = parts[0]
elif not parts:
content = None
else:
content = parts
kani_msg = ChatMessage.assistant(content, tool_calls=tool_calls or None)
# also cache the message token len
self.set_cached_message_len(kani_msg, resp.usage_metadata.candidates_token_count)
# set the extra
kani_msg.extra[RAW_RESPONSE_EXTRA_KEY] = resp
return Completion(
message=kani_msg,
prompt_tokens=resp.usage_metadata.prompt_token_count,
completion_tokens=resp.usage_metadata.candidates_token_count,
)
# ==== kani impl ====
[docs]
async def prompt_len(self, messages, functions=None, **kwargs) -> int:
if (cached_len := self.get_cached_prompt_len(messages, functions, **kwargs)) is not None:
return cached_len
request_config, prompt_msgs = await self._prepare_request(messages, functions, kwargs, intent="count_tokens")
# HACK: we have to run estimation for system instructions or tools if we're not using Vertex
# since the AI Studio token counting endpoint is broken >:c
# https://github.com/googleapis/python-genai/issues/432
token_counting_machine_broke_count = 0
if not self.client.vertexai:
# Google documents 4 bytes per token, so we do a conservative 3.8 char/tok
chars = 0
if "system_instruction" in request_config:
chars += len(request_config.pop("system_instruction"))
if "tools" in request_config:
chars += len(json.dumps([t.model_dump(mode="json") for t in request_config.pop("tools")]))
token_counting_machine_broke_count = int(chars / 3.8)
result = await self.client.aio.models.count_tokens(
model=self.model, contents=prompt_msgs, config=request_config
)
count = result.total_tokens + token_counting_machine_broke_count
self.set_cached_prompt_len(messages, functions, length=count, **kwargs)
return count
[docs]
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> Completion:
request_config, prompt_msgs = await self._prepare_request(
messages, functions, hyperparams, intent="generate_content"
)
# --- completion ---
assert len(prompt_msgs) > 0
message = await self.client.aio.models.generate_content(
model=self.model,
contents=prompt_msgs,
config=request_config,
)
# translate to kani
return self._translate_google_response(message)
[docs]
async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> AsyncIterable[str | BaseCompletion]:
# do the stream
request_config, prompt_msgs = await self._prepare_request(
messages, functions, hyperparams, intent="generate_content_stream"
)
assert len(prompt_msgs) > 0
last_chunk = None
content_parts = []
async for chunk in await self.client.aio.models.generate_content_stream(
model=self.model,
contents=prompt_msgs,
config=request_config,
):
parts = chunk.candidates[0].content.parts
if parts is None:
continue
for part in parts:
if part.text and not part.thought:
yield part.text
last_chunk = chunk
content_parts.extend(parts)
if last_chunk:
last_chunk.candidates[0].content.parts = content_parts
yield self._translate_google_response(last_chunk)
# ==== deprecated ====
# because we have to estimate tokens wildly and the ctx is so long we'll just reserve a bunch
token_reserve = 500
[docs]
@deprecated("Use prompt_len instead")
def message_len(self, message: ChatMessage) -> int:
if (cached_len := self.get_cached_message_len(message)) is not None:
return cached_len
# TODO with async token counting use the token counting API
chars = len(message.role.value)
tokens = 0
if _optional.has_multimodal_core:
for part in message.parts:
if isinstance(part, _optional.multimodal_core.ImagePart):
tokens += mm_tokens.tokens_from_image_size(part.size, self.model)
elif isinstance(part, _optional.multimodal_core.AudioPart):
tokens += mm_tokens.tokens_from_audio_duration(part.duration, self.model)
elif isinstance(part, _optional.multimodal_core.VideoPart):
tokens += mm_tokens.tokens_from_video_duration(part.duration, self.model)
else:
chars += len(str(part))
else:
chars += len(message.text)
# tools
if message.tool_calls:
for tc in message.tool_calls:
chars += len(tc.function.name) + len(tc.function.arguments)
# Google documents 4 bytes per token, so we do a conservative 3.8 char/tok
return int(chars / 3.8) + tokens
[docs]
@deprecated("Use prompt_len instead")
def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# wrap an inner impl to use lru_cache with frozensets
return self._function_token_reserve_impl(frozenset(functions))
@functools.lru_cache(maxsize=256)
def _function_token_reserve_impl(self, functions):
# panik, also assume len/4?
n = sum(len(f.name) + len(f.desc) + len(json.dumps(f.json_schema)) for f in functions)
return int(n / 3.8)