Message Parts¶
In some advanced use cases, ChatMessage.content may be a tuple of MessagePart or str rather than
a string. ChatMessage exposes ChatMessage.text (always a string or None) and ChatMessage.parts (always
a list of message parts), which we recommend using instead of ChatMessage.content. These properties are
dynamically generated based on the underlying content, and it is safe to mix messages with different content types in a
single Kani.
Generally, message part classes are defined by the engine, and consumed by the developer. Message parts can be used in any role’s message - for example, you might use a message part in an assistant message to separate out a chain of thought from a user reply, or in a user message to supply an image to a multimodal model.
Let’s say you wanted your model to perform a chain-of-thought before replying to the user, but didn’t want that chain of thought to be visible in its output. In the example below, we’ll define a hidden ThoughtPart that contains this data, and have our engine use this when building the prompt. This may be a little bit contrived, but hopefully it demonstrates how to use the MessagePart interface.
Tip
MessageParts are commonly used for multimodal inputs: for example, both the OpenAIEngine and
AnthropicEngine support image inputs using kani-multimodal-core‘s
ImagePart.
Warning
Don’t confuse Message Parts with Message Extras (Metadata). Message Parts are used for engine-agnostic inputs that require richer representation than a string, whereas Message Extras are used for engine-specific metadata about a certain message (e.g., an internal ID or detailed engine-specific usage data).
Defining a MessagePart¶
To define a MessagePart, you should create a new class that inherits from MessagePart.
Under the hood, a MessagePart is a Pydantic model, so to add attributes to your part, you can define them as
type-annotated class variables. For example, to add a data attribute with a string type, you can define
data: str.
Next, you’ll want to define the __str__() method, which is how kani defines the behaviour when your message part
is cast to a string. This happens in a couple places:
The message part is displayed to the user through a string-casting method (e.g.
Kani.full_round_str())The user is using the
chat_in_terminal()development utilityThe message was provided to an engine that does not support the class (e.g. an image part in a text-only model).
The default behaviour is to transform the part to a Python-esque representation of its data (e.g.
<ThoughtPart data="...">) and log a warning. You can override this method to specify the canonical string
representation of your message part.
from kani import MessagePart
class ThoughtPart(MessagePart):
# hold some string data that is the chain-of-thought
data: str
# when stringified, this part should be hidden from the user/any engine that does not explicitly support it
def __str__(self):
return ""
When you define a MessagePart, kani will automatically register serialization and deserialization schemes for your
class, so that all MessageParts are compatible with Kani.save() and Kani.load().
Warning
If you change the attributes or location of the class definition, old data may fail to be loaded into a new class.
Using MessageParts in Engines¶
Once we define a message part, we need to tell the engine how to use it. Since message parts are designed to provide engine-specific metadata while maintaining cross-engine compatibility, engines should check for the classes they know and cast unknown classes to a string. This lets parts control how they will be serialized in engines that don’t natively support them.
A common, but not necessary, pattern is to write an engine that wraps other engines. This wrapper engine acts as a translation layer for specific message parts, eventually compiling a single string (or list of strings) for a base engine (e.g. LlamaEngine) to consume.
Here’s the implementation for our chain-of-thought example. Note how we translate the parts by building a new message
with a string content composed of parts in translate_message, then use that method in our implementations of
prompt_len and predict.
from kani import AIFunction, ChatMessage, MessagePart
from kani.engines import WrapperEngine, Completion
class ChainOfThoughtEngine(WrapperEngine):
@staticmethod
def translate_message(message: ChatMessage) -> ChatMessage:
"""Translate a input message into a simple string-only message to pass to the underlying engine."""
content = ""
# iterate over parts: if it's a thought, place its data in the content; otherwise prefix it with "Answer: "
for part in message.parts:
if isinstance(part, ThoughtPart):
content += f"{part.data}\n"
else:
content += f"Answer: {str(part)}"
# return the translated message
return message.copy_with(content=content.strip())
# === BaseEngine interface ===
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> Completion:
# translate the messages
translated_messages = [self.translate_message(m) for m in messages]
# generate a completion using the underlying engine given those messages
result = await self.engine.predict(translated_messages, functions, **hyperparams)
# parse the string-completion back into parts - with some checks to make sure the model actually output the
# right format
text = result.message.text
if "Answer:" in text:
thought, answer = text.split("Answer:", 1)
message_with_parts = result.message.copy_with(parts=[ThoughtPart(data=thought.strip()), answer.strip()])
else:
message_with_parts = result.message
# we'll print the result so you can see the parts, though a real engine wouldn't want to
print("Model response with parts:", message_with_parts.parts)
# and return the modified answer
return Completion(
message=message_with_parts,
prompt_tokens=result.prompt_tokens,
completion_tokens=result.completion_tokens,
)
Now, we can use this engine by passing an underlying engine to it, prompt our model to follow our chain-of-thought format, and see how it works!
cot_engine = ChainOfThoughtEngine(OpenAIEngine())
ai = Kani(
cot_engine,
system_prompt=(
"When answering the user, think step by step. Output your thoughts first, "
'then output the final answer on its own line, in the format "Answer: {answer}".'
),
)
chat_in_terminal(ai)
# USER: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls.
# How many tennis balls does he have now?
#
# Model response with parts:
# [ThoughtPart(data='Roger already has 5 tennis balls. He buys 2 more cans of tennis balls, and each can has 3 tennis balls. \n\nTo find out how many tennis balls he has now, we need to multiply the number of cans with the number of tennis balls in each can. Since he bought 2 cans, we multiply 2 by 3:\n\n2 cans * 3 tennis balls per can = 6 tennis balls from the cans\n\nNext, we add the number of tennis balls he already had:\n\n5 tennis balls + 6 tennis balls = 11 tennis balls'),
# 'Roger now has 11 tennis balls.']
#
# AI: Roger now has 11 tennis balls.
See how the engine splits up the model’s response into two parts: the thought and the displayed answer. In the
chat_in_terminal() development utility, we only display the displayed answer, but in a larger application you’d
get the full ChatMessage, and you could iterate over the parts to handle displaying the thought with your own
application logic.
If you switch to a different base engine and inject the chat history, the base engine won’t see the thoughts, since the default stringification behaviour is to return the empty string. This means that you can build complex engine behaviour that won’t interfere with other engines.
Now you can use any message part you can think of - and you can create user messages with parts too.
Using MessageParts in Tools¶
What if an AI function returns multimodal content? By default, Kani applies certain transformations to the return value of an AIFunction in order:
if it is a
ChatMessage, do not modify itif it is a list of
MessagePart, do not modify itif it is a JSON-serializable Python dict, serialize it to JSON
if it is a Pydantic model, serialize it to JSON
otherwise, cast it to a string
Thus, in order to return multimodal content, the AIFunction must return a FUNCTION-role chat message. For example, the following code snippet defines a function which returns an image:
class MyKani(Kani):
@ai_function()
async def get_image(self, url: str):
"""Download the image at a certain URL, and view it."""
return [await ImagePart.from_url(url)]
However, most API-based LLMs currently do not support multimodal returns from tools.
Saving & Loading MessageParts¶
By default, kani will register a serializer and deserializer for each MessagePart you define, which will recursively save every public attribute in your MessagePart in JSON format. If all of your data is JSON-serializable, great! You don’t need to do anything extra.
Certain MessageParts, however, may store non-JSON-serializable data, like a large binary file for a multimodal model. In order to correctly save the state of these MessageParts, you must implement a custom serializer and validator.
You might, for example, save raw binary in Base64 format. When saving large binary objects, you should check for the
presence of kani.utils.saveload.SAVELOAD_CONTEXT_KEY in the serialization/validation context. If it is present,
you may use the KaniZipSaveContext to save large blobs to a separate location in the save archive.
This has the benefit of keeping the main chat log human-readable and the file size small.
See Saving & Loading Kani State for more information.
Example¶
An example of a custom serializer and validator for a messagepart containing large binary data
@model_serializer()
def _serialize_binary_file_part(self, info) -> dict[str, str]:
"""
When we serialize, save the data as:
- B64 of compressed data when not in zipfile mode
- a file when in zipfile mode
"""
if ctx := saveload.get_ctx(info):
suffix = mimetypes.guess_extension(self.mime) or ""
fp = ctx.save_bytes(self.as_bytes(), suffix=suffix)
return {"_archive_path": fp, "mime": self.mime, **self._get_typekey_dict()}
else:
compressed_b64 = base64.b64encode(zlib.compress(self.as_bytes())).decode()
return {"mime": self.mime, "compression": "gzip", "data": compressed_b64, **self._get_typekey_dict()}
@model_validator(mode="wrap")
@classmethod
def _validate_binary_file_part(cls, v, nxt, info):
"""If the value is the URI we saved, try loading it that way."""
assert isinstance(v, dict)
if "_archive_path" in v:
ctx = saveload.get_ctx(info)
data = ctx.load_bytes(v["_archive_path"])
return cls.from_bytes(data, mime=v["mime"])
elif "data" in v:
if v.get("compression") == "gzip":
decompressed = zlib.decompress(base64.b64decode(v["data"]))
return cls.from_bytes(mime=v["mime"], data=decompressed)
return cls.from_b64(mime=v["mime"], data=v["data"])
return nxt(v)