执行摘要
- 一句话:修复 MistralToolParser 流式 tool call 数组越界
- 推荐动作:值得合入并关注后续是否有类似问题的回归报告。合并后建议有人 cherry-pick 到相关 release 分支。
功能与动机
修复 Issue #33916:在使用 --tool-call-parser=mistral 进行流式 tool call 时,出现 IndexError: list index out of range。根本原因是 MistralToolParser 未正确填充 streamed_args_for_tool 和 prev_tool_call_arr,使得 serving 层在判断 finish_reason 时试图访问不存在的索引。
实现拆解
- 移除硬编码 hack:在
_extract_tool_calls_streaming 方法中,删除原先无条件设置 self.prev_tool_call_arr = [{"arguments": {}}] 的代码段。该 hack 无法正确对应多个 tool call 的场景。
- 在 tool call 开始时追加列表项:在
_generate_delta_tool_call 方法中检测到新 tool call 开始时(PARSING_NAME 状态),追加 self.streamed_args_for_tool.append("") 和 self.prev_tool_call_arr.append({}),确保列表长度与 tool call 数量一致。
- 在解析过程中动态更新:在
_generate_delta_tool_call 的 PARSING_NAME 分支中,当工具名解析完成时,写入 prev_tool_call_arr[current_tool_id]["name"];在 PARSING_ARGUMENTS 分支中,累计 delta_arguments 到 streamed_args_for_tool[current_tool_id],并同步更新 prev_tool_call_arr[current_tool_id]["arguments"]。
- 同步 pre-v11 tokenizer 路径:在
_extract_tool_calls_streaming_pre_v11_tokenizer 方法中,于新 tool call 创建时追加列表项,在名称解析完成后更新 name,并调用新增的 _track_streamed_args_pre_v11 方法(与上述逻辑类似)来处理参数累积。
- 更新测试断言:在
test_tool_parsers/test_mistral_tool_parser.py 的通用测试辅助函数 _test_extract_tool_calls_streaming 中,新增断言验证 streamed_args_for_tool 和 prev_tool_call_arr 长度与元素一致性,并模拟 serving 层的 unstreamed-args 检查逻辑。同时,在 test_extract_tool_calls_streaming_v11_no_tools 中添加断言确保无 tool call 时列表为空。
关键文件:
vllm/tool_parsers/mistral_tool_parser.py(模块 工具解析器;类别 source;类型 core-logic;符号 _generate_delta_tool_call, _extract_tool_calls_streaming, _extract_tool_calls_streaming_pre_v11_tokenizer, _track_streamed_args_pre_v11): 核心修复文件,修改了流式 tool call 解析的列表追加和更新逻辑,移除了硬编码 hack。
tests/tool_parsers/test_mistral_tool_parser.py(模块 工具解析器;类别 test;类型 test-coverage): 新增内部状态一致性断言,覆盖有 tool call 和无 tool call 的验证场景。
关键符号:_generate_delta_tool_call, _extract_tool_calls_streaming, _extract_tool_calls_streaming_pre_v11_tokenizer, _track_streamed_args_pre_v11
关键源码片段
vllm/tool_parsers/mistral_tool_parser.py
核心修复文件,修改了流式 tool call 解析的列表追加和更新逻辑,移除了硬编码 hack。
# vllm/tool_parsers/mistral_tool_parser.py # 关键变更片段
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
# ... 前略 ...
if self.streaming_state not in [
StreamingState.PARSING_NAME,
StreamingState.PARSING_ARGUMENTS,
] and delta_text.startswith(self.bot_token):
self.current_tool_id += 1
# 新增:在每个新 tool call 开始时,追加空记录
self.streamed_args_for_tool.append("")
self.prev_tool_call_arr.append({})
self.streaming_state = StreamingState.PARSING_NAME
delta_text = delta_text.replace(self.bot_token, "", 1)
if self.streaming_state == StreamingState.PARSING_NAME:
# ... 省略 ...
if "{" in delta_text:
# 工具名解析完成,更新 prev_tool_call_arr 中的 name
self.prev_tool_call_arr[self.current_tool_id]["name"] = (
self.current_tool_name
)
self.streaming_state = StreamingState.PARSING_ARGUMENTS
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
# ... 省略 ...
# 新增:累积参数并同步更新两条列表
self.streamed_args_for_tool[self.current_tool_id] += delta_arguments
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = (
self.streamed_args_for_tool[self.current_tool_id]
)
# ... 返回 delta 等后续逻辑 ...
return []
# 对应 pre-v11 tokenizer 路径也有类似修改,核心模式一致。
tests/tool_parsers/test_mistral_tool_parser.py
新增内部状态一致性断言,覆盖有 tool call 和无 tool call 的验证场景。
# tests/tool_parsers/test_mistral_tool_parser.py # 新增断言片段
# 在 _test_extract_tool_calls_streaming 函数结尾新增:
if expected_tool_calls:
# 验证内部状态列表长度与期望 tool call 数量一致
assert len(tool_parser.streamed_args_for_tool) == len(expected_tool_calls)
assert len(tool_parser.prev_tool_call_arr) == len(expected_tool_calls)
for i in range(len(expected_tool_calls)):
# prev_tool_call_arr 中的 arguments 应等于 streamed_args_for_tool[i]
assert (
tool_parser.prev_tool_call_arr[i]["arguments"]
== tool_parser.streamed_args_for_tool[i]
)
# streamed_args_for_tool[i] 应等于实际累积的参数字符串
assert tool_parser.streamed_args_for_tool[i] == function_args_strs[i]
# prev_tool_call_arr 中的 name 应等于期望的工具名
assert (
tool_parser.prev_tool_call_arr[i]["name"]
== expected_tool_calls[i].function.name
)
# 模拟 serving 层的 unstreamed-args 检查(剩余 JSON 片段应为空)
index = len(tool_parser.prev_tool_call_arr) - 1
args = tool_parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = (
args if isinstance(args, str) else json.dumps(args, ensure_ascii=False)
)
actual_call = tool_parser.streamed_args_for_tool[index]
remaining_call = expected_call.replace(actual_call, "", 1)
assert remaining_call == ""
else:
# 无 tool call 时两个列表应均为空
assert len(tool_parser.streamed_args_for_tool) == 0
assert len(tool_parser.prev_tool_call_arr) == 0
评论区精华
无人工审核评论。仅 bot(gemini-code-assist)总结变更内容,sfeng33 直接 approve。
风险与影响
关联脉络
参与讨论