155 lines
4.6 KiB
Python
155 lines
4.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试统一使用OpenAI格式的一致性
|
||
"""
|
||
|
||
import json
|
||
from fastapi_app import convert_messages_to_openai_format
|
||
from openai_converter import OpenAIConverter
|
||
|
||
|
||
def test_unified_openai_format():
|
||
"""测试统一使用OpenAI格式"""
|
||
print("=== 测试统一OpenAI格式 ===")
|
||
|
||
# 测试数据
|
||
test_messages = [
|
||
{
|
||
"role": "assistant",
|
||
"reasoning_content": "这是思考过程",
|
||
"content": "这是回答内容",
|
||
"function_call": {
|
||
"name": "test_function",
|
||
"arguments": '{"param1": "value1"}'
|
||
}
|
||
},
|
||
{
|
||
"role": "function",
|
||
"name": "test_function",
|
||
"content": "工具执行结果"
|
||
}
|
||
]
|
||
|
||
print("原始qwen-agent格式:")
|
||
print(json.dumps(test_messages, indent=2, ensure_ascii=False))
|
||
|
||
# 使用新的统一函数
|
||
openai_messages = convert_messages_to_openai_format(test_messages)
|
||
|
||
print("\n转换后的OpenAI格式:")
|
||
print(json.dumps(openai_messages, indent=2, ensure_ascii=False))
|
||
|
||
# 验证格式
|
||
assert len(openai_messages) == 2
|
||
|
||
# 验证第一个消息
|
||
first_msg = openai_messages[0]
|
||
assert first_msg["role"] == "assistant"
|
||
assert "reasoning_content" in first_msg
|
||
assert "tool_calls" in first_msg
|
||
assert first_msg["tool_calls"][0]["function"]["name"] == "test_function"
|
||
|
||
# 验证第二个消息
|
||
second_msg = openai_messages[1]
|
||
assert second_msg["role"] == "tool"
|
||
assert "tool_call_id" in second_msg
|
||
|
||
print("✅ 统一OpenAI格式测试通过")
|
||
|
||
|
||
def test_stream_and_nonstream_consistency():
|
||
"""测试流式和非流式响应的一致性"""
|
||
print("\n=== 测试流式和非流式一致性 ===")
|
||
|
||
converter = OpenAIConverter()
|
||
|
||
# 模拟agent返回的消息
|
||
agent_messages = [
|
||
{
|
||
"role": "assistant",
|
||
"content": "让我帮你查询文件",
|
||
"reasoning_content": "用户想查询文件信息",
|
||
"function_call": {
|
||
"name": "list_files",
|
||
"arguments": '{"path": "/tmp"}'
|
||
}
|
||
}
|
||
]
|
||
|
||
# 非流式处理
|
||
openai_messages = converter.convert_messages_to_openai_format(agent_messages)
|
||
|
||
# 创建非流式响应
|
||
non_stream_response = converter.create_openai_response(
|
||
"test-model",
|
||
openai_messages,
|
||
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||
)
|
||
|
||
print("非流式响应:")
|
||
print(json.dumps(non_stream_response, indent=2, ensure_ascii=False))
|
||
|
||
# 验证响应结构
|
||
assert "choices" in non_stream_response
|
||
assert "usage" in non_stream_response
|
||
assert non_stream_response["choices"][0]["message"]["role"] == "assistant"
|
||
assert "tool_calls" in non_stream_response["choices"][0]["message"]
|
||
|
||
print("✅ 流式和非流式一致性测试通过")
|
||
|
||
|
||
def test_no_more_string_conversion():
|
||
"""测试不再有字符串转换"""
|
||
print("\n=== 测试无字符串转换 ===")
|
||
|
||
# 测试直接返回OpenAI格式,不转换为字符串
|
||
test_messages = [
|
||
{
|
||
"role": "assistant",
|
||
"content": "这是一个简单的回答",
|
||
"function_call": {
|
||
"name": "simple_tool",
|
||
"arguments": '{"input": "test"}'
|
||
}
|
||
}
|
||
]
|
||
|
||
result = convert_messages_to_openai_format(test_messages)
|
||
|
||
# 验证返回的是结构化数据,不是字符串
|
||
assert isinstance(result, list)
|
||
assert isinstance(result[0], dict)
|
||
assert "tool_calls" in result[0]
|
||
assert "function" in result[0]["tool_calls"][0]
|
||
|
||
# 验证没有旧的字符串标记
|
||
msg_str = str(result)
|
||
assert "[THINK]" not in msg_str
|
||
assert "[TOOL_CALL]" not in msg_str
|
||
assert "[ANSWER]" not in msg_str
|
||
|
||
print("返回的数据类型:", type(result))
|
||
print("消息结构:", json.dumps(result[0], indent=2, ensure_ascii=False))
|
||
print("✅ 无字符串转换测试通过")
|
||
|
||
|
||
def main():
|
||
"""运行所有测试"""
|
||
print("测试统一使用OpenAI格式")
|
||
print("=" * 50)
|
||
|
||
test_unified_openai_format()
|
||
test_stream_and_nonstream_consistency()
|
||
test_no_more_string_conversion()
|
||
|
||
print("\n" + "=" * 50)
|
||
print("✅ 所有测试通过")
|
||
print("\n统一后的优势:")
|
||
print("1. 流式和非流式都使用OpenAI格式")
|
||
print("2. 不再有字符串转换的开销")
|
||
print("3. 完全兼容OpenAI API协议")
|
||
print("4. 代码更加简洁和一致")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |