Source code for kani.streaming

import asyncio
import contextlib
from collections.abc import AsyncIterable

from kani.engines.base import BaseCompletion, Completion
from kani.models import ChatMessage, ChatRole


[docs] class StreamManager: """ This class is responsible for managing a stream returned by an engine. It should not be constructed manually. To consume tokens from a stream, use this class as so:: # CHAT ROUND: stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") async for token in stream: print(token, end="") msg = await stream.message() # FULL ROUND: async for stream in ai.full_round_stream("What is the airspeed velocity of an unladen swallow?") async for token in stream: print(token, end="") msg = await stream.message() After a stream finishes, its contents will be available as a :class:`.ChatMessage`. You can retrieve the final message or :class:`.BaseCompletion` with:: msg = await stream.message() completion = await stream.completion() The final :class:`.ChatMessage` may contain non-yielded tokens (e.g. a request for a function call). If the final message or completion is requested before the stream is iterated over, the stream manager will consume the entire stream. .. tip:: For compatibility and ease of refactoring, awaiting the stream itself will also return the message, i.e.: .. code-block:: python msg = await ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") (note the ``await`` that is not present in the above examples). """ def __init__( self, stream_iter: AsyncIterable[str | BaseCompletion], role: ChatRole, *, after=None, lock: asyncio.Lock = None ): """ :param stream_iter: The async iterable that generates elements of the stream. :param role: The role of the message that will be returned eventually. :param after: A coro to call with the generated completion as its argument after the stream is fully consumed. :param lock: A lock to hold for the duration of the stream run. """ self.role = role """The role of the message that this stream will return.""" # private self._stream_iter = stream_iter self._after = after self._lock = lock if lock is not None else contextlib.nullcontext() # results self._completion = None # the final completion result self._awaited = False # whether or not this stream has already been consumed self._finished = asyncio.Event() # whether or not the stream has finished # ==== stream components ==== async def _stream_impl_outer(self): # simple wrapper to lock the lock too, but *async* async with self._lock: async for elem in self._stream_impl(): yield elem async def _stream_impl(self): """ Wrap the underlying stream iterable to handle mixed yield types and build completion when the stream finishes without setting the completion. """ yielded_tokens = [] # for each token or completion yielded by the engine, async for elem in self._stream_iter: if self._completion is not None: raise RuntimeError( "Expected `BaseCompletion` to be final yield of stream iterable but got another value after!" ) # re-yield if str if isinstance(elem, str): yield elem yielded_tokens.append(elem) # save if completion elif isinstance(elem, BaseCompletion): self._completion = elem # panic otherwise else: raise TypeError( "Expected yielded value from stream iterable to be `str` or `BaseCompletion` but got" f" {type(elem)!r}!" ) # if the stream is complete but we did not get a completion, we'll construct one here as the concatenation # of all the yielded tokens if self._completion is None: content = "".join(yielded_tokens) self._completion = Completion(message=ChatMessage(role=self.role, content=content.strip())) # run the callback, if any if self._after is not None: await self._after(self._completion) # allow anything waiting on the stream to finish to progress self._finished.set()
[docs] def __aiter__(self) -> AsyncIterable[str]: """Iterate over tokens yielded from the engine.""" # enforce that it can only be iterated over once if self._awaited: raise RuntimeError( "This stream has already been consumed. If you are consuming both the stream and the final Completion," " make sure you iterate over the stream first." ) self._awaited = True # delegate to a wrapper for an async context return self._stream_impl_outer()
# ==== final result getters ====
[docs] def __await__(self): """Awaiting the StreamManager is equivalent to awaiting :meth:`message`.""" return self.message().__await__()
[docs] async def completion(self) -> BaseCompletion: """Get the final :class:`.BaseCompletion` generated by the model.""" # if we are getting the completion but no one has consumed our stream yet, just dummy do it so we build # the completion if not self._awaited: async for _ in self: pass # otherwise, wait for the stream to be complete then return the saved completion await self._finished.wait() return self._completion
[docs] async def message(self) -> ChatMessage: """Get the final :class:`.ChatMessage` generated by the model.""" completion = await self.completion() return completion.message
class DummyStream(StreamManager): """Function calling helper: we already have the message.""" def __init__(self, message: ChatMessage): # init a dummy iterable async def _iter(): if message.content is not None: yield message.text yield Completion(message) super().__init__(_iter(), role=message.role) self._message = message async def message(self) -> ChatMessage: return self._message