"""The CLI utilities allow you to play with a chat session directly from a terminal."""
import asyncio
import importlib
import logging
import os
import pkgutil
import textwrap
from concurrent.futures import Future
from threading import Thread
from typing import AsyncIterable, Callable, Iterable, NamedTuple, Sequence, overload
from kani import _optional, model_specific
from kani.engines import BaseEngine
from kani.kani import Kani
from kani.models import ChatRole
from kani.streaming import StreamManager
from kani.utils.message_formatters import assistant_message_contents_thinking, assistant_message_thinking
[docs]
async def chat_in_terminal_async(
kani: Kani,
*,
rounds: int = 0,
stopword: str = None,
echo: bool = False,
ai_first: bool = False,
width: int = None,
show_function_args: bool = False,
show_function_returns: bool = False,
verbose: bool = False,
stream: bool = True,
):
"""Async version of :func:`.chat_in_terminal`.
Use in environments when there is already an asyncio loop running (e.g. Google Colab).
"""
if os.getenv("KANI_DEBUG") is not None:
logging.basicConfig(level=logging.DEBUG)
elif os.getenv("KANI_DEBUG_LOGGERS") is not None:
logging.basicConfig(level=logging.INFO)
for logger in os.getenv("KANI_DEBUG_LOGGERS").split(","):
logging.getLogger(logger).setLevel(logging.DEBUG)
if verbose:
echo = show_function_args = show_function_returns = True
try:
round_num = 0
while round_num < rounds or not rounds:
round_num += 1
# get user query
if not ai_first or round_num > 0:
query = await ainput("USER: ")
query = query_str = query.strip()
# multimodal handling
if _optional.has_multimodal_core:
# find @path/to/file.png parts and replace them with FileImageParts
query = await _optional.multimodal_cli.parts_from_cli_query(query)
# stopword
if stopword and query == stopword:
break
# echo & multimodal echo
if _optional.has_multimodal_core:
# IPython
if _optional.multimodal_cli._is_notebook:
_optional.multimodal_cli.display_media_ipython(query, show_text=echo)
else:
_optional.multimodal_cli.display_media(query, show_text=echo)
elif echo:
print_width(query_str, width=width, prefix="USER: ")
# print completion(s)
else:
query = None
# print completion(s)
if stream:
async for stream in kani.full_round_stream(query):
# assistant
if stream.role == ChatRole.ASSISTANT:
await print_stream(stream, width=width, prefix="AI: ")
msg = await stream.message()
text = assistant_message_thinking(msg, show_args=show_function_args)
if text:
print_width(text, width=width, prefix="AI: ")
# function
elif stream.role == ChatRole.FUNCTION and show_function_returns:
msg = await stream.message()
print_width(msg.text, width=width, prefix="FUNC: ")
# completions only
else:
async for msg in kani.full_round(query):
# assistant
if msg.role == ChatRole.ASSISTANT:
text = assistant_message_contents_thinking(msg, show_args=show_function_args)
print_width(text, width=width, prefix="AI: ")
# function
elif msg.role == ChatRole.FUNCTION and show_function_returns:
print_width(msg.text, width=width, prefix="FUNC: ")
except (KeyboardInterrupt, asyncio.CancelledError):
# we won't close the engine here since it's common enough that people close the session in colab
# and if the process is closing then this will clean itself up anyway
# await kani.engine.close()
return
@overload
def chat_in_terminal(
kani: Kani,
*,
rounds: int = 0,
stopword: str = None,
echo: bool = False,
ai_first: bool = False,
width: int = None,
show_function_args: bool = False,
show_function_returns: bool = False,
verbose: bool = False,
stream: bool = True,
): ...
[docs]
def chat_in_terminal(kani: Kani, **kwargs):
"""Chat with a kani right in your terminal.
Useful for playing with kani, quick prompt engineering, or demoing the library.
If the environment variable ``KANI_DEBUG`` is set, debug logging will be enabled.
If ``kani-multimodal-core`` is installed, you can send multimodal media to a compatible engine with a file path
or URL after an ``@`` symbol (e.g. "Describe this image: @image.png").
Use quotes (e.g. ``@"path/to/my image.png"``) for paths with spaces in their names.
.. warning::
This function is only a development utility and should not be used in production.
:param int rounds: The number of chat rounds to play (defaults to 0 for infinite).
:param str stopword: Break out of the chat loop if the user sends this message.
:param bool echo: Whether to echo the user's input to stdout after they send a message (e.g. to save in interactive
notebook outputs; default false)
:param bool ai_first: Whether the user should send the first message (default) or the model should generate a
completion before prompting the user for a message.
:param int width: The maximum width of the printed outputs (default unlimited).
:param bool show_function_args: Whether to print the arguments the model is calling functions with for each call
(default false).
:param bool show_function_returns: Whether to print the results of each function call (default false).
:param bool verbose: Equivalent to setting ``echo``, ``show_function_args``, and ``show_function_returns`` to True.
:param bool stream: Whether or not to print tokens as soon as they are generated by the model (default true).
"""
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
try:
# google colab comes with this pre-installed
# let's try importing and patching the loop so that we can just use the normal asyncio.run call
import nest_asyncio
nest_asyncio.apply()
except ImportError:
print(
f"WARNING: It looks like you're in an environment with a running asyncio loop (e.g. Google Colab).\nYou"
f" should use `await chat_in_terminal_async(...)` instead or install `nest-asyncio`."
)
return
try:
asyncio.run(chat_in_terminal_async(kani, **kwargs))
except KeyboardInterrupt:
return
# ===== format helpers =====
async def strip_stream(stream: StreamManager) -> AsyncIterable[str]:
"""
Strip any leading or trailing whitespace from a stream.
"""
last_trailing_ws = ""
first = True
async for token in stream:
# lstrip until we have a token with non-ws
if first:
token = token.lstrip()
if token:
first = False
# buffer any trailing whitespace in each token and yield it with the next token
stripped_token = token.rstrip()
trailing_ws = token[len(stripped_token) :]
if stripped_token:
yield last_trailing_ws + stripped_token
last_trailing_ws = trailing_ws
else:
last_trailing_ws += trailing_ws
[docs]
def print_width(msg: str, width: int = None, prefix: str = ""):
"""
Print the given message such that the width of each line is less than *width*.
If *prefix* and *width* are provided, indents each line after the first by the length of the prefix.
.. code-block: pycon
>>> print_width("Hello world I am a potato", width=15, prefix="USER: ")
USER: Hello
world I
am a
potato
"""
print(format_width(msg, width, prefix))
[docs]
async def print_stream(stream: StreamManager, width: int = None, prefix: str = ""):
"""
Print tokens from a stream to the terminal, with the width of each line less than *width*.
If *prefix* and *width* are provided, indents each line after the first by the length of the prefix.
This is a helper function intended to be used with :meth:`.Kani.chat_round_stream` or
:meth:`.Kani.full_round_stream`.
"""
has_printed = False
async for part in format_stream(stream, width, prefix):
print(part, end="", flush=True)
has_printed = True
# newline at the end to flush if we printed anything
if has_printed:
print()
async def ainput(string: str) -> str:
"""input(), but async."""
# doing just .to_thread causes problems when we ^C, so we need to launch our own daemon thread to handle reading
# input
future = Future()
future.set_running_or_notify_cancel()
def daemon():
try:
result = input(string)
except Exception as e:
future.set_exception(e)
else:
future.set_result(result)
Thread(target=daemon, daemon=True).start()
return await asyncio.wrap_future(future)
# ==== CLI engine defs ====
def chat_openai(model_id: str):
from kani.engines.openai import OpenAIEngine
return OpenAIEngine(model=model_id)
def chat_anthropic(model_id: str):
from kani.engines.anthropic import AnthropicEngine
return AnthropicEngine(model=model_id)
def chat_google(model_id: str):
from kani.engines.google import GoogleAIEngine
return GoogleAIEngine(model=model_id)
def chat_huggingface(model_id: str):
from kani.engines.huggingface import HuggingEngine
engine = HuggingEngine(model_id=model_id)
# HF: wrap in model-specific parser if available
if parser := model_specific.parser_for_hf_model(engine.model_id):
return parser(engine, show_reasoning_in_stream=True)
return engine
# ---- CLI registry ----
class CLIProvider(NamedTuple): # this is a NamedTuple so we don't need to import a special type in external pkgs
name: str
aliases: Sequence[str]
entrypoint: Callable[[str], BaseEngine]
CLI_PROVIDERS = [
# openai
CLIProvider(name="openai", aliases=["oai"], entrypoint=chat_openai),
# anthropic
CLIProvider(name="anthropic", aliases=["ant", "claude"], entrypoint=chat_anthropic),
# google
CLIProvider(name="google", aliases=["g", "gemini"], entrypoint=chat_google),
# huggingface
CLIProvider(name="huggingface", aliases=["hf"], entrypoint=chat_huggingface),
]
"""
Default CLI providers. Extension packages can define these in __init__.py as a 3-tuple of (name, aliases, factory) and
they will be automatically discovered.
"""
def get_cli_providers_including_extensions() -> Iterable[tuple[str, Sequence[str], Callable[[str], BaseEngine]]]:
"""
Yield all possible CLI provider 3-tuples. Imports any extension packages to check for CLI_PROVIDERS and yields from
them too.
"""
yield from CLI_PROVIDERS
try:
import kani.ext
for finder, name, ispkg in pkgutil.iter_modules(kani.ext.__path__, kani.ext.__name__ + "."):
try:
mod = importlib.import_module(name)
mod_cli_providers = getattr(mod, "CLI_PROVIDERS", None)
if mod_cli_providers is not None:
yield from mod_cli_providers
except ImportError:
pass
except ImportError:
pass
def fmt_cli_providers() -> str:
"""Return a list of available CLI providers and their aliases"""
out = []
for name, aliases, entrypoint in get_cli_providers_including_extensions():
if aliases:
out.append(f"* {name} (aliases: {', '.join(aliases)})")
else:
out.append(f"* {name}")
return "\n".join(out)
def create_engine_from_cli_arg(arg: str):
"""Create an engine instance from a CLI arg <provider>:<model_id>"""
provider, model_id = arg.split(":", 1)
for name, aliases, entrypoint in get_cli_providers_including_extensions():
if provider == name or provider in aliases:
return entrypoint(model_id)
raise ValueError(f"Invalid model provider: {provider!r}. Valid options:\n{fmt_cli_providers()}")