Source code for kani.model_specific.gpt_oss
import logging
import re
from kani.engines.huggingface.chat_template_pipeline import ChatTemplatePromptPipeline
from kani.models import FunctionCall, ToolCall
from kani.parts import ReasoningPart
from .base import BaseParser
from ..engines.base import BaseCompletion
SPECIAL_TOKEN_REGEX = re.compile(r"<\|(?P<type>\w+)\|>")
SPECIAL_TOKEN_REGEX_2 = re.compile(r"(<\|\w+\|>)")
TO_REGEX = re.compile(rf"to=([^\s<]+)")
CHANNEL_REGEX = re.compile(r"<\|channel\|>(\w+)")
ASST_MSG_REGEX = re.compile(
r"(<\|start\|>(?P<role>\w+))?"
r"(?P<header>.*?)"
r"<\|message\|>(?P<content>.*?)"
r"(<\|end\|>|<\|return\|>|<\|call\|>|$)",
re.DOTALL,
)
log = logging.getLogger(__name__)
# ===== PROMPT PIPELINE =====
def build_prompt_pipeline(tokenizer, **kwargs):
# set default chat_template_reasoning_content_key to "thinking"
kwargs["chat_template_reasoning_content_key"] = kwargs.get("chat_template_reasoning_content_key") or "thinking"
return ChatTemplatePromptPipeline(tokenizer, **kwargs)
# ===== OUTPUT PARSER =====
[docs]
class GPTOSSParser(BaseParser):
r"""
Automatically handles the parsing of GPT-OSS reasoning segments and tool calls.
Reasoning segments are returned as :class:`.ReasoningPart`\ s.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, tool_call_start_token=None, tool_call_end_token=None, **kwargs)
# state machine for stream on special token, regex for parse
def parse_completion(self, completion: BaseCompletion) -> BaseCompletion:
log.debug(f"PARSING MSG: {completion.message.text}")
parts = []
tcs = []
for match in ASST_MSG_REGEX.finditer(completion.message.text):
log.debug(f"PART: {match[0]}")
header = match["header"]
channel = c[1] if (c := CHANNEL_REGEX.search(header)) else None
to = t[1] if (t := TO_REGEX.search(header)) else None
content = match["content"]
if to and to.startswith("functions."):
tcs.append(
ToolCall.from_function_call(FunctionCall(name=to.removeprefix("functions."), arguments=content))
)
elif channel == "analysis":
parts.append(ReasoningPart(content=content))
else:
parts.append(content)
completion.message.content = parts
completion.message.tool_calls = tcs
return completion
async def stream(self, messages, functions=None, **hyperparams):
state = _GPTOSSStreamState(
show_reasoning=self.show_reasoning_in_stream, reasoning_color=self.reasoning_in_stream_color
)
async for elem in super().stream(messages, functions, **hyperparams):
if isinstance(elem, str):
for tokenlike in SPECIAL_TOKEN_REGEX_2.split(elem):
if not tokenlike:
continue
to_yield = state.feed(tokenlike)
if to_yield:
yield to_yield
else:
yield elem # probably the inner completion
class _GPTOSSStreamState:
def __init__(self, show_reasoning, reasoning_color):
# the last seen special token, e.g. "start", "channel", "constrain", "message"
# end sets this to None
# https://cookbook.openai.com/articles/openai-harmony#special-tokens
self.show_reasoning = show_reasoning
self.reasoning_color = reasoning_color
self.state = None
self.channel = None
self.to = None
self.buf = []
def feed(self, part: str):
# update the state machine
# new state
if match := SPECIAL_TOKEN_REGEX.fullmatch(part):
self.transition_states(match["type"])
# in message state and part is visible to user: yield it
elif self.is_visible_to_user():
if self.channel == "analysis" and self.reasoning_color:
return f"\033[0;37m{part}\033[0m"
return part
# default: keep buffering
else:
self.buf.append(part)
return None
def transition_states(self, new_state: str):
# we are about to transition states, handle the last state
buf_str = "".join(self.buf)
# check for to=... (in states None, start, or channel)
if self.state in (None, "start", "channel") and (match := TO_REGEX.search(buf_str)):
self.to = match[1]
# check if we are finishing a channel
if self.state == "channel":
self.channel, *_ = buf_str.split(" ", 1)
# if our new state is "end", clear the state
if new_state == "end":
self.channel = None
self.to = None
self.buf.clear()
self.state = new_state
log.debug(f"STREAM NEW STATE: {self!r}")
def is_visible_to_user(self):
# the content is visible to the user IFF:
# - state is "message"
# - channel is "final" or "commentary"
# - to is None
if self.show_reasoning:
return self.state == "message" and self.channel in ("final", "commentary", "analysis") and self.to is None
return self.state == "message" and self.channel in ("final", "commentary") and self.to is None
def __repr__(self):
return f"<_GPTOSSStreamState {self.state=} {self.channel=} {self.to=} {self.buf=}>"