Source code for kani.prompts.base
import abc
from typing import Iterable
from kani.ai_function import AIFunction
from kani.models import ChatRole
from kani.prompts.types import PipelineMsgT, PredicateFilterT, RoleFilterT
[docs]
class PipelineStep(abc.ABC):
"""
The base class for all pipeline steps.
If needed, you can subclass this and manually add steps to a :class:`.PromptPipeline`, but this is generally not
necessary (consider using :meth:`.PromptPipeline.apply` instead).
"""
[docs]
def execute(self, msgs: list[PipelineMsgT], functions: list[AIFunction]):
"""Apply this step's effects on the pipeline."""
raise NotImplementedError
[docs]
def explain(self) -> str:
"""Return a string explaining what this step does."""
raise NotImplementedError
[docs]
def explain_example_kwargs(self) -> dict[str, bool]:
"""Return a dict of kwargs to pass to examples.build_conversation to ensure relevant examples are included."""
return {}
def __repr__(self):
attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
return f"{type(self).__name__}({attrs})"
class FilterMixin:
"""helper mixin to implement filtering operations"""
def __init__(self, role: RoleFilterT = None, predicate: PredicateFilterT = None):
self.role = role
self.predicate = predicate
def filtered(self, msgs: list[PipelineMsgT]) -> Iterable[PipelineMsgT]:
"""Yield all messages that match the filter"""
for msg in msgs:
if self.matches_filter(msg):
yield msg
def matches_filter(self, msg: PipelineMsgT) -> bool:
"""Whether or not a message matches the filter"""
# role(s)
if not self.matches_role(msg.role):
return False
# predicate
if self.predicate is not None and not self.predicate(msg):
return False
# default
return True
def matches_role(self, role: ChatRole) -> bool:
"""Whether or not this filter unconditionally matches the given role (only checks *role*, not *predicate*)"""
if isinstance(self.role, ChatRole):
return role == self.role
elif self.role:
return role in self.role
return True
# explain helpers
def explain_note(self, join_sep="or", plural=True) -> str:
"""Returns a short note with the conditions this step applies to.
(e.g. "messages" or "system messages that match the given predicate")
:param join_sep: If the filter applies to more than one role, what word to use to join the role names
:param plural: e.g. "each message" vs "all messages"
"""
out = "messages" if plural else "message"
# role(s)
if isinstance(self.role, ChatRole):
out = f"{self.role.value} {out}"
elif self.role:
msgs = natural_join([r.value for r in self.role], join_sep)
out = f"{msgs} {out}"
# predicate
if self.predicate is not None:
out += f" that {'match' if plural else 'matches'} the given predicate"
return out
# by default, let's include function call if any filtered step targets functions
def explain_example_kwargs(self) -> dict[str, bool]:
if self.matches_role(ChatRole.FUNCTION):
return {"function_call": True}
return {}
def natural_join(elems: list[str], sep: str):
sep = f" {sep} "
if len(elems) < 3:
return sep.join(elems)
return ", ".join(elems[:-1]) + f",{sep}{elems[-1]}"