Source code for kani.utils.saveload
import dataclasses
import hashlib
import zipfile
from kani.models import BaseModel, ChatMessage
from kani.utils.typing import PathLike
# ==== main ====
class SavedKani(BaseModel):
version: int = 1
always_included_messages: list[ChatMessage]
chat_history: list[ChatMessage]
def save(fp: PathLike, inst, *, save_format: str, **kwargs):
# create a Pydantic model for the saved attrs
data = SavedKani(always_included_messages=inst.always_included_messages, chat_history=inst.chat_history)
if save_format == "kani":
# zip w/ manifest file for multimodal attachments
with zipfile.ZipFile(fp, mode="w", compression=zipfile.ZIP_DEFLATED, compresslevel=6) as zf:
ctx = KaniZipSaveContext(zf=zf)
# the model_dump_json should write to zf when SAVELOAD_CONTEXT_KEY is provided
index = data.model_dump_json(context={SAVELOAD_CONTEXT_KEY: ctx}, fallback=repr, **kwargs)
with zf.open("index.json", mode="w") as f:
f.write(index.encode("utf-8"))
elif save_format == "json":
# save using legacy JSON
with open(fp, "w", encoding="utf-8") as f:
f.write(data.model_dump_json(fallback=repr, **kwargs))
else:
raise ValueError("save_format must be either 'kani' or 'json'.")
def load(fp: PathLike, **kwargs) -> SavedKani:
# test file format
if zipfile.is_zipfile(fp):
# zipfile
with zipfile.ZipFile(fp, mode="r") as zf:
ctx = KaniZipSaveContext(zf=zf)
with zf.open("index.json") as f:
data = f.read().decode(encoding="utf-8")
return SavedKani.model_validate_json(data, context={SAVELOAD_CONTEXT_KEY: ctx}, **kwargs)
else:
# json
with open(fp, encoding="utf-8") as f:
data = f.read()
return SavedKani.model_validate_json(data, **kwargs)
# ==== zip mode ====
# to utilize multi-file saving, a model that is being saved (usually a MessagePart) should use a wrap-mode model
# serializer and check for this key in the serializationinfo.context object
SAVELOAD_CONTEXT_KEY = "kani.saveload.context"
[docs]
@dataclasses.dataclass
class KaniZipSaveContext:
zf: zipfile.ZipFile
[docs]
def save_bytes(self, data: bytes, suffix: str = "") -> str:
"""
Save the given bytes to the zip file and return its path.
Filename is automatically determined by SHA256 hash.
If *suffix* is given, the filename will end with the given suffix.
"""
the_hash = hashlib.sha256(data)
digest = the_hash.hexdigest()
fp = f"blobs/{digest[:2]}/{digest}{suffix}"
with self.zf.open(fp, mode="w") as f:
f.write(data)
return fp
[docs]
def load_bytes(self, fp: str) -> bytes:
"""
Read the bytes from the given path in the archive.
"""
with self.zf.open(fp, mode="r") as f:
return f.read()
[docs]
def get_ctx(info) -> KaniZipSaveContext | None:
"""Get the KaniZipSaveContext from a SerializationInfo/ValidationInfo object."""
if info.context and SAVELOAD_CONTEXT_KEY in info.context:
ctx = info.context[SAVELOAD_CONTEXT_KEY]
assert isinstance(ctx, KaniZipSaveContext)
return ctx
return None