Source code for kani.utils.cli

"""The CLI utilities allow you to play with a chat session directly from a terminal."""

import asyncio
import importlib
import logging
import os
import pkgutil
import textwrap
from concurrent.futures import Future
from threading import Thread
from typing import AsyncIterable, Callable, Iterable, NamedTuple, Sequence, overload

from kani import _optional, model_specific
from kani.engines import BaseEngine
from kani.kani import Kani
from kani.models import ChatRole
from kani.streaming import StreamManager
from kani.utils.message_formatters import assistant_message_contents_thinking, assistant_message_thinking


[docs] async def chat_in_terminal_async( kani: Kani, *, rounds: int = 0, stopword: str = None, echo: bool = False, ai_first: bool = False, width: int = None, show_function_args: bool = False, show_function_returns: bool = False, verbose: bool = False, stream: bool = True, ): """Async version of :func:`.chat_in_terminal`. Use in environments when there is already an asyncio loop running (e.g. Google Colab). """ if os.getenv("KANI_DEBUG") is not None: logging.basicConfig(level=logging.DEBUG) elif os.getenv("KANI_DEBUG_LOGGERS") is not None: logging.basicConfig(level=logging.INFO) for logger in os.getenv("KANI_DEBUG_LOGGERS").split(","): logging.getLogger(logger).setLevel(logging.DEBUG) if verbose: echo = show_function_args = show_function_returns = True try: round_num = 0 while round_num < rounds or not rounds: round_num += 1 # get user query if not ai_first or round_num > 0: query = await ainput("USER: ") query = query_str = query.strip() # multimodal handling if _optional.has_multimodal_core: # find @path/to/file.png parts and replace them with FileImageParts query = await _optional.multimodal_cli.parts_from_cli_query(query) # stopword if stopword and query == stopword: break # echo & multimodal echo if _optional.has_multimodal_core: # IPython if _optional.multimodal_cli._is_notebook: _optional.multimodal_cli.display_media_ipython(query, show_text=echo) else: _optional.multimodal_cli.display_media(query, show_text=echo) elif echo: print_width(query_str, width=width, prefix="USER: ") # print completion(s) else: query = None # print completion(s) if stream: async for stream in kani.full_round_stream(query): # assistant if stream.role == ChatRole.ASSISTANT: await print_stream(stream, width=width, prefix="AI: ") msg = await stream.message() text = assistant_message_thinking(msg, show_args=show_function_args) if text: print_width(text, width=width, prefix="AI: ") # function elif stream.role == ChatRole.FUNCTION and show_function_returns: msg = await stream.message() print_width(msg.text, width=width, prefix="FUNC: ") # completions only else: async for msg in kani.full_round(query): # assistant if msg.role == ChatRole.ASSISTANT: text = assistant_message_contents_thinking(msg, show_args=show_function_args) print_width(text, width=width, prefix="AI: ") # function elif msg.role == ChatRole.FUNCTION and show_function_returns: print_width(msg.text, width=width, prefix="FUNC: ") except (KeyboardInterrupt, asyncio.CancelledError): # we won't close the engine here since it's common enough that people close the session in colab # and if the process is closing then this will clean itself up anyway # await kani.engine.close() return
@overload def chat_in_terminal( kani: Kani, *, rounds: int = 0, stopword: str = None, echo: bool = False, ai_first: bool = False, width: int = None, show_function_args: bool = False, show_function_returns: bool = False, verbose: bool = False, stream: bool = True, ): ...
[docs] def chat_in_terminal(kani: Kani, **kwargs): """Chat with a kani right in your terminal. Useful for playing with kani, quick prompt engineering, or demoing the library. If the environment variable ``KANI_DEBUG`` is set, debug logging will be enabled. If ``kani-multimodal-core`` is installed, you can send multimodal media to a compatible engine with a file path or URL after an ``@`` symbol (e.g. "Describe this image: @image.png"). Use quotes (e.g. ``@"path/to/my image.png"``) for paths with spaces in their names. .. warning:: This function is only a development utility and should not be used in production. :param int rounds: The number of chat rounds to play (defaults to 0 for infinite). :param str stopword: Break out of the chat loop if the user sends this message. :param bool echo: Whether to echo the user's input to stdout after they send a message (e.g. to save in interactive notebook outputs; default false) :param bool ai_first: Whether the user should send the first message (default) or the model should generate a completion before prompting the user for a message. :param int width: The maximum width of the printed outputs (default unlimited). :param bool show_function_args: Whether to print the arguments the model is calling functions with for each call (default false). :param bool show_function_returns: Whether to print the results of each function call (default false). :param bool verbose: Equivalent to setting ``echo``, ``show_function_args``, and ``show_function_returns`` to True. :param bool stream: Whether or not to print tokens as soon as they are generated by the model (default true). """ try: asyncio.get_running_loop() except RuntimeError: pass else: try: # google colab comes with this pre-installed # let's try importing and patching the loop so that we can just use the normal asyncio.run call import nest_asyncio nest_asyncio.apply() except ImportError: print( f"WARNING: It looks like you're in an environment with a running asyncio loop (e.g. Google Colab).\nYou" f" should use `await chat_in_terminal_async(...)` instead or install `nest-asyncio`." ) return try: asyncio.run(chat_in_terminal_async(kani, **kwargs)) except KeyboardInterrupt: return
# ===== format helpers =====
[docs] def format_width(msg: str, width: int = None, prefix: str = "") -> str: """ Format the given message such that the width of each line is less than *width*. If *prefix* and *width* are provided, indents each line after the first by the length of the prefix. .. code-block: pycon >>> format_width("Hello world I am a potato", width=15, prefix="USER: ") '''\ USER: Hello world I am a potato\ ''' """ if not width: return prefix + msg out = [] wrapper = textwrap.TextWrapper(width=width, initial_indent=prefix, subsequent_indent=" " * len(prefix)) lines = msg.splitlines() for line in lines: out.append(wrapper.fill(line)) wrapper.initial_indent = wrapper.subsequent_indent return "\n".join(out)
async def strip_stream(stream: StreamManager) -> AsyncIterable[str]: """ Strip any leading or trailing whitespace from a stream. """ last_trailing_ws = "" first = True async for token in stream: # lstrip until we have a token with non-ws if first: token = token.lstrip() if token: first = False # buffer any trailing whitespace in each token and yield it with the next token stripped_token = token.rstrip() trailing_ws = token[len(stripped_token) :] if stripped_token: yield last_trailing_ws + stripped_token last_trailing_ws = trailing_ws else: last_trailing_ws += trailing_ws
[docs] async def format_stream(stream: StreamManager, width: int = None, prefix: str = "") -> AsyncIterable[str]: """ Yield formatted tokens from a stream such that if concatenated, the width of each line is less than *width*. If *prefix* and *width* are provided, indents each line after the first by the length of the prefix. """ prefix_len = len(prefix) line_indent = (" " * prefix_len) if width else "" prefix_printed = False # print tokens until they overflow width then newline and indent line_len = prefix_len async for token in strip_stream(stream): # only print the prefix if the model actually yields anything if not prefix_printed: yield prefix prefix_printed = True # split by newlines for part in token.splitlines(keepends=True): # then do bookkeeping part_len = len(part) if width and line_len + part_len > width: yield f"\n{line_indent}" line_len = prefix_len # print the token yield part.rstrip("\r\n") line_len += part_len # print a newline if the token had one if part.endswith("\n"): yield f"\n{line_indent}" line_len = prefix_len
async def ainput(string: str) -> str: """input(), but async.""" # doing just .to_thread causes problems when we ^C, so we need to launch our own daemon thread to handle reading # input future = Future() future.set_running_or_notify_cancel() def daemon(): try: result = input(string) except Exception as e: future.set_exception(e) else: future.set_result(result) Thread(target=daemon, daemon=True).start() return await asyncio.wrap_future(future) # ==== CLI engine defs ==== def chat_openai(model_id: str): from kani.engines.openai import OpenAIEngine return OpenAIEngine(model=model_id) def chat_anthropic(model_id: str): from kani.engines.anthropic import AnthropicEngine return AnthropicEngine(model=model_id) def chat_google(model_id: str): from kani.engines.google import GoogleAIEngine return GoogleAIEngine(model=model_id) def chat_huggingface(model_id: str): from kani.engines.huggingface import HuggingEngine engine = HuggingEngine(model_id=model_id) # HF: wrap in model-specific parser if available if parser := model_specific.parser_for_hf_model(engine.model_id): return parser(engine, show_reasoning_in_stream=True) return engine # ---- CLI registry ---- class CLIProvider(NamedTuple): # this is a NamedTuple so we don't need to import a special type in external pkgs name: str aliases: Sequence[str] entrypoint: Callable[[str], BaseEngine] CLI_PROVIDERS = [ # openai CLIProvider(name="openai", aliases=["oai"], entrypoint=chat_openai), # anthropic CLIProvider(name="anthropic", aliases=["ant", "claude"], entrypoint=chat_anthropic), # google CLIProvider(name="google", aliases=["g", "gemini"], entrypoint=chat_google), # huggingface CLIProvider(name="huggingface", aliases=["hf"], entrypoint=chat_huggingface), ] """ Default CLI providers. Extension packages can define these in __init__.py as a 3-tuple of (name, aliases, factory) and they will be automatically discovered. """ def get_cli_providers_including_extensions() -> Iterable[tuple[str, Sequence[str], Callable[[str], BaseEngine]]]: """ Yield all possible CLI provider 3-tuples. Imports any extension packages to check for CLI_PROVIDERS and yields from them too. """ yield from CLI_PROVIDERS try: import kani.ext for finder, name, ispkg in pkgutil.iter_modules(kani.ext.__path__, kani.ext.__name__ + "."): try: mod = importlib.import_module(name) mod_cli_providers = getattr(mod, "CLI_PROVIDERS", None) if mod_cli_providers is not None: yield from mod_cli_providers except ImportError: pass except ImportError: pass def fmt_cli_providers() -> str: """Return a list of available CLI providers and their aliases""" out = [] for name, aliases, entrypoint in get_cli_providers_including_extensions(): if aliases: out.append(f"* {name} (aliases: {', '.join(aliases)})") else: out.append(f"* {name}") return "\n".join(out) def create_engine_from_cli_arg(arg: str): """Create an engine instance from a CLI arg <provider>:<model_id>""" provider, model_id = arg.split(":", 1) for name, aliases, entrypoint in get_cli_providers_including_extensions(): if provider == name or provider in aliases: return entrypoint(model_id) raise ValueError(f"Invalid model provider: {provider!r}. Valid options:\n{fmt_cli_providers()}")