qwen_agent/skills/developing/pmda-drug-info/pmda_server.py
2026-06-12 11:03:30 +08:00

828 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
PMDA drug information MCP server — 真实 PG / OS 查询版本(替换原 mock.
Plugin 自包含,不依赖 mygpt.* 任何模块。配置通过环境变量:
PMDA_PG_HOST / PMDA_PG_PORT / PMDA_PG_DB / PMDA_PG_USER / PMDA_PG_PASSWORD
PMDA_OPENSEARCH_URL (or OPENSEARCH_URL) / PMDA_OS_INDEX
参考 hu-sandbox/pmda/agent/tools.py 的 10 个 tool 行为98/100 v2e baseline.
"""
import asyncio
import sys
from dataclasses import asdict
from decimal import Decimal
from typing import Any, Dict, List, Optional, Sequence, Tuple
from mcp_common import (
create_error_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
load_tools_from_json,
handle_mcp_streaming,
)
from db import session
from queries import (
drug_dosing_query,
drug_interaction_query,
drug_master_get,
drug_restriction_query,
list_categories_with_counts,
list_drugs_in_category as _sql_list_drugs_in_category,
search_drugs_in_db,
)
from os_client import client as os_client, INDEX_NAME as OS_INDEX_NAME
# ---------------------------------------------------------------------------
# Plain-text rendering (agent-friendly tool output)
#
# 工具结果以纯文本返回(而非 JSON降低 agent 的解析负担与 token 噪音。
# CITATION enforcement 仍是工程化保证(不依赖 LLM 自觉):
# 1. `_CITE_INSTRUCTION_TEXT` 注入每个含可引用源的结果头部 (LLM 第一眼)
# 2. 每条记录末尾一行 `CITATION:` 镜像 `_cite._tag` (LLM 直接复制, 不用 traverse)
# 3. `read_drug_chapter` 三明治包装 raw markdown (tag 物理紧贴章节文本)
# 命中 0 条时返回英文 no-results 话术, 且 **不含** CITATION 指令 —— 避免诱导
# agent 在无来源时编造引用。
# ---------------------------------------------------------------------------
_CITE_INSTRUCTION_TEXT = (
"=== CITATION INSTRUCTIONS ===\n"
"Each record below ends with a `CITATION:` line — a pre-built "
"`<CITATION file=\"...\" filename=\"...\" />` tag the frontend PDF-highlight "
"pipeline depends on. When you use a fact from a record, copy that record's "
"`CITATION:` tag VERBATIM (byte-for-byte) immediately AFTER the paragraph or "
"bullet that states the fact. NEVER collect citations at the end. At most ONE "
"tag per unique file. Do NOT add, modify, reorder, remove attributes, or build "
"a tag yourself. Records without a `CITATION:` line carry no clickable source — "
"do NOT fabricate one. An answer that uses these facts but contains zero "
"`<CITATION>` tags is a FAILED answer.\n"
"=============================="
)
def _no_results(what: str) -> str:
"""English no-results message — intentionally omits CITATION instructions.
Returned when a query matches 0 rows, so the agent tells the user nothing was
found instead of being pushed to emit a citation for a non-existent source.
"""
return (
f"No matching {what} were found in the PMDA package-insert database.\n"
)
def _fmt(value: Any) -> str:
"""Render a single field value as compact text (Decimal → number)."""
if isinstance(value, Decimal):
value = float(value)
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value)
def _tag_of(data: dict) -> Optional[str]:
"""Pull the pre-built ``<CITATION ... />`` tag out of a record."""
return data.get("cite_emit") or (data.get("_cite") or {}).get("_tag")
def _render_records(
records: Sequence[dict],
*,
what: str,
header_title: str,
field_specs: Sequence[Tuple[str, str]],
title_key: Optional[str] = None,
with_citation: bool = True,
) -> str:
"""Render a flat list of record dicts into agent-friendly plain text.
Empty ``records`` → English no-results message (no CITATION instructions).
Otherwise: optional citation-instruction header, a ``header_title`` line, then
one block per record. ``field_specs`` is ``[(key, label), ...]`` controlling
field order/display; empty values are skipped. ``title_key`` (if given) is the
record's headline; each record's ``_citation`` text and CITATION tag are
appended when present.
"""
if not records:
return _no_results(what)
parts: List[str] = []
if with_citation:
parts.append(_CITE_INSTRUCTION_TEXT)
parts.append(header_title)
for idx, rec in enumerate(records, 1):
title = _fmt(rec.get(title_key)) if title_key and rec.get(title_key) else ""
lines = [f"[{idx}] {title}".rstrip()]
for key, label in field_specs:
value = rec.get(key)
if value in (None, "", [], {}):
continue
lines.append(f" {label}: {_fmt(value)}")
if rec.get("_citation"):
lines.append(f" 出典: {rec['_citation']}")
if with_citation:
tag = _tag_of(rec)
if tag:
lines.append(f" CITATION: {tag}")
parts.append("\n".join(lines))
return "\n\n".join(parts)
def _render_categories(data: Sequence[dict]) -> str:
"""Render the L1/L2 category tree (navigation only — no citation source)."""
if not data:
return _no_results("categories")
lines: List[str] = ["Drug categories:"]
for l1 in data:
lines.append(f"\n{l1.get('l1_code', '')} {l1.get('l1_name', '')}".rstrip())
for l2 in l1.get("l2", []):
lines.append(
f" - {l2.get('code', '')} {l2.get('name', '')} "
f"({l2.get('drug_count', 0)} drugs)"
)
return "\n".join(lines)
def _render_drugs_in_category(data: dict) -> str:
"""Render generic → [brand] listing for one L2 category (navigation only)."""
generics = data.get("generics") or []
if not generics:
return _no_results("drugs in this category")
header = f"Category {data.get('l2_code', '')} {data.get('l2_name', '')}".rstrip()
lines: List[str] = [header]
for entry in generics:
lines.append(f"\n{entry.get('generic', '')}".rstrip())
for drug in entry.get("drugs", []):
if "_more" in drug:
lines.append(f" - {drug['_more']}")
else:
lines.append(
f" - {drug.get('brand', '')} (yj_full={drug.get('yj_full', '')})"
)
if data.get("_more_generics"):
lines.append(f"\n(+{data['_more_generics']} more generics)")
return "\n".join(lines)
def _render_section_hits(
*, keyword: str, section_filter: str, total: int, hits: Sequence[dict]
) -> str:
"""Render OpenSearch section-text hits with per-match snippets (carry tags)."""
shown = len(hits)
title = f'Found {total} drug(s) matching "{keyword}"'
if section_filter:
title += f' in sections like "{section_filter}"'
title += f" (showing {shown}):"
parts: List[str] = [_CITE_INSTRUCTION_TEXT, title]
for idx, hit in enumerate(hits, 1):
head = f"[{idx}] {hit.get('brand', '')} / {hit.get('generic', '')}".rstrip(" /")
l2 = hit.get("l2", "")
lines = [f"{head} ({l2})" if l2 else head]
lines.append(f" yj_full: {hit.get('yj_full', '')}")
for m in hit.get("matches", []):
lines.append(f"{m.get('section_title', '')}")
snippet = (m.get("snippet") or "").strip()
for sl in snippet.splitlines():
lines.append(f" {sl}")
if hit.get("_citation_template"):
lines.append(f" 出典テンプレ: {hit['_citation_template']}")
tag = _tag_of(hit)
if tag:
lines.append(f" CITATION: {tag}")
parts.append("\n".join(lines))
more = total - shown
if more > 0:
parts.append(f"(+{more} more drugs not shown)")
return "\n\n".join(parts)
def _render_chapters(
*, yj_full: str, brand: str, generic: str, sections: Sequence[dict]
) -> str:
"""Render the chapter index for one drug; each chapter carries its own tag."""
has_cite = any(_tag_of(s) for s in sections)
parts: List[str] = []
if has_cite:
parts.append(_CITE_INSTRUCTION_TEXT)
parts.append(
f"{brand} / {generic} (yj_full={yj_full}) — {len(sections)} section(s):".lstrip(
" /"
)
)
block: List[str] = []
for s in sections:
block.append(
f" - {s.get('section_title', '')} "
f"(line {s.get('line_num', 0)}, {s.get('text_len', 0)} chars)"
)
tag = _tag_of(s)
if tag:
block.append(f" CITATION: {tag}")
parts.append("\n".join(block))
return "\n\n".join(parts)
# ---------------------------------------------------------------------------
# 出典フォーマッタ(与 tools.py 一致)
# ---------------------------------------------------------------------------
_DRUG_LOOKUP: Optional[dict] = None
_VF_LOOKUP: Optional[dict] = None
_BRAND_BY_YJ_FULL: Optional[dict] = None
def _load_drug_lookup() -> dict:
"""yj_code → (brand_name, yj_full) 进程内缓存"""
global _DRUG_LOOKUP
if _DRUG_LOOKUP is None:
with session() as conn, conn.cursor() as cur:
cur.execute("SELECT yj_code, brand_name, yj_full FROM drug_master")
_DRUG_LOOKUP = {
row[0]: ((row[1] or ""), (row[2] or row[0]))
for row in cur.fetchall()
}
return _DRUG_LOOKUP
def _load_brand_by_yj_full_lookup() -> dict:
"""yj_full → brand 表示名(多品名时取 "/" 分隔的第一段)。
drug_master.brand_name 是多 brand 合并的字符串 (例
"〔東洋〕半夏厚朴湯エキス細粒/〔松浦〕..."), <pmda_citation brand= /> 只用来
给前端显示一个代表性的药品名,这里固定取第一段。
"""
global _BRAND_BY_YJ_FULL
if _BRAND_BY_YJ_FULL is None:
with session() as conn, conn.cursor() as cur:
cur.execute("SELECT yj_full, brand_name FROM drug_master")
_BRAND_BY_YJ_FULL = {
yj_full: ((brand or "").split("/", 1)[0].strip())
for yj_full, brand in cur.fetchall()
if yj_full
}
return _BRAND_BY_YJ_FULL
def _load_vf_lookup() -> dict:
"""yj_full → (vector_file_id, filename, section_to_page).
Populated from ``pmda_drug_vf`` (written by gbase-onprem PmdaXmlPipeline).
If the table is empty / not yet migrated, returns ``{}`` — citations then
degrade to text-only (no ``<CITATION>`` tag emitted).
"""
global _VF_LOOKUP
if _VF_LOOKUP is None:
out: dict = {}
try:
with session() as conn, conn.cursor() as cur:
cur.execute(
"SELECT yj_full, vector_file_id, filename, section_to_page "
"FROM pmda_drug_vf"
)
for yj_full, vf_id, fname, s2p in cur.fetchall():
out[yj_full] = (str(vf_id), fname or "", s2p or {})
except Exception:
# Table not yet present — leave empty, downstream tools skip _cite.
pass
_VF_LOOKUP = out
return _VF_LOOKUP
def _citation(drug_yj: str, section: Optional[str]) -> str:
lk = _load_drug_lookup()
brand, yj_full = lk.get(drug_yj, ("", drug_yj))
chap = section or "(章不明)"
return f"[出典: {brand} (yj_full={yj_full}) / {chap}]"
def _citation_tag(cite: dict) -> str:
"""Build the ``<CITATION file="..." filename="..." />`` string.
精简版: **只输出 2 个属性 file + filename** — 减轻 LLM 负担 / 减少
输出 token / 减少幻觉表面积。前端 PDF 高亮链路实际只用 file_id +
text(段落正文),不依赖 page/yj_full/brand/section,所以 tag 里
不再带这些(`_cite` 字典里仍保留, 给前端可选展示)。
工程化预制, 让 LLM 直接照搬, 避免 LLM 自己拼字符串幻觉 file= 文件名。
"""
from html import escape as _esc
parts = []
if cite.get("file_id"):
parts.append(f'file="{_esc(str(cite["file_id"]), quote=True)}"')
if cite.get("filename"):
# 用 basename, 前端 chip 显示干净 — 完整 path 留在 _cite.filename
bn = cite["filename"].rsplit("/", 1)[-1]
parts.append(f'filename="{_esc(bn, quote=True)}"')
return f"<CITATION {' '.join(parts)} />"
def _cite_struct_by_yj_full(yj_full: str, section: Optional[str]) -> Optional[dict]:
"""Build the ``_cite`` dict directly from a yj_full.
返回 ``{file_id, filename, page, yj_full, brand, section?}`` — 复用通用
``<CITATION file="uuid" filename="name" page=N />`` 协议, 额外附加 PMDA
专属属性 ``yj_full`` / ``brand`` / ``section``。
核心属性 (通用 CITATION 协议):
- ``file_id`` : VectorFile.id (uuid), 通用 /pdf/highlight 用这个定位 PDF
- ``filename`` : VF 文件名, 通用 CITATION 展示用
- ``page`` : PDF 页码 (0-based), 第一版固定 0 (后端 expand_pages 全文搜兜底)
PMDA 额外属性 (前端可选读):
- ``yj_full`` : 厚労省 YJ コード (含枝番), 跨 vf_uuid 稳定的唯一 id
- ``brand`` : 表示用販売名 (drug_master.brand_name "/" 分隔的第一段)
- ``section?`` : fact 表 source_section 完整字符串 (例 "10.1 併用禁忌")
存在性验证 (硬要求, 缺一不返 _cite):
- brand lookup (drug_master) 找不到 → None
- vf_lookup (pmda_drug_vf) 找不到 → None (避免输出 <CITATION /> 空壳 tag)
返 None 时 caller 不附 _cite, LLM 看到没 _cite 就不会 emit citation —
比 emit 一个无 file/filename 属性的空标签好(前端解析空标签会渲染成
broken chip)。
"""
brand = _load_brand_by_yj_full_lookup().get(yj_full)
if not brand:
return None
# 通用 CITATION 核心属性: file_id / filename 必须有, 否则不出 tag
vf_info = _load_vf_lookup().get(yj_full)
if not vf_info:
return None
vf_id, filename, _s2p = vf_info
cite: dict = {
"yj_full": yj_full,
"brand": brand,
"file_id": vf_id,
"filename": filename,
"page": 0, # 第一版固定 page 0, 后端 expand_pages 全文搜
}
if section:
cite["section"] = section
# 工程化预制完整 tag 字符串, 让 LLM 只做复制粘贴, 不再自己拼
cite["_tag"] = _citation_tag(cite)
return cite
def _cite_struct(drug_yj: str, section: Optional[str]) -> Optional[dict]:
"""Return ``{file_id, filename, page, yj_full, brand, section?}`` for the ``<CITATION>`` tag.
Returns ``None`` when drug_master has no row for this yj (skill can still
emit the human ``[出典: ...]`` text).
"""
drug_lk = _load_drug_lookup()
_, yj_full = drug_lk.get(drug_yj, ("", drug_yj))
return _cite_struct_by_yj_full(yj_full, section)
# ---------------------------------------------------------------------------
# Tool implementations (10 个)
# ---------------------------------------------------------------------------
def _tool_search_drugs(query: str, kind: str = "auto", limit: int = 10) -> str:
rows = search_drugs_in_db(query, kind=kind, limit=limit)
out = []
for r in rows:
entry: dict = {
"yj_full": r.yj_full,
"yj_code": r.yj_code,
"brand": r.brand_name,
"generic": r.generic_name,
"category": f"{r.category_code} {r.category_name}".strip(),
"score": r.score,
}
cite = _cite_struct_by_yj_full(r.yj_full, section=None)
if cite is not None:
entry["_cite"] = cite
entry["cite_emit"] = cite["_tag"] # top-level mirror for LLM
out.append(entry)
return _render_records(
out,
what="drugs",
header_title=f"Found {len(out)} drug(s):",
title_key="brand",
field_specs=[
("generic", "generic"),
("yj_full", "yj_full"),
("yj_code", "yj_code"),
("category", "category"),
("score", "score"),
],
)
def _tool_list_categories() -> str:
return _render_categories(list_categories_with_counts())
def _tool_list_drugs_in_category(l2_code: str, limit_generics: int = 50) -> str:
return _render_drugs_in_category(
_sql_list_drugs_in_category(l2_code, limit_generics=limit_generics)
)
def _tool_get_drug_master(yj_code: str) -> str:
row = drug_master_get(yj_code)
if row is None:
return _no_results("drug master record")
result = asdict(row)
result["_citation"] = f"[出典: {row.brand_name} (yj_full={row.yj_full}) / 添付文書冒頭]"
cite = _cite_struct(row.yj_code, section=None)
if cite is not None:
result["_cite"] = cite
result["cite_emit"] = cite["_tag"] # top-level mirror for LLM
return _render_records(
[result],
what="drug master record",
header_title="Drug master record:",
title_key="brand_name",
field_specs=[
("generic_name_jp", "generic"),
("yj_full", "yj_full"),
("yj_code", "yj_code"),
("category_code", "category_code"),
("category_name", "category_name"),
("regulation", "regulation"),
("manufacturer", "manufacturer"),
("revision_date", "revision_date"),
],
)
def _tool_get_drug_interactions(
drug_a_yj: Optional[str] = None,
drug_b_yj: Optional[str] = None,
severity: Optional[str] = None,
keyword: Optional[str] = None,
limit: int = 30,
) -> str:
rows = drug_interaction_query(
drug_a_yj=drug_a_yj,
drug_b_yj=drug_b_yj,
severity=severity,
keyword=keyword,
limit=limit,
)
out = []
for r in rows:
d = asdict(r)
d["_citation"] = _citation(r.source_drug_yj, r.source_section)
cite = _cite_struct(r.source_drug_yj, r.source_section)
if cite is not None:
d["_cite"] = cite
d["cite_emit"] = cite["_tag"] # top-level mirror for LLM
out.append(d)
return _render_records(
out,
what="drug interactions",
header_title=f"Found {len(out)} drug interaction(s):",
title_key="severity",
field_specs=[
("drug_a_yj", "drug_a_yj"),
("drug_b_yj", "drug_b_yj"),
("drug_b_class", "drug_b_class"),
("mechanism", "mechanism"),
("clinical_effect", "clinical_effect"),
("source_section", "source_section"),
],
)
def _tool_get_drug_restrictions(
drug_yj: Optional[str] = None,
condition_type: Optional[str] = None,
severity: Optional[str] = None,
keyword: Optional[str] = None,
limit: int = 30,
) -> str:
rows = drug_restriction_query(
drug_yj=drug_yj,
condition_type=condition_type,
severity=severity,
keyword=keyword,
limit=limit,
)
out = []
for r in rows:
d = asdict(r)
d["_citation"] = _citation(r.drug_yj, r.source_section)
cite = _cite_struct(r.drug_yj, r.source_section)
if cite is not None:
d["_cite"] = cite
d["cite_emit"] = cite["_tag"] # top-level mirror for LLM
out.append(d)
return _render_records(
out,
what="drug restrictions",
header_title=f"Found {len(out)} drug restriction(s):",
title_key="condition_type",
field_specs=[
("drug_yj", "drug_yj"),
("condition_text", "condition_text"),
("severity", "severity"),
("source_section", "source_section"),
],
)
def _tool_get_drug_dosing(
drug_yj: str,
patient_segment: Optional[str] = None,
limit: int = 20,
) -> str:
rows = drug_dosing_query(
drug_yj=drug_yj,
patient_segment=patient_segment,
limit=limit,
)
out = []
for r in rows:
d = asdict(r)
# Merge amount + unit into one readable "dose" field for plain-text output.
if r.dose_amount is not None:
d["dose"] = f"{_fmt(r.dose_amount)}{r.dose_unit or ''}".strip()
d["_citation"] = _citation(r.drug_yj, r.source_section)
cite = _cite_struct(r.drug_yj, r.source_section)
if cite is not None:
d["_cite"] = cite
d["cite_emit"] = cite["_tag"] # top-level mirror for LLM
out.append(d)
return _render_records(
out,
what="dosing entries",
header_title=f"Found {len(out)} dosing entr{'y' if len(out) == 1 else 'ies'}:",
title_key="patient_segment",
field_specs=[
("indication_code", "indication_code"),
("dose", "dose"),
("frequency", "frequency"),
("duration", "duration"),
("adjustment_text", "adjustment"),
("source_section", "source_section"),
],
)
def _tool_search_section_text(
keyword: str,
section_filter: str = "",
limit: int = 30,
) -> str:
if not keyword.strip():
return _no_results("sections")
size = min(max(1, limit), 100)
body: dict = {
"size": size,
"_source": ["yj_full", "brand_names", "generic_name",
"l2_code", "l2_name", "section_title", "line_num"],
"query": {"bool": {"must": [{"match": {"text": keyword}}]}},
"collapse": {
"field": "yj_full",
"inner_hits": {
"name": "matches",
"size": 2,
"_source": ["section_title", "line_num"],
"highlight": {"fields": {"text": {"fragment_size": 160, "number_of_fragments": 1}}},
},
},
"aggs": {"total_drugs": {"cardinality": {"field": "yj_full"}}},
}
if section_filter:
body["query"]["bool"]["filter"] = [
{"wildcard": {"section_title.raw": f"*{section_filter}*"}}
]
resp = os_client().search(index=OS_INDEX_NAME, body=body)
total = int(resp["aggregations"]["total_drugs"]["value"])
hits_out = []
for h in resp["hits"]["hits"]:
src = h.get("_source") or {}
inner = h.get("inner_hits", {}).get("matches", {}).get("hits", {}).get("hits", [])
brand = (src.get("brand_names") or [""])[0]
yj_full = src.get("yj_full") or ""
# Per-match snippet 自带对应 section 的 CITATION tag,
# LLM 复制 snippet 时自动带对的 section 标签 (而不是 hit 顶层粗粒度 tag)。
matches = []
seen = set()
for ih in inner:
ih_src = ih.get("_source") or {}
title = ih_src.get("section_title") or ""
if title in seen:
continue
seen.add(title)
hl = ih.get("highlight", {}).get("text", [""])
snippet_text = hl[0] if hl else ""
inner_cite = _cite_struct_by_yj_full(yj_full, title)
inner_tag = (inner_cite or {}).get("_tag", "")
matches.append({
"section_title": title,
"snippet": snippet_text + (f"\n{inner_tag}" if inner_tag else ""),
})
# Hit-level _cite per first-match section (legacy compatibility).
cite_section = matches[0]["section_title"] if matches else None
cite = _cite_struct_by_yj_full(yj_full, cite_section)
hit_entry: dict = {
"yj_full": yj_full,
"brand": brand,
"generic": src.get("generic_name") or "",
"l2": f"{src.get('l2_code') or ''} {src.get('l2_name') or ''}".strip(),
"matches": matches,
"_citation_template": f"[出典: {brand} (yj_full={yj_full}) / <該当章>]",
}
if cite is not None:
hit_entry["_cite"] = cite
hit_entry["cite_emit"] = cite["_tag"] # top-level mirror for LLM
hits_out.append(hit_entry)
if not hits_out:
return _no_results("sections")
return _render_section_hits(
keyword=keyword,
section_filter=section_filter,
total=total,
hits=hits_out,
)
def _tool_list_drug_chapters(yj_full: str) -> str:
"""全章節 (section_title + line_num + text_len) for a yj_fullline_num 昇順)"""
body = {
"size": 200,
"_source": ["section_title", "line_num", "brand_names", "generic_name", "text"],
"query": {"term": {"yj_full": yj_full}},
"sort": [{"line_num": "asc"}],
}
resp = os_client().search(index=OS_INDEX_NAME, body=body)
hits = resp["hits"]["hits"]
if not hits:
return _no_results("chapters")
head = hits[0].get("_source", {})
brand = (head.get("brand_names") or [""])[0]
generic = head.get("generic_name") or ""
sections = []
for h in hits:
src = h.get("_source", {})
section_title = src.get("section_title", "")
entry: dict = {
"section_title": section_title,
"line_num": src.get("line_num", 0),
"text_len": len(src.get("text", "")),
}
cite = _cite_struct_by_yj_full(yj_full, section_title)
if cite is not None:
entry["_cite"] = cite
entry["cite_emit"] = cite["_tag"] # top-level mirror for LLM
sections.append(entry)
return _render_chapters(
yj_full=yj_full, brand=brand, generic=generic, sections=sections
)
def _tool_read_drug_chapter(yj_full: str, section_title: str) -> str:
"""指定 (yj_full, section_title) の章節 markdown 全文max 8000 字)。"""
body = {
"size": 1,
"_source": ["text"],
"query": {
"bool": {
"must": [
{"term": {"yj_full": yj_full}},
{"term": {"section_title.raw": section_title}},
]
}
},
}
resp = os_client().search(index=OS_INDEX_NAME, body=body)
hits = resp["hits"]["hits"]
if hits:
src = hits[0].get("_source") or {}
text = src.get("text", "")
if text:
# 三明治包装: header (CITATION reminder + tag) + body + footer (tag).
# LLM 不读 body 就一定先读 header, 输出时复制段落自然带上 tag。
# HTML comment 对 chat 渲染不可见, 但 LLM 在 sampling 时看得到。
cite = _cite_struct_by_yj_full(yj_full, section_title)
tag = (cite or {}).get("_tag", "")
if tag:
header = (
f"<!-- CITATION rule: copy the tag below verbatim into your answer -->\n"
f"{tag}\n\n"
)
footer = (
f"\n\n---\n"
f"If you use the content above in your answer, you MUST include this "
f"tag verbatim:\n"
f"{tag}\n"
)
return header + text[:8000] + footer
return text[:8000]
# Not a "no data" case but a parameter mismatch — keep the actionable hint.
return (
f'No section titled "{section_title}" exactly matches yj_full={yj_full}.\n'
f"Hint: pass a sections[].section_title returned by list_drug_chapters "
f"verbatim."
)
# ---------------------------------------------------------------------------
# MCP dispatch
# ---------------------------------------------------------------------------
_TOOL_DISPATCH = {
"search_drugs": lambda args: _tool_search_drugs(
query=args.get("query", ""),
kind=args.get("kind", "auto"),
limit=args.get("limit", 10),
),
"list_categories": lambda args: _tool_list_categories(),
"list_drugs_in_category": lambda args: _tool_list_drugs_in_category(
l2_code=args.get("l2_code", ""),
limit_generics=args.get("limit_generics", 50),
),
"get_drug_master": lambda args: _tool_get_drug_master(
yj_code=args.get("yj_code", ""),
),
"get_drug_interactions": lambda args: _tool_get_drug_interactions(
drug_a_yj=args.get("drug_a_yj"),
drug_b_yj=args.get("drug_b_yj"),
severity=args.get("severity"),
keyword=args.get("keyword"),
limit=args.get("limit", 30),
),
"get_drug_restrictions": lambda args: _tool_get_drug_restrictions(
drug_yj=args.get("drug_yj"),
condition_type=args.get("condition_type"),
severity=args.get("severity"),
keyword=args.get("keyword"),
limit=args.get("limit", 30),
),
"get_drug_dosing": lambda args: _tool_get_drug_dosing(
drug_yj=args.get("drug_yj", ""),
patient_segment=args.get("patient_segment"),
limit=args.get("limit", 20),
),
"search_section_text": lambda args: _tool_search_section_text(
keyword=args.get("keyword", ""),
section_filter=args.get("section_filter", ""),
limit=args.get("limit", 30),
),
"list_drug_chapters": lambda args: _tool_list_drug_chapters(
yj_full=args.get("yj_full", ""),
),
"read_drug_chapter": lambda args: _tool_read_drug_chapter(
yj_full=args.get("yj_full", ""),
section_title=args.get("section_title", ""),
),
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "pmda-drug-info")
if method == "ping":
return create_ping_response(request_id)
if method == "tools/list":
tools = load_tools_from_json("pmda_tools.json")
return create_tools_list_response(request_id, tools)
if method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name not in _TOOL_DISPATCH:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
try:
result_text = _TOOL_DISPATCH[tool_name](arguments)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {"content": [{"type": "text", "text": result_text}]},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {"content": [{"type": "text", "text": f"Error: {type(e).__name__}: {e}"}]},
}
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {e}")
async def main():
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())