Implementing an Engine#
To create your own engine, all you have to do is subclass BaseEngine
:
- class kani.engines.base.BaseEngine[source]
Base class for all LM engines.
To add support for a new LM, make a subclass of this and implement the abstract methods below.
- max_context_size: int
The maximum context size supported by this engine’s LM.
- abstract message_len(message: ChatMessage) int [source]
Return the length, in tokens, of the given chat message.
- abstract async predict(
- messages: list[ChatMessage],
- functions: list[AIFunction] | None = None,
- **hyperparams,
Given the current context of messages and available functions, get the next predicted chat message from the LM.
- Parameters:
messages – The messages in the current chat context.
sum(message_len(m) for m in messages)
is guaranteed to be less than max_context_size.functions – The functions the LM is allowed to call.
hyperparams – Any additional parameters to pass to the engine.
- token_reserve: int = 0
Optional: The number of tokens to reserve for internal engine mechanisms (e.g. if an engine has to set up the model’s reply with a delimiting token).
Default: 0
- function_token_reserve(functions: list[AIFunction]) int [source]
Optional: How many tokens are required to build a prompt to expose the given functions to the model.
Default: If this is not implemented and the user passes in functions, log a warning that the engine does not support function calling.
- async stream(
- messages: list[ChatMessage],
- functions: list[AIFunction] | None = None,
- **hyperparams,
Optional: Stream a completion from the engine, token-by-token.
This method’s signature is the same as
BaseEngine.predict()
.This method should yield strings as an asynchronous iterable.
Optionally, this method may also yield a
BaseCompletion
. If it does, it MUST be the last item yielded by this method.If an engine does not implement streaming, this method will yield the entire text of the completion in a single chunk by default.
- Parameters:
messages – The messages in the current chat context.
sum(message_len(m) for m in messages)
is guaranteed to be less than max_context_size.functions – The functions the LM is allowed to call.
hyperparams – Any additional parameters to pass to the engine.
- async close()[source]
Optional: Clean up any resources the engine might need.
A new engine must implement at least the two abstract methods and set the abstract attribute:
BaseEngine.message_len()
takes a singleChatMessage
and returns the length of that message, in tokens.BaseEngine.predict()
takes a list ofChatMessage
andAIFunction
and returns a newBaseCompletion
.BaseEngine.max_context_size
specifies the model’s token context size.
With just these three implementations, an engine will be fully functional!
kani comes with a couple additional bases and utilities to help you build engines for models on HuggingFace or with an available HTTP API.
Optional Methods#
Engines also come with a set of optional methods/attributes to override that you can use to customize its behaviour further. For example, engines often have to add a custom model-specific prompt in order to expose functions to the underlying model, and kani needs to know about the extra tokens added by this prompt!
BaseEngine.token_reserve
: if your engine needs to reserve tokens (e.g. for a one-time prompt template).BaseEngine.function_token_reserve()
: specify how many tokens are needed to expose a set of functions to the model.BaseEngine.close()
: if your engine needs to clean up resources during shutdown.
Adding Function Calling#
If you’re writing an engine for a model with function calling, there are a couple additional steps you need to take.
Generally, to use function calling, you need to do the following:
- Tell the model what functions it has available to it
Optional - tell the model what format to output to request calling a function (if the model is not already fine-tuned to do so)
Parse the model’s requests to call functions from its text generations
To tell the model what functions it has available, you’ll need to somehow prompt the model.
You’ll need to implement two methods: BaseEngine.predict()
and BaseEngine.function_token_reserve()
.
BaseEngine.predict()
takes in a list of available AIFunction
s as an argument, which you should use to
build such a prompt. BaseEngine.function_token_reserve()
tells kani how many tokens that prompt takes, so the
context window management can ensure it never sends too many tokens.
You’ll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example).
When you’re building the prompt, you’ll need to iterate over ChatMessage.tool_calls
if it exists, and add
your model’s appropriate function calling prompt.
To parse the model’s requests to call a function, you also do this in BaseEngine.predict()
. After generating the
model’s completion (usually a string, or a list of token IDs that decodes into a string), separate the model’s
conversational content from the structured function call:
Finally, return a Completion
with the .message
attribute set to a ChatMessage
with the
appropriate ChatMessage.content
and ChatMessage.tool_calls
.
Note
See Internal Representation for more information about ToolCalls vs FunctionCalls.