167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
PrePrompt Hook - 用户上下文加载器
|
||
|
||
在 system_prompt 加载时执行,通过 MCP 服务查询用户相关信息并注入到 prompt 中。
|
||
"""
|
||
import json
|
||
import os
|
||
import sys
|
||
|
||
import anyio
|
||
from mcp.client.session import ClientSession
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
|
||
# MCP 服务配置
|
||
MCP_BASE_URL = "http://prd-mcp.gbase.ai/mcp/iot/sse"
|
||
MCP_TIMEOUT = 30
|
||
|
||
|
||
async def get_employee_location(user_identifier: str, bot_id: str) -> dict | None:
|
||
"""
|
||
通过 MCP 服务查询员工位置信息
|
||
|
||
Args:
|
||
user_identifier: 用户标识(员工姓名或邮箱)
|
||
bot_id: Bot ID(当前未使用,保留以备将来扩展)
|
||
|
||
Returns:
|
||
员工位置信息字典,如果查询失败返回 None
|
||
"""
|
||
try:
|
||
async with streamablehttp_client(
|
||
url=MCP_BASE_URL,
|
||
timeout=MCP_TIMEOUT,
|
||
terminate_on_close=False,
|
||
) as (read_stream, write_stream, _):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
# 初始化 MCP 会话
|
||
await session.initialize()
|
||
|
||
# 调用 find_employee_location 工具
|
||
result = await session.call_tool(
|
||
name="find_employee_location",
|
||
arguments={"name": user_identifier}
|
||
)
|
||
|
||
# 解析返回结果
|
||
if result.content:
|
||
for item in result.content:
|
||
if hasattr(item, 'text') and item.text:
|
||
try:
|
||
return json.loads(item.text)
|
||
except json.JSONDecodeError:
|
||
# 如果不是 JSON,直接返回文本
|
||
return {"data": item.text}
|
||
|
||
return None
|
||
|
||
except Exception:
|
||
# 发生错误时返回 None,不影响主流程
|
||
return None
|
||
|
||
|
||
def format_location_context(location_data: dict | None, user_identifier: str, bot_id: str) -> str:
|
||
"""
|
||
格式化位置信息为 Markdown 上下文(日语)
|
||
|
||
Args:
|
||
location_data: 从 MCP 查询返回的位置数据
|
||
user_identifier: 用户标识(不使用)
|
||
bot_id: Bot ID(不使用)
|
||
|
||
Returns:
|
||
格式化后的 Markdown 字符串,出错或无数据时返回空字符串
|
||
"""
|
||
# 出错或无数据时返回空字符串
|
||
if not location_data:
|
||
return ""
|
||
|
||
matched_count = location_data.get('matched_count', 0)
|
||
results = location_data.get('results', [])
|
||
|
||
# 没有匹配数据时返回空字符串
|
||
if matched_count == 0:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
# 添加说明:这是当前用户(USER_IDENTIFIER)的信息
|
||
lines.append(f"**Current User ({user_identifier}) Information**:")
|
||
lines.append("")
|
||
|
||
for idx, employee in enumerate(results, 1):
|
||
name = employee.get('name', 'Unknown')
|
||
sensor_id = employee.get('sensor_id', '')
|
||
confidence = employee.get('confidence', 0)
|
||
|
||
lines.append(f"- Name: {name}")
|
||
lines.append(f"- Sensor ID: {sensor_id}")
|
||
|
||
location_status = employee.get('location_status', '')
|
||
|
||
if location_status == 'success':
|
||
coordinates = employee.get('coordinates', {})
|
||
location = employee.get('location', {})
|
||
|
||
lines.append(f"- Location Status: Success")
|
||
lines.append(f"- Floor: {coordinates.get('floor', 'N/A')}")
|
||
lines.append(f"- Coordinates: ({coordinates.get('x', 0):.2f}, {coordinates.get('y', 0):.2f}, {coordinates.get('z', 0):.2f})")
|
||
|
||
if location:
|
||
building = location.get('building', '')
|
||
area = location.get('area', '')
|
||
room = location.get('room', '')
|
||
lines.append(f"- Detailed Location: {building} / {area} / {room}")
|
||
|
||
measurement_time = employee.get('measurement_time')
|
||
if measurement_time:
|
||
lines.append(f"- Measurement Time: {measurement_time}")
|
||
|
||
elif location_status == 'not_in_range':
|
||
lines.append(f"- Location Status: Out of Range")
|
||
error_message = employee.get('error_message', '')
|
||
if error_message:
|
||
lines.append(f"- Note: {error_message}")
|
||
|
||
else: # error
|
||
lines.append(f"- Location Status: Failed")
|
||
error_message = employee.get('error_message', 'Unknown Error')
|
||
lines.append(f"- Error: {error_message}")
|
||
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
async def async_main():
|
||
"""异步主函数"""
|
||
user_identifier = os.environ.get('USER_IDENTIFIER', '')
|
||
bot_id = os.environ.get('BOT_ID', '')
|
||
|
||
if not user_identifier:
|
||
return 0
|
||
|
||
# 查询员工位置信息
|
||
location_data = await get_employee_location(user_identifier, bot_id)
|
||
|
||
# 格式化并输出上下文(出错或无数据时返回空字符串)
|
||
context_info = format_location_context(location_data, user_identifier, bot_id)
|
||
if context_info:
|
||
print(context_info)
|
||
|
||
return 0
|
||
|
||
|
||
def main():
|
||
"""从环境变量读取参数并通过 MCP 服务查询用户上下文"""
|
||
try:
|
||
return anyio.run(async_main)
|
||
except Exception:
|
||
# 出错时返回空字符串(不输出任何内容)
|
||
return 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
sys.exit(main())
|