Source code for kani.engines.huggingface.base

import asyncio
import logging
import warnings
from collections import UserDict
from threading import Thread
from typing import AsyncIterable

from kani import _optional, model_specific
from kani.ai_function import AIFunction
from kani.engines.base import BaseCompletion, BaseEngine, Completion
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage
from kani.prompts.pipeline import PromptPipeline
from kani.utils.warnings import deprecated

try:
    import torch
    from transformers import (
        AutoModelForCausalLM,
        AutoProcessor,
        AutoTokenizer,
        BatchEncoding,
        BatchFeature,
        PreTrainedTokenizerBase,
        ProcessorMixin,
        TextIteratorStreamer,
    )
except ImportError:
    raise MissingModelDependencies(
        'The HuggingEngine requires extra dependencies. Please install kani with "pip install kani[huggingface]". '
        "You will also need to install PyTorch manually."
    ) from None

has_cuda = torch.backends.cuda.is_built()
has_mps = torch.backends.mps.is_built()

try:
    import accelerate

    has_accelerate = True
except ImportError:
    if has_cuda:
        warnings.warn(
            "A PyTorch install with CUDA was detected on this system, but `accelerate` is not installed. Run `pip"
            " install accelerate` for automatic GPU mapping of Hugging Face models."
        )
    has_accelerate = False

log = logging.getLogger(__name__)


[docs] class HuggingEngine(BaseEngine): """Base engine for all HuggingFace text-generation models. This class implements the main decoding logic for any HuggingFace model based on a pretrained ``AutoModelForCausalLM``. As most models use model-specific chat templates, this base class accepts a :class:`.PromptPipeline` to translate kani ChatMessages into a model-specific string. .. versionadded:: 1.2.0 By default, the ``HuggingEngine`` uses models' bundled chat template to build the prompt for chat-based models available on Hugging Face. See https://huggingface.co/docs/transformers/main/en/chat_templating for more information. **GPU Support** By default, the HuggingEngine loads the model on GPU if CUDA is detected on your system. To override the device the model is loaded on, pass ``device="cpu|cuda"`` to the constructor. **Multimodal support**: audio, images, video (depending on model). .. tip:: See :ref:`4b_quant` for information about loading a quantized model for lower memory usage. """ def __init__( self, model_id: str, max_context_size: int = None, prompt_pipeline: PromptPipeline[str | torch.Tensor] = None, *, # hf args token=None, device: str | None = None, tokenizer_cls=None, tokenizer_kwargs: dict = None, model_cls=AutoModelForCausalLM, model_load_kwargs: dict = None, chat_template_reasoning_content_key: str = None, chat_template_kwargs: dict = None, # multimodal args mm_audio_sample_rate: int = None, mm_video_fps: float = 1, # kani args token_reserve: int = 0, **hyperparams, ): """ :param model_id: The ID of the model to load from HuggingFace. :param max_context_size: The context size of the model. If not given, will be set from the model's config. :param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat format (see :class:`.PromptPipeline`). If not passed, uses the Hugging Face chat template if available. :param token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli. :param device: The hardware device to use. If not specified, uses CUDA or MPS if available; otherwise uses CPU. :param tokenizer_cls: Advanced use cases: The HF tokenizer class to use. Defaults to ``AutoProcessor`` (if no processing config is available or this raises an error, this will fall back to ``AutoTokenizer``). :param tokenizer_kwargs: Additional arguments to pass to ``AutoProcessor.from_pretrained()``. :param model_cls: Advanced use cases: The HF model class to use. Defaults to ``AutoModelForCausalLM``. :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. :param chat_template_reasoning_content_key: The key of each message dict that any reasoning content should be set at. :param chat_template_kwargs: The keyword arguments to pass to ``tokenizer.apply_chat_template`` if using a chat template prompt pipeline. :param mm_audio_sample_rate: The sample rate to remux audio inputs to. Check your model's documentation for the expected sample rate. By default, does not change the sample rate of the input file. :param mm_video_fps: The number of image frames to sample per second of video input. :param hyperparams: Additional arguments to supply the model during generation. :param token_reserve: DEPRECATED: The number of tokens to reserve for internal engine mechanisms (e.g. if there is a generation template after the last user message). If not passed, kani will attempt to infer this from a prompt pipeline. """ if tokenizer_kwargs is None: tokenizer_kwargs = {} if model_load_kwargs is None: model_load_kwargs = {} if chat_template_kwargs is None: chat_template_kwargs = {} tokenizer_kwargs.setdefault("token", hyperparams.get("use_auth_token", token)) model_load_kwargs.setdefault("token", hyperparams.pop("use_auth_token", token)) model_load_kwargs.setdefault("torch_dtype", "auto") if has_cuda and has_accelerate: model_load_kwargs.setdefault("device_map", "auto") self.model_id = model_id self.max_context_size = max_context_size # load the correct processor or tokenizer, EAFP if tokenizer_cls is None: try: _processor_or_tokenizer = AutoProcessor.from_pretrained(model_id, **tokenizer_kwargs) except Exception as e: log.warning( f"Could not load the AutoProcessor for {model_id}, falling back to AutoTokenizer. Multimodal" " inputs will not be available.", exc_info=e, ) _processor_or_tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) else: _processor_or_tokenizer = tokenizer_cls.from_pretrained(model_id, **tokenizer_kwargs) self._processor_or_tokenizer: ProcessorMixin | PreTrainedTokenizerBase = _processor_or_tokenizer self.model = model_cls.from_pretrained(model_id, **model_load_kwargs) self.hyperparams = hyperparams # multimodal args self.mm_audio_sample_rate = mm_audio_sample_rate self.mm_video_fps = mm_video_fps # load the pipeline if prompt_pipeline is None: # try and load a manual impl, or default to chat template if not available prompt_pipeline = model_specific.prompt_pipeline_for_hf_model( model_id, self._processor_or_tokenizer, chat_template_reasoning_content_key=chat_template_reasoning_content_key, chat_template_kwargs=chat_template_kwargs, ) self.pipeline = prompt_pipeline # ensure model is on correct device if device is None: if has_cuda: device = "cuda" elif has_mps: device = "mps" else: device = "cpu" log.info(f"Inferred device for model weights: {device}. Set `device=...` if this is incorrect.") self.device = device if self.model.device.type != self.device: self.model.to(device) # ensure model is in eval mode self.model.eval() # token counting stuff # try and infer max context size from the model config if not specified if self.max_context_size is None: self.max_context_size = getattr( self.model.config, "model_max_len", getattr(self.model.config, "max_position_embeddings", None), ) log.debug(f"Inferred max context size: {self.max_context_size}") if self.max_context_size is None: raise ValueError( "Could not infer the model's max context size from the config. Please pass the `max_context_size`" " arg." ) elif self.max_context_size > 1e20: warnings.warn( f"The inferred max context size of this model is extremely large ({self.max_context_size}). This" " may mean that the model has not configured their model_max_len correctly. Please pass the" " `max_context_size` arg to use the correct model size." ) # deprecated self._token_reserve = token_reserve @property def tokenizer(self) -> PreTrainedTokenizerBase: # little util to make sure we get a Tokenizer object when we need it # just in case we load a processor (for multimodal models) if isinstance(self._processor_or_tokenizer, ProcessorMixin): # noinspection PyUnresolvedReferences return self._processor_or_tokenizer.tokenizer return self._processor_or_tokenizer def _collect_multimodal(self, messages: list[ChatMessage]): """Collect a list of audios, images, videos from the given input prompt.""" if not _optional.has_multimodal_core: return None, None, None audios = [] images = [] videos = [] for msg in messages: for part in msg.parts: if isinstance(part, _optional.multimodal_core.AudioPart): audios.append(part.as_ndarray(sr=self.mm_audio_sample_rate)) elif isinstance(part, _optional.multimodal_core.ImagePart): images.append(part.image) elif isinstance(part, _optional.multimodal_core.VideoPart): videos.append(part.as_tensor(fps=self.mm_video_fps)) return audios or None, images or None, videos or None
[docs] def build_prompt( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None ) -> str | torch.Tensor | BatchEncoding | BatchFeature: """ Given the list of messages from kani, build either a single string representing the prompt for the model, or build the token tensor. The default behaviour is to call the supplied pipeline. """ if self.pipeline is None: raise NotImplementedError( "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." ) text = self.pipeline(messages, functions) log.debug(f"BUILT PROMPT TEXT: {text}") # if multimodal is installed and we have a processor, collect parts and run them through the processor is_processor = isinstance(self._processor_or_tokenizer, ProcessorMixin) if _optional.has_multimodal_core and is_processor: audios, images, videos = self._collect_multimodal(messages) inputs = self._processor_or_tokenizer( # should be a processor in this case text=text, audio=audios, images=images, videos=videos, add_special_tokens=False, return_tensors="pt", return_mm_token_type_ids=False, ) return inputs elif is_processor: # we want to pass return_mm_token_type_ids=False to a Processor only inputs = self._processor_or_tokenizer( text=text, add_special_tokens=False, return_tensors="pt", return_mm_token_type_ids=False, ) return inputs # otherwise run it through the tokenizer with just text inputs = self._processor_or_tokenizer(text=text, add_special_tokens=False, return_tensors="pt") return inputs
def _get_generate_args(self, prompt: str | torch.Tensor | BatchEncoding | BatchFeature, **hyperparams): """ Internal method to build common params for the generate call and also do some chores before we generate """ # make sure the prompt is tokenized if isinstance(prompt, str): # prompt str to tokens input_kwargs = self._processor_or_tokenizer(text=prompt, add_special_tokens=False, return_tensors="pt") input_len = input_kwargs["input_ids"].shape[1] elif isinstance(prompt, torch.Tensor): input_kwargs = BatchFeature({"input_ids": prompt}) input_len = len(prompt[0]) elif isinstance(prompt, (dict, UserDict)): input_kwargs = prompt input_len = input_kwargs["input_ids"].shape[1] else: raise TypeError( "build_prompt should either return a str, Tensor, or dict (e.g., BatchEncoding, BatchFeature)." ) # move the input tensor to the right device and make sure any multimodal features are in the right dtype # (if BatchFeature) input_kwargs.to(self.device) if isinstance(input_kwargs, BatchFeature): input_kwargs.to(self.model.dtype) # set up hyperparams for HF decode hyperparams = {**self.hyperparams, **hyperparams} if "max_new_tokens" not in hyperparams: hyperparams.setdefault("max_length", self.max_context_size) # check for a model-specific parser model_specific.warn_for_uninitialized_parser(self.model_id) return input_kwargs, input_len, hyperparams def _get_eos_tokens(self, *, return_ids=False, **hyperparams) -> list[str] | list[int]: """Get the list of tokens that should end a generation.""" if "eos_token_id" in hyperparams: genconfig_eos_token_id = hyperparams["eos_token_id"] elif getattr(self.tokenizer, "eos_token_id", None): genconfig_eos_token_id = self.tokenizer.eos_token_id elif hasattr(self.model, "text_config") and getattr(self.model.text_config, "eos_token_id", None): # multimodal models sometimes only have text_config and no generation_config (e.g. Qwen3.5) genconfig_eos_token_id = self.model.text_config.eos_token_id else: genconfig_eos_token_id = self.model.generation_config.eos_token_id if isinstance(genconfig_eos_token_id, list): eos_token_ids = genconfig_eos_token_id elif genconfig_eos_token_id is not None: eos_token_ids = [genconfig_eos_token_id] else: warnings.warn( f"No EOS token was found for the {self.model_id} model. Generation may continue forever. Please pass" " `eos_token_id=[...]` in the engine constructor." ) eos_token_ids = [] if return_ids: return eos_token_ids return [self._processor_or_tokenizer.decode(t) for t in eos_token_ids] # ==== kani impl ====
[docs] async def prompt_len(self, messages, functions=None, **kwargs) -> int: prompt = self.build_prompt(messages, functions) input_kwargs, input_len, _ = self._get_generate_args(prompt, **kwargs) return input_len
[docs] async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, *, decode_kwargs: dict = None, **hyperparams, ) -> Completion: """ Given the current context of messages and available functions, get the next predicted chat message from the LM. :param messages: The messages in the current chat context. ``prompt_len(messages, functions)`` is guaranteed to be less than max_context_size. :param functions: The functions the LM is allowed to call. :param decode_kwargs: Any arguments to pass to AutoTokenizer.decode(). :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See https://huggingface.co/docs/transformers/main_classes/text_generation) """ if decode_kwargs is None: decode_kwargs = {} prompt = self.build_prompt(messages, functions) input_kwargs, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) eos_tok_ids = self._get_eos_tokens(return_ids=True, **hyperparams) # run it through the model with torch.no_grad(): output = self.model.generate(**input_kwargs, **hyperparams) # decode to tokens # the completion shouldn't include the prompt or stop token if output[0][-1] in eos_tok_ids: content = self._processor_or_tokenizer.decode(output[0][input_len:-1], **decode_kwargs).strip() output_len = len(output[0]) - (input_len + 1) else: content = self._processor_or_tokenizer.decode(output[0][input_len:], **decode_kwargs).strip() output_len = len(output[0]) - input_len return Completion(ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=output_len)
[docs] async def stream( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, *, streamer_timeout: float | None = None, decode_kwargs: dict = None, **hyperparams, ) -> AsyncIterable[str | BaseCompletion]: """ Given the current context of messages and available functions, get the next predicted chat message from the LM. :param messages: The messages in the current chat context. ``prompt_len(messages, functions)`` is guaranteed to be less than max_context_size. :param functions: The functions the LM is allowed to call. :param streamer_timeout: The maximum number of seconds to wait for the next token when streaming. :param decode_kwargs: Any arguments to pass to AutoTokenizer.decode(). :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See https://huggingface.co/docs/transformers/main_classes/text_generation) """ if decode_kwargs is None: decode_kwargs = {} prompt = self.build_prompt(messages, functions) input_kwargs, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) eos_toks = self._get_eos_tokens(**hyperparams) streamer = TextIteratorStreamer( self._processor_or_tokenizer, skip_prompt=True, timeout=streamer_timeout, **decode_kwargs ) # run it through the model in another thread so that we can get the tokens in this thread loop = asyncio.get_event_loop() output_toks_fut = loop.create_future() def thread_target(): try: with torch.no_grad(): output = self.model.generate(**input_kwargs, streamer=streamer, **hyperparams) loop.call_soon_threadsafe(output_toks_fut.set_result, output) except Exception as e: loop.call_soon_threadsafe(output_toks_fut.set_exception, e) streamer.end() thread = Thread(target=thread_target) thread.start() # then wait for tokens from the task yielded_tokens = [] for token in streamer: for eos_tok in eos_toks: if token.endswith(eos_tok): token = token[: -len(eos_tok)] break if not token: continue yield token yielded_tokens.append(token) # clean up the thread output_toks = await output_toks_fut thread.join() # yield a completion with usage stats content = "".join(yielded_tokens) yield Completion( message=ChatMessage.assistant(content=content.strip()), prompt_tokens=input_len, completion_tokens=len(output_toks[0]) - (input_len + 1), )
# ===== deprecated ===== @property @deprecated("Use prompt_len instead") def token_reserve(self): # infer the token reserve from the pipeline if self._token_reserve == 0 and self.pipeline: self._token_reserve = self._infer_token_reserve() return self._token_reserve def _infer_token_reserve(self): """If token_reserve is not set and we have a pipeline, infer it.""" prompt = self.pipeline.execute([], for_measurement=True) if isinstance(prompt, torch.Tensor): return len(prompt[0]) # prompt str to tokens tokenized = self.tokenizer.encode(prompt, add_special_tokens=False) return len(tokenized)
[docs] @deprecated("Use prompt_len instead") def message_len(self, message: ChatMessage) -> int: """Return the length, in tokens, of the given chat message. The HuggingEngine's default implementation renders the message with ``apply_chat_template`` if no ``prompt_pipeline`` is supplied. """ # default concrete base behaviour: if self.pipeline is None: raise NotImplementedError( "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." ) prompt = self.pipeline.execute([message], for_measurement=True) if isinstance(prompt, torch.Tensor): return len(prompt[0]) # prompt str to tokens tokenized = self.tokenizer.encode(prompt, add_special_tokens=False) return len(tokenized)
[docs] @deprecated("Use prompt_len instead") def function_token_reserve(self, functions: list[AIFunction]) -> int: if not functions: return 0 # default concrete base behaviour: if self.pipeline is None: raise NotImplementedError( "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." ) prompt = self.pipeline.execute([], functions, for_measurement=True) if isinstance(prompt, torch.Tensor): toklen = len(prompt[0]) else: # prompt str to tokens tokenized = self.tokenizer.encode(prompt, add_special_tokens=False) toklen = len(tokenized) # warn if there are functions but no tokens if toklen == 0: warnings.warn( "Functions were given to the model, but the function prompt returned 0 tokens! This model may not" " support function calling, or you may need to implement" f" `{type(self).__name__}.function_token_reserve()`." ) return toklen