qwen_agent/skills/developing/pmda-drug-info/queries.py
2026-06-12 11:03:30 +08:00

428 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""SQL 查询接口(关系小库侧)。
Phase 1 已承接:
- `search_drugs` 的 商品名 / 一般名 / YJ 子串检索
- `list_categories` 的 L1/L2 + drug_count
- `list_drugs_in_category` 的 一般名 → 販売名
后续 Phase 2 会接 drug_interaction / drug_restriction / drug_dosing。
"""
from __future__ import annotations
import re
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from taxonomy import load_taxonomy
from db import session
# Plugin 自包含: drug_category.md 与 queries.py 同目录
_TAXONOMY_PATH = Path(__file__).resolve().parent / "drug_category.md"
_TAXONOMY_CACHE = None
def _taxonomy():
global _TAXONOMY_CACHE
if _TAXONOMY_CACHE is None:
_TAXONOMY_CACHE = load_taxonomy(_TAXONOMY_PATH)
return _TAXONOMY_CACHE
# 12 字母数字 → YJ code 候选;前几位即足够触发自动 kind=yj 的判断
_YJ_RE = re.compile(r"^[0-9A-Z]{4,12}$")
@dataclass(frozen=True)
class DrugHit:
yj_full: str
yj_code: str
brand_name: str # "/" 分隔多品名
generic_name: str
category_code: str
category_name: str
score: float # 50-100
def _detect_kind(q: str) -> str:
"""auto-detect: pure alnum & uppercase 4+ chars → yj, otherwise any."""
if _YJ_RE.match(q.upper()):
return "yj"
return "any"
def _score_expr(q_lower: str, q_like: str) -> str:
"""Postgres expression returning relevance score 50100."""
# NB: doubles each pattern; psycopg expands %s positionally so caller
# must pass q_lower / q_like in matching repetitions.
return (
"GREATEST("
" CASE WHEN lower(brand_name) = %s THEN 100.0 "
" WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0 "
" WHEN brand_name ILIKE %s THEN 70.0 ELSE 0 END,"
" CASE WHEN lower(generic_name_jp) = %s THEN 95.0 "
" WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0 "
" WHEN generic_name_jp ILIKE %s THEN 65.0 ELSE 0 END,"
" CASE WHEN yj_code = %s THEN 100.0 ELSE 0 END"
")"
)
def search_drugs_in_db(
query: str,
*,
kind: str = "auto",
limit: int = 20,
) -> list[DrugHit]:
"""Drop-in replacement for the in-memory ``CorpusIndex.search``.
`kind` ∈ {"auto", "brand", "generic", "yj"}.
Returns DrugHit list (max ``limit``) ordered by relevance score desc.
"""
q = (query or "").strip()
if not q:
return []
if kind == "auto":
kind = _detect_kind(q)
q_lower = q.lower()
q_like = f"%{q}%"
q_upper = q.upper()
if kind == "yj":
sql = """
SELECT yj_full, yj_code, brand_name, generic_name_jp,
category_code, category_name,
CASE WHEN yj_code = %s THEN 100.0
WHEN yj_full LIKE %s || '%%' THEN 95.0
ELSE 80.0 END AS score
FROM drug_master
WHERE yj_code LIKE %s OR yj_full LIKE %s
ORDER BY score DESC, yj_full ASC
LIMIT %s
"""
params = (q_upper, q_upper, f"{q_upper}%", f"{q_upper}%", limit)
elif kind == "brand":
sql = """
SELECT yj_full, yj_code, brand_name, generic_name_jp,
category_code, category_name,
CASE WHEN lower(brand_name) = %s THEN 100.0
WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0
ELSE 70.0 END AS score
FROM drug_master
WHERE brand_name ILIKE %s
ORDER BY score DESC, length(brand_name) ASC, yj_full ASC
LIMIT %s
"""
params = (q_lower, q_lower, q_like, limit)
elif kind == "generic":
sql = """
SELECT yj_full, yj_code, brand_name, generic_name_jp,
category_code, category_name,
CASE WHEN lower(generic_name_jp) = %s THEN 95.0
WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0
ELSE 65.0 END AS score
FROM drug_master
WHERE generic_name_jp ILIKE %s
ORDER BY score DESC, length(generic_name_jp) ASC, yj_full ASC
LIMIT %s
"""
params = (q_lower, q_lower, q_like, limit)
else: # any
sql = f"""
SELECT yj_full, yj_code, brand_name, generic_name_jp,
category_code, category_name,
{_score_expr(q_lower, q_like)} AS score
FROM drug_master
WHERE brand_name ILIKE %s OR generic_name_jp ILIKE %s
OR yj_code LIKE %s OR yj_full LIKE %s
ORDER BY score DESC, length(brand_name) ASC, yj_full ASC
LIMIT %s
"""
# _score_expr 占位符顺序brand=, brand LIKE, brand ILIKE,
# generic=, generic LIKE, generic ILIKE, yj_code=
# 然后 WHERE: brand ILIKE, generic ILIKE, yj LIKE, yj_full LIKE
params = (
q_lower, q_lower, q_like,
q_lower, q_lower, q_like,
q_upper,
q_like, q_like, f"{q_upper}%", f"{q_upper}%",
limit,
)
with session() as conn, conn.cursor() as cur:
cur.execute(sql, params)
rows = cur.fetchall()
return [
DrugHit(
yj_full=r[0],
yj_code=r[1],
brand_name=r[2] or "",
generic_name=r[3] or "",
category_code=r[4] or "",
category_name=r[5] or "",
score=float(r[6] or 0),
)
for r in rows
]
# ---- 类别导航 ------------------------------------------------------------
def list_categories_with_counts() -> list[dict]:
"""全 L1 / L2 分类 + 各 L2 的 drug 数。
分类层级名取自 drug_category.md不用 PMDA 的 category_name 自由文,
因为后者一药一表达难以聚合drug_count 取自 drug_master 的实际行数。
"""
tax = _taxonomy()
with session() as conn, conn.cursor() as cur:
cur.execute(
"SELECT category_code, COUNT(*) FROM drug_master "
"WHERE category_code IS NOT NULL "
"GROUP BY category_code"
)
counts: dict[str, int] = dict(cur.fetchall())
by_l1: dict[str, dict] = {}
for l2_code, l2 in tax.items():
c = counts.get(l2_code, 0)
if c == 0:
continue
l1 = by_l1.setdefault(
l2.l1_code,
{"l1_code": l2.l1_code, "l1_name": l2.l1_name, "l2": []},
)
l1["l2"].append({"code": l2_code, "name": l2.name, "drug_count": c})
# 内层按 code 排序,外层按 l1_code 排序
for l1 in by_l1.values():
l1["l2"].sort(key=lambda x: x["code"])
return [by_l1[k] for k in sorted(by_l1)]
def list_drugs_in_category(
l2_code: str,
*,
limit_generics: int = 50,
brands_per_generic: int = 5,
) -> dict:
"""指定 L2 类目下的「一般名 → [販売名]」一览。
Returns the same JSON shape `_corpus_tools.list_drugs_in_category` previously
yielded so the agent prompt 不变。
"""
tax = _taxonomy()
l2 = tax.get(l2_code)
with session() as conn, conn.cursor() as cur:
cur.execute(
"SELECT generic_name_jp, yj_full, brand_name "
"FROM drug_master WHERE category_code = %s "
"ORDER BY generic_name_jp, yj_full",
(l2_code,),
)
rows = cur.fetchall()
by_gen: "OrderedDict[str, list[dict]]" = OrderedDict()
for gen, yj_full, brand in rows:
by_gen.setdefault(gen or "(一般名不明)", []).append(
{"brand": brand or "", "yj_full": yj_full}
)
payload: list[dict] = []
for gen in list(by_gen)[:limit_generics]:
drugs = by_gen[gen]
shown = drugs[:brands_per_generic]
extra = len(drugs) - len(shown)
entry = {"generic": gen, "drugs": list(shown)}
if extra > 0:
entry["drugs"].append({"_more": f"+{extra} more brands"})
payload.append(entry)
out = {
"l2_code": l2_code,
"l2_name": l2.name if l2 else "",
"generics": payload,
}
if len(by_gen) > limit_generics:
out["_more_generics"] = len(by_gen) - limit_generics
return out
# ---- fact 查询drug_master / interaction / restriction / dosing ----------
@dataclass(frozen=True)
class DrugMasterRow:
yj_code: str
yj_full: str
brand_name: str
generic_name_jp: str
category_code: str
category_name: str
regulation: str | None
manufacturer: str | None
revision_date: str | None # ISO date string
def drug_master_get(yj_code: str) -> DrugMasterRow | None:
with session() as conn, conn.cursor() as cur:
cur.execute(
"SELECT yj_code, yj_full, brand_name, generic_name_jp, "
" category_code, category_name, regulation, manufacturer, "
" to_char(revision_date, 'YYYY-MM-DD') "
"FROM drug_master WHERE yj_code = %s",
(yj_code,),
)
row = cur.fetchone()
if not row:
return None
return DrugMasterRow(*row)
@dataclass(frozen=True)
class InteractionRow:
id: str
drug_a_yj: str
drug_b_yj: str | None
drug_b_class: str | None
severity: str
mechanism: str | None
clinical_effect: str | None
source_section: str
source_drug_yj: str
def drug_interaction_query(
drug_a_yj: str | None = None,
drug_b_yj: str | None = None,
*,
severity: str | None = None,
keyword: str | None = None,
limit: int = 50,
) -> list[InteractionRow]:
"""検索条件:
drug_a_yj alone → drug_a の全相互作用drug_b 任意)
drug_a_yj + drug_b_yj → 双向A→B もしくは B→A 両方)
keyword → drug_b_class や mechanism / clinical_effect の ILIKE
"""
where = []
params: list = []
if drug_a_yj and drug_b_yj:
where.append("((drug_a_yj=%s AND drug_b_yj=%s) OR "
"(drug_a_yj=%s AND drug_b_yj=%s))")
params += [drug_a_yj, drug_b_yj, drug_b_yj, drug_a_yj]
elif drug_a_yj:
where.append("drug_a_yj = %s")
params.append(drug_a_yj)
elif drug_b_yj:
where.append("drug_b_yj = %s")
params.append(drug_b_yj)
if severity:
where.append("severity = %s")
params.append(severity)
if keyword:
where.append("(drug_b_class ILIKE %s OR mechanism ILIKE %s "
" OR clinical_effect ILIKE %s)")
kw = f"%{keyword}%"
params += [kw, kw, kw]
if not where:
return []
sql = (
"SELECT id, drug_a_yj, drug_b_yj, drug_b_class, severity, "
" mechanism, clinical_effect, source_section, source_drug_yj "
"FROM drug_interaction WHERE " + " AND ".join(where) +
" ORDER BY severity, drug_b_class NULLS LAST LIMIT %s"
)
params.append(limit)
with session() as conn, conn.cursor() as cur:
cur.execute(sql, params)
return [InteractionRow(*r) for r in cur.fetchall()]
@dataclass(frozen=True)
class RestrictionRow:
id: str
drug_yj: str
condition_type: str
condition_text: str
condition_params: dict
severity: str
source_section: str
def drug_restriction_query(
drug_yj: str | None = None,
*,
condition_type: str | None = None,
severity: str | None = None,
keyword: str | None = None,
limit: int = 50,
) -> list[RestrictionRow]:
where = []
params: list = []
if drug_yj:
where.append("drug_yj = %s")
params.append(drug_yj)
if condition_type:
where.append("condition_type = %s")
params.append(condition_type)
if severity:
where.append("severity = %s")
params.append(severity)
if keyword:
where.append("condition_text ILIKE %s")
params.append(f"%{keyword}%")
if not where:
return []
sql = (
"SELECT id, drug_yj, condition_type, condition_text, condition_params, "
" severity, source_section "
"FROM drug_restriction WHERE " + " AND ".join(where) +
" ORDER BY severity, condition_type LIMIT %s"
)
params.append(limit)
with session() as conn, conn.cursor() as cur:
cur.execute(sql, params)
return [RestrictionRow(*r) for r in cur.fetchall()]
@dataclass(frozen=True)
class DosingRow:
id: str
drug_yj: str
indication_code: str | None
patient_segment: str
segment_params: dict
dose_amount: float | None
dose_unit: str | None
frequency: str | None
duration: str | None
adjustment_text: str
source_section: str
def drug_dosing_query(
drug_yj: str,
*,
patient_segment: str | None = None,
limit: int = 30,
) -> list[DosingRow]:
where = ["drug_yj = %s"]
params: list = [drug_yj]
if patient_segment:
where.append("patient_segment = %s")
params.append(patient_segment)
sql = (
"SELECT id, drug_yj, indication_code, patient_segment, segment_params, "
" dose_amount, dose_unit, frequency, duration, adjustment_text, "
" source_section "
"FROM drug_dosing WHERE " + " AND ".join(where) +
" ORDER BY patient_segment, indication_code NULLS LAST LIMIT %s"
)
params.append(limit)
with session() as conn, conn.cursor() as cur:
cur.execute(sql, params)
return [DosingRow(*r) for r in cur.fetchall()]