Implementing an Engine

Important

Looking to use a model available through HuggingFace Transformers? Before implementing your own model class, try using HuggingEngine! If your model has a Chat Template available, Kani will automatically use the correct prompt format.

from kani.engines.huggingface import HuggingEngine
engine = HuggingEngine(model_id="your-org/your-model-id")

To create your own engine, all you have to do is subclass BaseEngine:

class kani.engines.base.BaseEngine[source]

Base class for all LM engines.

To add support for a new LM, make a subclass of this and implement the abstract methods below.

max_context_size: int

The maximum context size supported by this engine’s LM.

abstract prompt_len(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**kwargs,
) int[source]
abstract prompt_len(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**kwargs,
) int

Returns the number of tokens used by the given prompt (i.e., list of messages and functions), or a best estimate if the exact count is unavailable.

This method MAY be asynchronous. Use Kani.prompt_token_len() for a higher-level interface that handles asynchrony.

Parameters:
  • messages – The messages in the prompt.

  • functions – The functions included in the prompt.

  • kwargs – Any additional parameters to pass to the underlying token counting implementation (engine-specific).

abstract async predict(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**hyperparams,
) BaseCompletion[source]

Given the current context of messages and available functions, get the next predicted chat message from the LM.

Parameters:
  • messages – The messages in the current chat context. prompt_len(messages, functions) is guaranteed to be less than max_context_size.

  • functions – The functions the LM is allowed to call.

  • hyperparams – Any additional parameters to pass to the engine.

async stream(
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**hyperparams,
) AsyncIterable[str | BaseCompletion][source]

Optional: Stream a completion from the engine, token-by-token.

This method’s signature is the same as BaseEngine.predict().

This method should yield strings as an asynchronous iterable.

Optionally, this method may also yield a BaseCompletion. If it does, it MUST be the last item yielded by this method.

If an engine does not implement streaming, this method will yield the entire text of the completion in a single chunk by default.

Parameters:
  • messages – The messages in the current chat context. prompt_len(messages, functions) is guaranteed to be less than max_context_size.

  • functions – The functions the LM is allowed to call.

  • hyperparams – Any additional parameters to pass to the engine.

async close()[source]

Optional: Clean up any resources the engine might need.

disable_function_calling_kwargs = {'include_functions': False}

Kwargs to set in the Kani._full_round loop when the model should disable function calling. Mostly this is useful for API models, where we want to still define functions in the prompt but disallow calling them.

message_len(message: ChatMessage) int[source]

Returns the estimated number of tokens used by a single given message.

Note

The token count returned by this may not exactly reflect the actual token count (e.g., due to prompt formatting or not having access to the tokenizer). It should, however, be a safe overestimate to use as an upper bound.

Deprecated since version 1.7.0: Use BaseEngine.prompt_len() instead.

token_reserve: int = 0

Optional: The number of tokens to reserve for internal engine mechanisms (e.g. if an engine has to set up the model’s reply with a delimiting token). Default: 0

Deprecated since version 1.7.0: Use BaseEngine.prompt_len() instead.

function_token_reserve(functions: list[AIFunction]) int[source]

Optional: How many tokens are required to build a prompt to expose the given functions to the model.

Default: If this is not implemented and the user passes in functions, log a warning that the engine does not support function calling.

Deprecated since version 1.7.0: Use BaseEngine.prompt_len() instead.

A new engine must implement at least the two abstract methods and set the abstract attribute:

With just these three implementations, an engine will be fully functional!

kani comes with a couple additional bases and utilities to help you build engines for models on HuggingFace or with an available HTTP API.

Optional Methods

Engines also come with a set of optional methods/attributes to override that you can use to customize its behaviour further.

Adding Function Calling

Important

Already have a way to build function calling prompts but just need a way to parse the outputs? Check out the list of Model-Specific Parsers.

This is commonly the case for Hugging Face models which implement the function calling prompt in their chat template.

If you’re writing an engine for a model with function calling, there are a couple additional steps you need to take.

Generally, to use function calling, you need to do the following:

  1. Tell the model what functions it has available to it
    1. Optional - tell the model what format to output to request calling a function (if the model is not already fine-tuned to do so)

  2. Parse the model’s requests to call functions from its text generations

To tell the model what functions it has available, you’ll need to somehow prompt the model. You’ll just need to edit two methods: BaseEngine.predict() and BaseEngine.prompt_len().

BaseEngine.predict() takes in a list of available AIFunctions as an argument, which you should use to build such a prompt. BaseEngine.prompt_len() tells kani how many tokens that prompt takes, so the context window management can ensure it never sends too many tokens.

You’ll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example). When you’re building the prompt, you’ll need to iterate over ChatMessage.tool_calls if it exists, and add your model’s appropriate function calling prompt.

To parse the model’s requests to call a function, you also do this in BaseEngine.predict(). After generating the model’s completion (usually a string, or a list of token IDs that decodes into a string), separate the model’s conversational content from the structured function call:

../_images/function-calling-parsing.png

Finally, return a Completion with the .message attribute set to a ChatMessage with the appropriate ChatMessage.content and ChatMessage.tool_calls.

Note

See Internal Representation for more information about ToolCalls vs FunctionCalls.