import logging
import re
import warnings
from functools import cached_property
from typing import AsyncIterable
from kani import model_specific
from kani.ai_function import AIFunction
from kani.engines.base import BaseCompletion, BaseEngine, Completion
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated
try:
import torch
from llama_cpp import Llama
except ImportError:
raise MissingModelDependencies(
'The LlamaCppEngine requires extra dependencies. Please install kani with "pip install kani[cpp]". '
) from None
log = logging.getLogger(__name__)
[docs]
class LlamaCppEngine(BaseEngine):
"""
This class implements the main decoding logic for any GGUF model (not just LLaMA as the name might suggest).
**GPU Support**
llama.cpp supports multiple acceleration backends, which may require different flags to be set during installation.
To see the full list of backends, see their README at https://github.com/abetlen/llama-cpp-python.
To load some or all of the model layers on GPU, pass ``n_gpu_layers=...`` in the ``model_load_kwargs``. Use
``-1`` to specify all layers.
"""
def __init__(
self,
repo_id: str | None = None,
filename: str | None = None,
model_path: str | None = None,
max_context_size: int = 0,
prompt_pipeline: PromptPipeline[str | list[int]] = None,
*,
model_load_kwargs: dict = None,
**hyperparams,
):
"""
:param repo_id: The ID of the model repo to load from Hugging Face.
If this is set, ``filename`` must be set and ``model_path`` may not be set.
:param filename: A filename or glob pattern to match the model file in the Hugging Face repo.
If this is set, ``repo_id`` must be set and ``model_path`` may not be set.
:param model_path: A path to the model files on local disk.
If this is set, neither ``repo_id`` nor ``filename`` may be set.
:param max_context_size: The context size of the model.
:param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat
format (see :class:`.PromptPipeline`).
:param model_load_kwargs: Additional arguments to pass to ``Llama.from_pretrained()``.
See `this link <https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.from_pretrained>`_
for more info.
:param hyperparams: Additional arguments to supply the model during generation.
"""
if model_load_kwargs is None:
model_load_kwargs = {}
# exactly one of (model_path, (repo_id, filename)) should be passed
if (model_path is None and (repo_id is None or filename is None)) or (
model_path is not None and (repo_id is not None or filename is not None)
):
raise ValueError(
"Exactly one of (model_path, (repo_id, filename)) must be passed. Use `model_path` for locally"
" downloaded models, and `repo_id, filename` to download models from the Hugging Face hub."
)
self.repo_id = repo_id
self.filename = filename
self.model_path = model_path
self.pipeline = prompt_pipeline
# for convenience, if the filename is *-00001-of-0000X.gguf, mark all the others as additional files if not set
if filename is not None:
if match := re.match(r"(.*?)-(\d+)-of-(\d+)\.gguf", filename):
log.info("Sharded GGUF file given - ensuring that all GGUF shards are downloaded")
# there is an issue in llama-cpp-python that makes the additional_files inherit the subfolder of the parent
# https://github.com/abetlen/llama-cpp-python/issues/1938
if "/" in match[1]:
warnings.warn(
"llama-cpp-python can fail to find additional model files in subfolders. If you see a 404"
" error, try manually using huggingface-cli to download model files. See"
" https://github.com/abetlen/llama-cpp-python/issues/1938 for more information."
)
additional_files = []
for n in range(1, int(match[3]) + 1):
if n == int(match[2]):
continue
additional_files.append(f"{match[1]}-*{n}-of-{match[3]}.gguf")
log.info(f"additional_files={additional_files}")
model_load_kwargs.setdefault("additional_files", additional_files)
model_load_kwargs.setdefault("n_ctx", max_context_size)
if model_path is not None:
self.model = Llama(model_path=model_path, **model_load_kwargs)
else:
self.model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **model_load_kwargs)
self.hyperparams = hyperparams
self.max_context_size = max_context_size or self.model.n_ctx()
[docs]
def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str | list[int]:
"""
Given the list of messages from kani, build either a single string representing the prompt for the model,
or build the token list.
The default behaviour is to call the supplied pipeline.
"""
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.model_specific import prompt_pipeline_for_hf_model\n"
"pipeline = prompt_pipeline_for_hf_model(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
prompt = self.pipeline(messages, functions)
log.debug(f"BUILT PROMPT: {prompt}")
return prompt
def _get_generate_args(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
"""
Internal method to build common params for the generate call
and do some other pre-generate work
"""
prompt = self.build_prompt(messages, functions)
if isinstance(prompt, str):
# prompt str to tokens
input_toks = self.model.tokenize(prompt.encode(), add_bos=False, special=True)
input_len = len(input_toks)
elif isinstance(prompt, list):
input_toks = prompt
input_len = len(input_toks)
else:
raise TypeError("build_prompt should either return a str or a list[int].")
# set up hyperparams
hyperparams = {**self.hyperparams, **hyperparams}
hyperparams.setdefault("max_tokens", None) # by default llama.cpp sets this to 16, which is too small
# check for a model-specific parser
if self.repo_id:
model_specific.warn_for_uninitialized_parser(self.repo_id)
return input_toks, input_len, hyperparams
# ==== kani impl ====
[docs]
async def prompt_len(self, messages, functions=None, **kwargs) -> int:
input_toks, input_len, hyperparams = self._get_generate_args(messages, functions, **kwargs)
return input_len
[docs]
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> Completion:
"""
Given the current context of messages and available functions, get the next predicted chat message from the LM.
:param messages: The messages in the current chat context. ``prompt_len(messages, functions)`` is
guaranteed to be less than max_context_size.
:param functions: The functions the LM is allowed to call.
:param hyperparams: Any additional parameters to pass to ``Llama.create_completion()``. (See
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion)
"""
input_toks, input_len, hyperparams = self._get_generate_args(messages, functions, **hyperparams)
completion = self.model.create_completion(input_toks, **hyperparams)
return Completion(
ChatMessage.assistant(completion["choices"][0]["text"]),
prompt_tokens=input_len,
completion_tokens=completion["usage"]["completion_tokens"],
)
[docs]
async def stream(
self,
messages: list[ChatMessage],
functions: list[AIFunction] | None = None,
**hyperparams,
) -> AsyncIterable[str | BaseCompletion]:
"""
Given the current context of messages and available functions, get the next predicted chat message from the LM.
:param messages: The messages in the current chat context. ``prompt_len(messages, functions)`` is
guaranteed to be less than max_context_size.
:param functions: The functions the LM is allowed to call.
:param hyperparams: Any additional parameters to pass to ``Llama.create_completion()``. (See
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion)
"""
input_toks, input_len, hyperparams = self._get_generate_args(messages, functions, **hyperparams)
stream = self.model.create_completion(input_toks, stream=True, **hyperparams)
# iterate over the stream and yield/save
content_chunks = []
for chunk in stream:
text = chunk["choices"][0]["text"]
# yield content
if text is not None:
content_chunks.append(text)
yield text
# construct the final completion
# https://github.com/abetlen/llama-cpp-python/issues/1498 blocks token counting impl
content = None if not content_chunks else "".join(content_chunks)
yield Completion(message=ChatMessage.assistant(content))
[docs]
async def close(self):
self.model.close()
# ==== deprecated ====
@cached_property
@deprecated("Use prompt_len instead")
def token_reserve(self):
# infer the token reserve from the pipeline
if self.pipeline:
return self._infer_token_reserve()
return 0
def _infer_token_reserve(self):
"""If token_reserve is not set and we have a pipeline, infer it."""
prompt = self.pipeline.execute([], for_measurement=True)
if isinstance(prompt, list):
return len(prompt)
tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=True)
return len(tokenized)
[docs]
@deprecated("Use prompt_len instead")
def message_len(self, message: ChatMessage) -> int:
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.engines.huggingface import ChatTemplatePromptPipeline\n"
"pipeline = ChatTemplatePromptPipeline.from_pretrained(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
prompt = self.pipeline.execute([message], for_measurement=True)
if isinstance(prompt, list):
return len(prompt)
elif isinstance(prompt, torch.Tensor):
return len(prompt[0])
tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=True)
return len(tokenized)
[docs]
@deprecated("Use prompt_len instead")
def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.model_specific import prompt_pipeline_for_hf_model\n"
"pipeline = prompt_pipeline_for_hf_model(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
prompt = self.pipeline.execute([], functions, for_measurement=True)
if isinstance(prompt, list):
return len(prompt)
elif isinstance(prompt, torch.Tensor):
toklen = len(prompt[0])
else:
# prompt str to tokens
tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=False)
toklen = len(tokenized)
# warn if there are functions but no tokens
if toklen == 0:
warnings.warn(
"Functions were given to the model, but the function prompt returned 0 tokens! This model may not"
" support function calling, or you may need to implement"
f" `{type(self).__name__}.function_token_reserve()`."
)
return toklen