Implementing an Engine#

To create your own engine, all you have to do is subclass BaseEngine:

class kani.engines.base.BaseEngine[source]

Base class for all LM engines.

To add support for a new LM, make a subclass of this and implement the abstract methods below.

max_context_size: int

The maximum context size supported by this engine’s LM.

abstract message_len(message: ChatMessage) int[source]

Return the length, in tokens, of the given chat message.

abstract async predict(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**hyperparams,
) BaseCompletion[source]

Given the current context of messages and available functions, get the next predicted chat message from the LM.

Parameters:
  • messages – The messages in the current chat context. sum(message_len(m) for m in messages) is guaranteed to be less than max_context_size.

  • functions – The functions the LM is allowed to call.

  • hyperparams – Any additional parameters to pass to the engine.

token_reserve: int = 0

Optional: The number of tokens to reserve for internal engine mechanisms (e.g. if an engine has to set up the model’s reply with a delimiting token).

Default: 0

function_token_reserve(functions: list[AIFunction]) int[source]

Optional: How many tokens are required to build a prompt to expose the given functions to the model.

Default: If this is not implemented and the user passes in functions, log a warning that the engine does not support function calling.

async stream(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**hyperparams,
) AsyncIterable[str | BaseCompletion][source]

Optional: Stream a completion from the engine, token-by-token.

This method’s signature is the same as BaseEngine.predict().

This method should yield strings as an asynchronous iterable.

Optionally, this method may also yield a BaseCompletion. If it does, it MUST be the last item yielded by this method.

If an engine does not implement streaming, this method will yield the entire text of the completion in a single chunk by default.

Parameters:
  • messages – The messages in the current chat context. sum(message_len(m) for m in messages) is guaranteed to be less than max_context_size.

  • functions – The functions the LM is allowed to call.

  • hyperparams – Any additional parameters to pass to the engine.

async close()[source]

Optional: Clean up any resources the engine might need.

A new engine must implement at least the two abstract methods and set the abstract attribute:

With just these three implementations, an engine will be fully functional!

kani comes with a couple additional bases and utilities to help you build engines for models on HuggingFace or with an available HTTP API.

Optional Methods#

Engines also come with a set of optional methods/attributes to override that you can use to customize its behaviour further. For example, engines often have to add a custom model-specific prompt in order to expose functions to the underlying model, and kani needs to know about the extra tokens added by this prompt!

Adding Function Calling#

If you’re writing an engine for a model with function calling, there are a couple additional steps you need to take.

Generally, to use function calling, you need to do the following:

  1. Tell the model what functions it has available to it
    1. Optional - tell the model what format to output to request calling a function (if the model is not already fine-tuned to do so)

  2. Parse the model’s requests to call functions from its text generations

To tell the model what functions it has available, you’ll need to somehow prompt the model. You’ll need to implement two methods: BaseEngine.predict() and BaseEngine.function_token_reserve().

BaseEngine.predict() takes in a list of available AIFunctions as an argument, which you should use to build such a prompt. BaseEngine.function_token_reserve() tells kani how many tokens that prompt takes, so the context window management can ensure it never sends too many tokens.

You’ll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example). When you’re building the prompt, you’ll need to iterate over ChatMessage.tool_calls if it exists, and add your model’s appropriate function calling prompt.

To parse the model’s requests to call a function, you also do this in BaseEngine.predict(). After generating the model’s completion (usually a string, or a list of token IDs that decodes into a string), separate the model’s conversational content from the structured function call:

../_images/function-calling-parsing.png

Finally, return a Completion with the .message attribute set to a ChatMessage with the appropriate ChatMessage.content and ChatMessage.tool_calls.

Note

See Internal Representation for more information about ToolCalls vs FunctionCalls.