feat: handle gemini tool call tags
This commit is contained in:
@@ -6,3 +6,4 @@
|
|||||||
{"type": "connection", "event": "disconnect", "ts": "2026-04-04T01:15:12.826042+00:00", "profile": "default"}
|
{"type": "connection", "event": "disconnect", "ts": "2026-04-04T01:15:12.826042+00:00", "profile": "default"}
|
||||||
{"type": "connection", "event": "connect", "ts": "2026-04-04T01:15:30.842533+00:00", "profile": "default"}
|
{"type": "connection", "event": "connect", "ts": "2026-04-04T01:15:30.842533+00:00", "profile": "default"}
|
||||||
{"type": "connection", "event": "hello", "details": {"version": "1.0"}, "ts": "2026-04-04T01:15:30.845025+00:00", "profile": "default"}
|
{"type": "connection", "event": "hello", "details": {"version": "1.0"}, "ts": "2026-04-04T01:15:30.845025+00:00", "profile": "default"}
|
||||||
|
{"type": "tool_call", "tool": "browser_stop", "params": {"profile": "gcu-browser-worker:3"}, "result": {"ok": true, "status": "not_running", "profile": "gcu-browser-worker:3"}, "ok": true, "duration_ms": 0.01, "ts": "2026-04-04T01:29:04.294954+00:00", "profile": "default"}
|
||||||
|
|||||||
@@ -458,6 +458,57 @@ def _is_stream_transient_error(exc: BaseException) -> bool:
|
|||||||
return isinstance(exc, transient_types)
|
return isinstance(exc, transient_types)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_tool_calls(
|
||||||
|
text: str,
|
||||||
|
) -> tuple[list["ToolCallEvent"], str]:
|
||||||
|
"""Extract hallucinated tool calls from ``<tool_code>`` blocks in LLM text.
|
||||||
|
|
||||||
|
Some models (notably Gemini) emit tool invocations as text instead of using
|
||||||
|
the structured function-calling API. This function parses those blocks and
|
||||||
|
returns ``(tool_call_events, cleaned_text)`` where *cleaned_text* has the
|
||||||
|
``<tool_code>`` blocks removed.
|
||||||
|
|
||||||
|
Expected format::
|
||||||
|
|
||||||
|
<tool_code>
|
||||||
|
{
|
||||||
|
"tool_name": { ...args }
|
||||||
|
}
|
||||||
|
</tool_code>
|
||||||
|
"""
|
||||||
|
from framework.llm.stream_events import ToolCallEvent
|
||||||
|
|
||||||
|
pattern = re.compile(r"<tool_code>\s*(.*?)\s*</tool_code>", re.DOTALL)
|
||||||
|
events: list[ToolCallEvent] = []
|
||||||
|
cleaned = text
|
||||||
|
|
||||||
|
for match in pattern.finditer(text):
|
||||||
|
raw = match.group(1).strip()
|
||||||
|
try:
|
||||||
|
payload = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("[_extract_text_tool_calls] failed to parse JSON: %s", raw[:200])
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tool_name, tool_args in payload.items():
|
||||||
|
call_id = f"synth_{hashlib.md5(f'{tool_name}:{json.dumps(tool_args, sort_keys=True)}'.encode()).hexdigest()[:12]}"
|
||||||
|
events.append(
|
||||||
|
ToolCallEvent(
|
||||||
|
tool_use_id=call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_args if isinstance(tool_args, dict) else {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if events:
|
||||||
|
cleaned = pattern.sub("", text).strip()
|
||||||
|
|
||||||
|
return events, cleaned
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
"""
|
"""
|
||||||
LiteLLM-based LLM provider for multi-provider support.
|
LiteLLM-based LLM provider for multi-provider support.
|
||||||
@@ -1918,6 +1969,35 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
f"(last_role={last_role}). Returning empty result."
|
f"(last_role={last_role}). Returning empty result."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Gemini sometimes outputs tool calls as text in
|
||||||
|
# <tool_code>{"name": {...args}}</tool_code> blocks
|
||||||
|
# instead of using the function-calling API. Extract
|
||||||
|
# these as real ToolCallEvents and strip them from the
|
||||||
|
# text so the rest of the system treats them normally.
|
||||||
|
if accumulated_text and "<tool_code>" in accumulated_text:
|
||||||
|
extracted, cleaned = _extract_text_tool_calls(accumulated_text)
|
||||||
|
if extracted:
|
||||||
|
logger.info(
|
||||||
|
"[stream] extracted %d hallucinated tool call(s) from text",
|
||||||
|
len(extracted),
|
||||||
|
)
|
||||||
|
accumulated_text = cleaned
|
||||||
|
# Emit a corrected TextDeltaEvent so the caller's
|
||||||
|
# accumulated_text is overwritten with the cleaned text.
|
||||||
|
yield TextDeltaEvent(content="", snapshot=cleaned)
|
||||||
|
# Insert synthetic ToolCallEvents before FinishEvent.
|
||||||
|
finish_idx = next(
|
||||||
|
(i for i, ev in enumerate(tail_events) if isinstance(ev, FinishEvent)),
|
||||||
|
len(tail_events),
|
||||||
|
)
|
||||||
|
for tc_ev in reversed(extracted):
|
||||||
|
tail_events.insert(finish_idx, tc_ev)
|
||||||
|
# Update TextEndEvent if present.
|
||||||
|
for _i, _ev in enumerate(tail_events):
|
||||||
|
if isinstance(_ev, TextEndEvent):
|
||||||
|
tail_events[_i] = TextEndEvent(full_text=cleaned)
|
||||||
|
break
|
||||||
|
|
||||||
# Success (or empty after exhausted retries) — flush events.
|
# Success (or empty after exhausted retries) — flush events.
|
||||||
for event in tail_events:
|
for event in tail_events:
|
||||||
yield event
|
yield event
|
||||||
|
|||||||
Reference in New Issue
Block a user