Source code for kani.prompts.types

from collections.abc import Collection
from dataclasses import dataclass
from typing import Callable, TypeVar

from kani.ai_function import AIFunction
from kani.models import ChatMessage, ChatRole, MessagePart, ToolCall

PipelineMsgT = ChatMessage
"""The type of messages in the pipeline"""

MessageContentT = str | list[MessagePart | str] | None
"""The type of ChatMessage.content"""

RoleFilterT = ChatRole | Collection[ChatRole]
"""A role or list of roles to apply a step to"""

PredicateFilterT = Callable[[PipelineMsgT], bool]
"""A callable that determines whether or not to apply a step to an input"""

FunctionCallStrT = Callable[[ToolCall], str | None]
"""A callable to format a toolcall as a str"""

ApplyResultT = TypeVar("ApplyResultT")


[docs] @dataclass class ApplyContext: """Context about where a message lives in the pipeline for an arbitrary Apply operation.""" msg: PipelineMsgT """The message being operated on.""" is_last: bool """Whether the message being operated on is the last message (of all types) in the chat prompt.""" idx: int """The index of the message in the chat prompt.""" messages: list[PipelineMsgT] """The list of all messages in the chat prompt.""" functions: list[AIFunction] """The list of functions available in the chat prompt.""" @property def is_last_of_type(self) -> bool: """Whether this message is the last one of its role in the chat prompt.""" return self.msg is [m for m in self.messages if m.role == self.msg.role][-1]
ApplyCallableT = Callable[[PipelineMsgT], ApplyResultT] | Callable[[PipelineMsgT, ApplyContext], ApplyResultT] """A function taking 1-2 args""" MacroApplyResultT = TypeVar("MacroApplyResultT") MacroApplyCallableT = Callable[[list[PipelineMsgT], list[AIFunction]], list[MacroApplyResultT]] """A function taking 2 args (msgs, funcs)"""