qwen_agent/skills/developing/pmda-drug-info/pmda_server.py
2026-05-11 18:55:25 +08:00

597 lines
20 KiB
Python

#!/usr/bin/env python3
"""
PMDA drug information MCP server.
Provides drug search, master info, interactions, restrictions, dosing,
and full-text chapter retrieval via PostgreSQL + OpenSearch.
"""
import asyncio
import json
import os
import sys
from decimal import Decimal
from typing import Any, Dict, List, Optional
import psycopg2
import psycopg2.extras
from opensearchpy import OpenSearch
from mcp_common import (
create_error_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
load_tools_from_json,
handle_mcp_streaming,
)
# ---------------------------------------------------------------------------
# Configuration from environment variables
# ---------------------------------------------------------------------------
PG_DSN = os.getenv("PMDA_PG_DSN", "")
OS_HOST = os.getenv("PMDA_OS_HOST", "localhost")
OS_PORT = int(os.getenv("PMDA_OS_PORT", "9200"))
OS_INDEX = os.getenv("PMDA_OS_INDEX", "pmda_sections")
def _json_default(o):
"""JSON serializer for objects not serializable by default json code."""
if isinstance(o, Decimal):
return float(o)
raise TypeError(f"non-serializable: {type(o).__name__}")
def _dump(obj) -> str:
return json.dumps(obj, ensure_ascii=False, default=_json_default)
# ---------------------------------------------------------------------------
# Lazy database connections
# ---------------------------------------------------------------------------
_pg_conn = None
_os_client = None
# Drug lookup cache: yj_code -> (brand_name, yj_full)
_drug_lookup: Optional[Dict[str, tuple]] = None
def _get_pg():
global _pg_conn
if _pg_conn is None or _pg_conn.closed:
if not PG_DSN:
raise RuntimeError("PMDA_PG_DSN environment variable is not set")
_pg_conn = psycopg2.connect(PG_DSN)
_pg_conn.autocommit = True
return _pg_conn
def _get_os() -> OpenSearch:
global _os_client
if _os_client is None:
_os_client = OpenSearch(
hosts=[{"host": OS_HOST, "port": OS_PORT}],
use_ssl=False,
verify_certs=False,
)
return _os_client
def _load_drug_lookup() -> Dict[str, tuple]:
"""Load yj_code -> (brand_name, yj_full) mapping from drug_master."""
global _drug_lookup
if _drug_lookup is not None:
return _drug_lookup
conn = _get_pg()
with conn.cursor() as cur:
cur.execute("SELECT yj_code, brand_name, yj_full FROM drug_master")
_drug_lookup = {
row[0]: (row[1] or "", row[2] or row[0]) for row in cur.fetchall()
}
return _drug_lookup
def _citation(drug_yj: str, section: Optional[str]) -> str:
"""Format citation string: [出典: <brand> (yj_full=<id>) / <section>]"""
lk = _load_drug_lookup()
brand, yj_full = lk.get(drug_yj, ("", drug_yj))
chap = section or "(章不明)"
return f"[出典: {brand} (yj_full={yj_full}) / {chap}]"
# ---------------------------------------------------------------------------
# Tool implementations
# ---------------------------------------------------------------------------
def _tool_search_drugs(query: str, kind: str = "auto", limit: int = 10) -> str:
"""Search drugs by brand name, generic name, or YJ code."""
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
if kind == "yj":
cur.execute(
"SELECT yj_full, yj_code, brand_name, generic_name, "
"category_code, category_name FROM drug_master "
"WHERE yj_code ILIKE %s OR yj_full ILIKE %s LIMIT %s",
(f"%{query}%", f"%{query}%", limit),
)
elif kind == "brand":
cur.execute(
"SELECT yj_full, yj_code, brand_name, generic_name, "
"category_code, category_name, "
"similarity(brand_name, %s) AS score "
"FROM drug_master "
"WHERE brand_name ILIKE %s ORDER BY score DESC LIMIT %s",
(query, f"%{query}%", limit),
)
elif kind == "generic":
cur.execute(
"SELECT yj_full, yj_code, brand_name, generic_name, "
"category_code, category_name, "
"similarity(generic_name, %s) AS score "
"FROM drug_master "
"WHERE generic_name ILIKE %s ORDER BY score DESC LIMIT %s",
(query, f"%{query}%", limit),
)
else: # auto
cur.execute(
"SELECT yj_full, yj_code, brand_name, generic_name, "
"category_code, category_name, "
"GREATEST("
" similarity(brand_name, %s),"
" similarity(generic_name, %s)"
") AS score "
"FROM drug_master "
"WHERE brand_name ILIKE %s OR generic_name ILIKE %s "
" OR yj_code ILIKE %s OR yj_full ILIKE %s "
"ORDER BY score DESC LIMIT %s",
(query, query, f"%{query}%", f"%{query}%", f"%{query}%", f"%{query}%", limit),
)
rows = cur.fetchall()
return _dump([
{
"yj_full": r.get("yj_full", ""),
"yj_code": r.get("yj_code", ""),
"brand": r.get("brand_name", ""),
"generic": r.get("generic_name", ""),
"category": f"{r.get('category_code', '')} {r.get('category_name', '')}".strip(),
"score": float(r.get("score", 0)) if r.get("score") else 0.0,
}
for r in rows
])
def _tool_list_categories() -> str:
"""List all L1/L2 drug categories with drug counts."""
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT c.category_code, c.category_name, c.level, "
"COUNT(m.yj_code) AS drug_count "
"FROM drug_category c "
"LEFT JOIN drug_master m ON m.category_code = c.category_code "
"WHERE c.level IN ('L1', 'L2') "
"GROUP BY c.category_code, c.category_name, c.level "
"ORDER BY c.category_code"
)
rows = cur.fetchall()
return _dump([dict(r) for r in rows])
def _tool_list_drugs_in_category(l2_code: str, limit_generics: int = 50) -> str:
"""List drugs under a specific L2 category code."""
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT generic_name, json_agg(json_build_object("
" 'yj_code', yj_code, 'brand_name', brand_name, 'yj_full', yj_full"
")) AS brands "
"FROM drug_master "
"WHERE category_code ILIKE %s "
"GROUP BY generic_name "
"ORDER BY generic_name LIMIT %s",
(f"{l2_code}%", limit_generics),
)
rows = cur.fetchall()
return _dump([{"generic_name": r["generic_name"], "brands": r["brands"]} for r in rows])
def _tool_get_drug_master(yj_code: str) -> str:
"""Get basic info for a drug by yj_code."""
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
"SELECT * FROM drug_master WHERE yj_code = %s LIMIT 1",
(yj_code,),
)
row = cur.fetchone()
if not row:
return _dump({"error": f"yj_code {yj_code} not found"})
d = dict(row)
d["_citation"] = f"[出典: {row.get('brand_name', '')} (yj_full={row.get('yj_full', '')}) / 添付文書冒頭]"
return _dump(d)
def _tool_get_drug_interactions(
drug_a_yj: Optional[str] = None,
drug_b_yj: Optional[str] = None,
severity: Optional[str] = None,
keyword: Optional[str] = None,
limit: int = 30,
) -> str:
"""Search drug_interaction table."""
conditions = []
params = []
if drug_a_yj:
conditions.append("drug_a_yj = %s")
params.append(drug_a_yj)
if drug_b_yj:
conditions.append("(drug_b_yj = %s OR drug_a_yj = %s)")
params.extend([drug_b_yj, drug_b_yj])
if severity:
conditions.append("severity = %s")
params.append(severity)
if keyword:
conditions.append("(drug_b_class ILIKE %s OR mechanism ILIKE %s OR clinical_effect ILIKE %s)")
k = f"%{keyword}%"
params.extend([k, k, k])
where = " AND ".join(conditions) if conditions else "1=1"
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
f"SELECT * FROM drug_interaction WHERE {where} LIMIT %s",
(*params, limit),
)
rows = cur.fetchall()
return _dump([
{
"drug_a_yj": r.get("drug_a_yj"),
"drug_b_yj": r.get("drug_b_yj"),
"drug_b_class": r.get("drug_b_class"),
"severity": r.get("severity"),
"mechanism": r.get("mechanism"),
"clinical_effect": r.get("clinical_effect"),
"source_drug_yj": r.get("source_drug_yj"),
"source_section": r.get("source_section"),
"_citation": _citation(r.get("source_drug_yj", ""), r.get("source_section")),
}
for r in rows
])
def _tool_get_drug_restrictions(
drug_yj: Optional[str] = None,
condition_type: Optional[str] = None,
severity: Optional[str] = None,
keyword: Optional[str] = None,
limit: int = 30,
) -> str:
"""Search drug_restriction table."""
conditions = []
params = []
if drug_yj:
conditions.append("drug_yj = %s")
params.append(drug_yj)
if condition_type:
conditions.append("condition_type = %s")
params.append(condition_type)
if severity:
conditions.append("severity = %s")
params.append(severity)
if keyword:
conditions.append("condition_text ILIKE %s")
params.append(f"%{keyword}%")
where = " AND ".join(conditions) if conditions else "1=1"
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
f"SELECT * FROM drug_restriction WHERE {where} LIMIT %s",
(*params, limit),
)
rows = cur.fetchall()
return _dump([
{
"drug_yj": r.get("drug_yj"),
"condition_type": r.get("condition_type"),
"condition_text": r.get("condition_text"),
"condition_params": r.get("condition_params"),
"severity": r.get("severity"),
"source_section": r.get("source_section"),
"_citation": _citation(r.get("drug_yj", ""), r.get("source_section")),
}
for r in rows
])
def _tool_get_drug_dosing(
drug_yj: str,
patient_segment: Optional[str] = None,
limit: int = 20,
) -> str:
"""Search drug_dosing table."""
conditions = ["drug_yj = %s"]
params = [drug_yj]
if patient_segment:
conditions.append("patient_segment = %s")
params.append(patient_segment)
where = " AND ".join(conditions)
conn = _get_pg()
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(
f"SELECT * FROM drug_dosing WHERE {where} LIMIT %s",
(*params, limit),
)
rows = cur.fetchall()
return _dump([
{
"patient_segment": r.get("patient_segment"),
"segment_params": r.get("segment_params"),
"indication_code": r.get("indication_code"),
"dose_amount": r.get("dose_amount"),
"dose_unit": r.get("dose_unit"),
"frequency": r.get("frequency"),
"duration": r.get("duration"),
"adjustment_text": r.get("adjustment_text"),
"source_section": r.get("source_section"),
"_citation": _citation(drug_yj, r.get("source_section")),
}
for r in rows
])
def _tool_search_section_text(
keyword: str,
section_filter: str = "",
limit: int = 30,
) -> str:
"""Full-text search in OpenSearch pmda_sections index."""
if not keyword.strip():
return _dump({"keyword": keyword, "total_drugs": 0, "shown": 0, "hits": []})
size = min(max(1, limit), 100)
body: Dict[str, Any] = {
"size": size,
"_source": ["yj_full", "brand_names", "generic_name", "l2_code", "l2_name", "section_title", "line_num"],
"query": {"bool": {"must": [{"match": {"text": keyword}}]}},
"collapse": {
"field": "yj_full",
"inner_hits": {
"name": "matches",
"size": 2,
"_source": ["section_title", "line_num"],
"highlight": {"fields": {"text": {"fragment_size": 160, "number_of_fragments": 1}}},
},
},
"aggs": {"total_drugs": {"cardinality": {"field": "yj_full"}}},
}
if section_filter:
body["query"]["bool"]["filter"] = [
{"wildcard": {"section_title.raw": f"*{section_filter}*"}}
]
client = _get_os()
resp = client.search(index=OS_INDEX, body=body)
total = int(resp["aggregations"]["total_drugs"]["value"])
hits_out = []
for h in resp["hits"]["hits"]:
src = h.get("_source") or {}
inner = h.get("inner_hits", {}).get("matches", {}).get("hits", {}).get("hits", [])
matches = []
seen = set()
for ih in inner:
ih_src = ih.get("_source") or {}
title = ih_src.get("section_title") or ""
if title in seen:
continue
seen.add(title)
hl = ih.get("highlight", {}).get("text", [""])
matches.append({"section_title": title, "snippet": hl[0] if hl else ""})
brand = (src.get("brand_names") or [""])[0]
yj_full = src.get("yj_full") or ""
hits_out.append({
"yj_full": yj_full,
"brand": brand,
"generic": src.get("generic_name") or "",
"l2": f"{src.get('l2_code') or ''} {src.get('l2_name') or ''}".strip(),
"matches": matches,
"_citation_template": f"[出典: {brand} (yj_full={yj_full}) / <該当章>]",
})
out = {
"keyword": keyword,
"section_filter": section_filter or None,
"total_drugs": total,
"shown": len(hits_out),
"hits": hits_out,
}
if total > len(hits_out):
out["_more_count"] = total - len(hits_out)
return _dump(out)
def _tool_list_drug_chapters(yj_full: str) -> str:
"""List all chapter titles for a drug's package insert."""
client = _get_os()
body = {
"size": 200,
"_source": ["yj_full", "brand_names", "generic_name", "section_title", "line_num"],
"query": {"term": {"yj_full": yj_full}},
"sort": [{"line_num": {"order": "asc"}}],
}
resp = client.search(index=OS_INDEX, body=body)
hits = resp["hits"]["hits"]
if not hits:
return _dump({"error": f"yj_full {yj_full} の章節が見つかりません。"})
sections = []
for h in hits:
src = h.get("_source") or {}
# Calculate text length from _score or use stored field
sections.append({
"section_title": src.get("section_title", ""),
"line_num": src.get("line_num", 0),
"text_len": 0, # not available from list query
})
head = hits[0].get("_source") or {}
return _dump({
"yj_full": yj_full,
"brand": (head.get("brand_names") or [""])[0],
"generic": head.get("generic_name", ""),
"n_sections": len(sections),
"sections": sections,
})
def _tool_read_drug_chapter(yj_full: str, section_title: str) -> str:
"""Read verbatim text of a specific chapter."""
client = _get_os()
body = {
"size": 1,
"_source": ["text", "section_title"],
"query": {
"bool": {
"must": [
{"term": {"yj_full": yj_full}},
{"term": {"section_title.keyword": section_title}},
]
}
},
}
resp = client.search(index=OS_INDEX, body=body)
hits = resp["hits"]["hits"]
if hits:
text = hits[0].get("_source", {}).get("text", "")
if text:
return text[:8000]
# Fallback: try match instead of term for section_title
body["query"]["bool"]["must"][1] = {"match_phrase": {"section_title": section_title}}
resp = client.search(index=OS_INDEX, body=body)
hits = resp["hits"]["hits"]
if hits:
text = hits[0].get("_source", {}).get("text", "")
if text:
return text[:8000]
# Not found — suggest listing chapters
return _dump({
"error": f"section_title {section_title!r}{yj_full} に存在しません。",
"hint": "list_drug_chapters で取得した sections[].section_title をそのまま渡してください。",
})
# ---------------------------------------------------------------------------
# MCP request handler
# ---------------------------------------------------------------------------
# Map tool names to their implementation functions
_TOOL_DISPATCH = {
"search_drugs": lambda args: _tool_search_drugs(
query=args.get("query", ""),
kind=args.get("kind", "auto"),
limit=args.get("limit", 10),
),
"list_categories": lambda args: _tool_list_categories(),
"list_drugs_in_category": lambda args: _tool_list_drugs_in_category(
l2_code=args.get("l2_code", ""),
limit_generics=args.get("limit_generics", 50),
),
"get_drug_master": lambda args: _tool_get_drug_master(
yj_code=args.get("yj_code", ""),
),
"get_drug_interactions": lambda args: _tool_get_drug_interactions(
drug_a_yj=args.get("drug_a_yj"),
drug_b_yj=args.get("drug_b_yj"),
severity=args.get("severity"),
keyword=args.get("keyword"),
limit=args.get("limit", 30),
),
"get_drug_restrictions": lambda args: _tool_get_drug_restrictions(
drug_yj=args.get("drug_yj"),
condition_type=args.get("condition_type"),
severity=args.get("severity"),
keyword=args.get("keyword"),
limit=args.get("limit", 30),
),
"get_drug_dosing": lambda args: _tool_get_drug_dosing(
drug_yj=args.get("drug_yj", ""),
patient_segment=args.get("patient_segment"),
limit=args.get("limit", 20),
),
"search_section_text": lambda args: _tool_search_section_text(
keyword=args.get("keyword", ""),
section_filter=args.get("section_filter", ""),
limit=args.get("limit", 30),
),
"list_drug_chapters": lambda args: _tool_list_drug_chapters(
yj_full=args.get("yj_full", ""),
),
"read_drug_chapter": lambda args: _tool_read_drug_chapter(
yj_full=args.get("yj_full", ""),
section_title=args.get("section_title", ""),
),
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an MCP request."""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "pmda-drug-info")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
tools = load_tools_from_json("pmda_tools.json")
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name not in _TOOL_DISPATCH:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
try:
result_text = _TOOL_DISPATCH[tool_name](arguments)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [{"type": "text", "text": result_text}]
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"content": [{"type": "text", "text": f"Error: {str(e)}"}]
},
}
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())