428 lines
13 KiB
Python
428 lines
13 KiB
Python
"""SQL 查询接口(关系小库侧)。
|
||
|
||
Phase 1 已承接:
|
||
- `search_drugs` 的 商品名 / 一般名 / YJ 子串检索
|
||
- `list_categories` 的 L1/L2 + drug_count
|
||
- `list_drugs_in_category` 的 一般名 → 販売名
|
||
|
||
后续 Phase 2 会接 drug_interaction / drug_restriction / drug_dosing。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from collections import OrderedDict
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from taxonomy import load_taxonomy
|
||
from db import session
|
||
|
||
# Plugin 自包含: drug_category.md 与 queries.py 同目录
|
||
_TAXONOMY_PATH = Path(__file__).resolve().parent / "drug_category.md"
|
||
_TAXONOMY_CACHE = None
|
||
|
||
|
||
def _taxonomy():
|
||
global _TAXONOMY_CACHE
|
||
if _TAXONOMY_CACHE is None:
|
||
_TAXONOMY_CACHE = load_taxonomy(_TAXONOMY_PATH)
|
||
return _TAXONOMY_CACHE
|
||
|
||
# 12 字母数字 → YJ code 候选;前几位即足够触发自动 kind=yj 的判断
|
||
_YJ_RE = re.compile(r"^[0-9A-Z]{4,12}$")
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class DrugHit:
|
||
yj_full: str
|
||
yj_code: str
|
||
brand_name: str # "/" 分隔多品名
|
||
generic_name: str
|
||
category_code: str
|
||
category_name: str
|
||
score: float # 50-100
|
||
|
||
|
||
def _detect_kind(q: str) -> str:
|
||
"""auto-detect: pure alnum & uppercase 4+ chars → yj, otherwise any."""
|
||
if _YJ_RE.match(q.upper()):
|
||
return "yj"
|
||
return "any"
|
||
|
||
|
||
def _score_expr(q_lower: str, q_like: str) -> str:
|
||
"""Postgres expression returning relevance score 50–100."""
|
||
# NB: doubles each pattern; psycopg expands %s positionally so caller
|
||
# must pass q_lower / q_like in matching repetitions.
|
||
return (
|
||
"GREATEST("
|
||
" CASE WHEN lower(brand_name) = %s THEN 100.0 "
|
||
" WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0 "
|
||
" WHEN brand_name ILIKE %s THEN 70.0 ELSE 0 END,"
|
||
" CASE WHEN lower(generic_name_jp) = %s THEN 95.0 "
|
||
" WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0 "
|
||
" WHEN generic_name_jp ILIKE %s THEN 65.0 ELSE 0 END,"
|
||
" CASE WHEN yj_code = %s THEN 100.0 ELSE 0 END"
|
||
")"
|
||
)
|
||
|
||
|
||
def search_drugs_in_db(
|
||
query: str,
|
||
*,
|
||
kind: str = "auto",
|
||
limit: int = 20,
|
||
) -> list[DrugHit]:
|
||
"""Drop-in replacement for the in-memory ``CorpusIndex.search``.
|
||
|
||
`kind` ∈ {"auto", "brand", "generic", "yj"}.
|
||
|
||
Returns DrugHit list (max ``limit``) ordered by relevance score desc.
|
||
"""
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return []
|
||
if kind == "auto":
|
||
kind = _detect_kind(q)
|
||
|
||
q_lower = q.lower()
|
||
q_like = f"%{q}%"
|
||
q_upper = q.upper()
|
||
|
||
if kind == "yj":
|
||
sql = """
|
||
SELECT yj_full, yj_code, brand_name, generic_name_jp,
|
||
category_code, category_name,
|
||
CASE WHEN yj_code = %s THEN 100.0
|
||
WHEN yj_full LIKE %s || '%%' THEN 95.0
|
||
ELSE 80.0 END AS score
|
||
FROM drug_master
|
||
WHERE yj_code LIKE %s OR yj_full LIKE %s
|
||
ORDER BY score DESC, yj_full ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (q_upper, q_upper, f"{q_upper}%", f"{q_upper}%", limit)
|
||
elif kind == "brand":
|
||
sql = """
|
||
SELECT yj_full, yj_code, brand_name, generic_name_jp,
|
||
category_code, category_name,
|
||
CASE WHEN lower(brand_name) = %s THEN 100.0
|
||
WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0
|
||
ELSE 70.0 END AS score
|
||
FROM drug_master
|
||
WHERE brand_name ILIKE %s
|
||
ORDER BY score DESC, length(brand_name) ASC, yj_full ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (q_lower, q_lower, q_like, limit)
|
||
elif kind == "generic":
|
||
sql = """
|
||
SELECT yj_full, yj_code, brand_name, generic_name_jp,
|
||
category_code, category_name,
|
||
CASE WHEN lower(generic_name_jp) = %s THEN 95.0
|
||
WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0
|
||
ELSE 65.0 END AS score
|
||
FROM drug_master
|
||
WHERE generic_name_jp ILIKE %s
|
||
ORDER BY score DESC, length(generic_name_jp) ASC, yj_full ASC
|
||
LIMIT %s
|
||
"""
|
||
params = (q_lower, q_lower, q_like, limit)
|
||
else: # any
|
||
sql = f"""
|
||
SELECT yj_full, yj_code, brand_name, generic_name_jp,
|
||
category_code, category_name,
|
||
{_score_expr(q_lower, q_like)} AS score
|
||
FROM drug_master
|
||
WHERE brand_name ILIKE %s OR generic_name_jp ILIKE %s
|
||
OR yj_code LIKE %s OR yj_full LIKE %s
|
||
ORDER BY score DESC, length(brand_name) ASC, yj_full ASC
|
||
LIMIT %s
|
||
"""
|
||
# _score_expr 占位符顺序:brand=, brand LIKE, brand ILIKE,
|
||
# generic=, generic LIKE, generic ILIKE, yj_code=
|
||
# 然后 WHERE: brand ILIKE, generic ILIKE, yj LIKE, yj_full LIKE
|
||
params = (
|
||
q_lower, q_lower, q_like,
|
||
q_lower, q_lower, q_like,
|
||
q_upper,
|
||
q_like, q_like, f"{q_upper}%", f"{q_upper}%",
|
||
limit,
|
||
)
|
||
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
rows = cur.fetchall()
|
||
|
||
return [
|
||
DrugHit(
|
||
yj_full=r[0],
|
||
yj_code=r[1],
|
||
brand_name=r[2] or "",
|
||
generic_name=r[3] or "",
|
||
category_code=r[4] or "",
|
||
category_name=r[5] or "",
|
||
score=float(r[6] or 0),
|
||
)
|
||
for r in rows
|
||
]
|
||
|
||
|
||
# ---- 类别导航 ------------------------------------------------------------
|
||
|
||
|
||
def list_categories_with_counts() -> list[dict]:
|
||
"""全 L1 / L2 分类 + 各 L2 的 drug 数。
|
||
|
||
分类层级名取自 drug_category.md(不用 PMDA 的 category_name 自由文,
|
||
因为后者一药一表达,难以聚合);drug_count 取自 drug_master 的实际行数。
|
||
"""
|
||
tax = _taxonomy()
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT category_code, COUNT(*) FROM drug_master "
|
||
"WHERE category_code IS NOT NULL "
|
||
"GROUP BY category_code"
|
||
)
|
||
counts: dict[str, int] = dict(cur.fetchall())
|
||
|
||
by_l1: dict[str, dict] = {}
|
||
for l2_code, l2 in tax.items():
|
||
c = counts.get(l2_code, 0)
|
||
if c == 0:
|
||
continue
|
||
l1 = by_l1.setdefault(
|
||
l2.l1_code,
|
||
{"l1_code": l2.l1_code, "l1_name": l2.l1_name, "l2": []},
|
||
)
|
||
l1["l2"].append({"code": l2_code, "name": l2.name, "drug_count": c})
|
||
# 内层按 code 排序,外层按 l1_code 排序
|
||
for l1 in by_l1.values():
|
||
l1["l2"].sort(key=lambda x: x["code"])
|
||
return [by_l1[k] for k in sorted(by_l1)]
|
||
|
||
|
||
def list_drugs_in_category(
|
||
l2_code: str,
|
||
*,
|
||
limit_generics: int = 50,
|
||
brands_per_generic: int = 5,
|
||
) -> dict:
|
||
"""指定 L2 类目下的「一般名 → [販売名]」一览。
|
||
|
||
Returns the same JSON shape `_corpus_tools.list_drugs_in_category` previously
|
||
yielded so the agent prompt 不变。
|
||
"""
|
||
tax = _taxonomy()
|
||
l2 = tax.get(l2_code)
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT generic_name_jp, yj_full, brand_name "
|
||
"FROM drug_master WHERE category_code = %s "
|
||
"ORDER BY generic_name_jp, yj_full",
|
||
(l2_code,),
|
||
)
|
||
rows = cur.fetchall()
|
||
|
||
by_gen: "OrderedDict[str, list[dict]]" = OrderedDict()
|
||
for gen, yj_full, brand in rows:
|
||
by_gen.setdefault(gen or "(一般名不明)", []).append(
|
||
{"brand": brand or "", "yj_full": yj_full}
|
||
)
|
||
|
||
payload: list[dict] = []
|
||
for gen in list(by_gen)[:limit_generics]:
|
||
drugs = by_gen[gen]
|
||
shown = drugs[:brands_per_generic]
|
||
extra = len(drugs) - len(shown)
|
||
entry = {"generic": gen, "drugs": list(shown)}
|
||
if extra > 0:
|
||
entry["drugs"].append({"_more": f"+{extra} more brands"})
|
||
payload.append(entry)
|
||
|
||
out = {
|
||
"l2_code": l2_code,
|
||
"l2_name": l2.name if l2 else "",
|
||
"generics": payload,
|
||
}
|
||
if len(by_gen) > limit_generics:
|
||
out["_more_generics"] = len(by_gen) - limit_generics
|
||
return out
|
||
|
||
|
||
# ---- fact 查询:drug_master / interaction / restriction / dosing ----------
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class DrugMasterRow:
|
||
yj_code: str
|
||
yj_full: str
|
||
brand_name: str
|
||
generic_name_jp: str
|
||
category_code: str
|
||
category_name: str
|
||
regulation: str | None
|
||
manufacturer: str | None
|
||
revision_date: str | None # ISO date string
|
||
|
||
|
||
def drug_master_get(yj_code: str) -> DrugMasterRow | None:
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT yj_code, yj_full, brand_name, generic_name_jp, "
|
||
" category_code, category_name, regulation, manufacturer, "
|
||
" to_char(revision_date, 'YYYY-MM-DD') "
|
||
"FROM drug_master WHERE yj_code = %s",
|
||
(yj_code,),
|
||
)
|
||
row = cur.fetchone()
|
||
if not row:
|
||
return None
|
||
return DrugMasterRow(*row)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class InteractionRow:
|
||
id: str
|
||
drug_a_yj: str
|
||
drug_b_yj: str | None
|
||
drug_b_class: str | None
|
||
severity: str
|
||
mechanism: str | None
|
||
clinical_effect: str | None
|
||
source_section: str
|
||
source_drug_yj: str
|
||
|
||
|
||
def drug_interaction_query(
|
||
drug_a_yj: str | None = None,
|
||
drug_b_yj: str | None = None,
|
||
*,
|
||
severity: str | None = None,
|
||
keyword: str | None = None,
|
||
limit: int = 50,
|
||
) -> list[InteractionRow]:
|
||
"""検索条件:
|
||
drug_a_yj alone → drug_a の全相互作用(drug_b 任意)
|
||
drug_a_yj + drug_b_yj → 双向(A→B もしくは B→A 両方)
|
||
keyword → drug_b_class や mechanism / clinical_effect の ILIKE
|
||
"""
|
||
where = []
|
||
params: list = []
|
||
if drug_a_yj and drug_b_yj:
|
||
where.append("((drug_a_yj=%s AND drug_b_yj=%s) OR "
|
||
"(drug_a_yj=%s AND drug_b_yj=%s))")
|
||
params += [drug_a_yj, drug_b_yj, drug_b_yj, drug_a_yj]
|
||
elif drug_a_yj:
|
||
where.append("drug_a_yj = %s")
|
||
params.append(drug_a_yj)
|
||
elif drug_b_yj:
|
||
where.append("drug_b_yj = %s")
|
||
params.append(drug_b_yj)
|
||
if severity:
|
||
where.append("severity = %s")
|
||
params.append(severity)
|
||
if keyword:
|
||
where.append("(drug_b_class ILIKE %s OR mechanism ILIKE %s "
|
||
" OR clinical_effect ILIKE %s)")
|
||
kw = f"%{keyword}%"
|
||
params += [kw, kw, kw]
|
||
if not where:
|
||
return []
|
||
sql = (
|
||
"SELECT id, drug_a_yj, drug_b_yj, drug_b_class, severity, "
|
||
" mechanism, clinical_effect, source_section, source_drug_yj "
|
||
"FROM drug_interaction WHERE " + " AND ".join(where) +
|
||
" ORDER BY severity, drug_b_class NULLS LAST LIMIT %s"
|
||
)
|
||
params.append(limit)
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
return [InteractionRow(*r) for r in cur.fetchall()]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class RestrictionRow:
|
||
id: str
|
||
drug_yj: str
|
||
condition_type: str
|
||
condition_text: str
|
||
condition_params: dict
|
||
severity: str
|
||
source_section: str
|
||
|
||
|
||
def drug_restriction_query(
|
||
drug_yj: str | None = None,
|
||
*,
|
||
condition_type: str | None = None,
|
||
severity: str | None = None,
|
||
keyword: str | None = None,
|
||
limit: int = 50,
|
||
) -> list[RestrictionRow]:
|
||
where = []
|
||
params: list = []
|
||
if drug_yj:
|
||
where.append("drug_yj = %s")
|
||
params.append(drug_yj)
|
||
if condition_type:
|
||
where.append("condition_type = %s")
|
||
params.append(condition_type)
|
||
if severity:
|
||
where.append("severity = %s")
|
||
params.append(severity)
|
||
if keyword:
|
||
where.append("condition_text ILIKE %s")
|
||
params.append(f"%{keyword}%")
|
||
if not where:
|
||
return []
|
||
sql = (
|
||
"SELECT id, drug_yj, condition_type, condition_text, condition_params, "
|
||
" severity, source_section "
|
||
"FROM drug_restriction WHERE " + " AND ".join(where) +
|
||
" ORDER BY severity, condition_type LIMIT %s"
|
||
)
|
||
params.append(limit)
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
return [RestrictionRow(*r) for r in cur.fetchall()]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class DosingRow:
|
||
id: str
|
||
drug_yj: str
|
||
indication_code: str | None
|
||
patient_segment: str
|
||
segment_params: dict
|
||
dose_amount: float | None
|
||
dose_unit: str | None
|
||
frequency: str | None
|
||
duration: str | None
|
||
adjustment_text: str
|
||
source_section: str
|
||
|
||
|
||
def drug_dosing_query(
|
||
drug_yj: str,
|
||
*,
|
||
patient_segment: str | None = None,
|
||
limit: int = 30,
|
||
) -> list[DosingRow]:
|
||
where = ["drug_yj = %s"]
|
||
params: list = [drug_yj]
|
||
if patient_segment:
|
||
where.append("patient_segment = %s")
|
||
params.append(patient_segment)
|
||
sql = (
|
||
"SELECT id, drug_yj, indication_code, patient_segment, segment_params, "
|
||
" dose_amount, dose_unit, frequency, duration, adjustment_text, "
|
||
" source_section "
|
||
"FROM drug_dosing WHERE " + " AND ".join(where) +
|
||
" ORDER BY patient_segment, indication_code NULLS LAST LIMIT %s"
|
||
)
|
||
params.append(limit)
|
||
with session() as conn, conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
return [DosingRow(*r) for r in cur.fetchall()]
|