"""SQL 查询接口(关系小库侧)。 Phase 1 已承接: - `search_drugs` 的 商品名 / 一般名 / YJ 子串检索 - `list_categories` 的 L1/L2 + drug_count - `list_drugs_in_category` 的 一般名 → 販売名 后续 Phase 2 会接 drug_interaction / drug_restriction / drug_dosing。 """ from __future__ import annotations import re from collections import OrderedDict from dataclasses import dataclass from pathlib import Path from taxonomy import load_taxonomy from db import session # Plugin 自包含: drug_category.md 与 queries.py 同目录 _TAXONOMY_PATH = Path(__file__).resolve().parent / "drug_category.md" _TAXONOMY_CACHE = None def _taxonomy(): global _TAXONOMY_CACHE if _TAXONOMY_CACHE is None: _TAXONOMY_CACHE = load_taxonomy(_TAXONOMY_PATH) return _TAXONOMY_CACHE # 12 字母数字 → YJ code 候选;前几位即足够触发自动 kind=yj 的判断 _YJ_RE = re.compile(r"^[0-9A-Z]{4,12}$") @dataclass(frozen=True) class DrugHit: yj_full: str yj_code: str brand_name: str # "/" 分隔多品名 generic_name: str category_code: str category_name: str score: float # 50-100 def _detect_kind(q: str) -> str: """auto-detect: pure alnum & uppercase 4+ chars → yj, otherwise any.""" if _YJ_RE.match(q.upper()): return "yj" return "any" def _score_expr(q_lower: str, q_like: str) -> str: """Postgres expression returning relevance score 50–100.""" # NB: doubles each pattern; psycopg expands %s positionally so caller # must pass q_lower / q_like in matching repetitions. return ( "GREATEST(" " CASE WHEN lower(brand_name) = %s THEN 100.0 " " WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0 " " WHEN brand_name ILIKE %s THEN 70.0 ELSE 0 END," " CASE WHEN lower(generic_name_jp) = %s THEN 95.0 " " WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0 " " WHEN generic_name_jp ILIKE %s THEN 65.0 ELSE 0 END," " CASE WHEN yj_code = %s THEN 100.0 ELSE 0 END" ")" ) def search_drugs_in_db( query: str, *, kind: str = "auto", limit: int = 20, ) -> list[DrugHit]: """Drop-in replacement for the in-memory ``CorpusIndex.search``. `kind` ∈ {"auto", "brand", "generic", "yj"}. Returns DrugHit list (max ``limit``) ordered by relevance score desc. """ q = (query or "").strip() if not q: return [] if kind == "auto": kind = _detect_kind(q) q_lower = q.lower() q_like = f"%{q}%" q_upper = q.upper() if kind == "yj": sql = """ SELECT yj_full, yj_code, brand_name, generic_name_jp, category_code, category_name, CASE WHEN yj_code = %s THEN 100.0 WHEN yj_full LIKE %s || '%%' THEN 95.0 ELSE 80.0 END AS score FROM drug_master WHERE yj_code LIKE %s OR yj_full LIKE %s ORDER BY score DESC, yj_full ASC LIMIT %s """ params = (q_upper, q_upper, f"{q_upper}%", f"{q_upper}%", limit) elif kind == "brand": sql = """ SELECT yj_full, yj_code, brand_name, generic_name_jp, category_code, category_name, CASE WHEN lower(brand_name) = %s THEN 100.0 WHEN lower(brand_name) LIKE %s || '%%' THEN 90.0 ELSE 70.0 END AS score FROM drug_master WHERE brand_name ILIKE %s ORDER BY score DESC, length(brand_name) ASC, yj_full ASC LIMIT %s """ params = (q_lower, q_lower, q_like, limit) elif kind == "generic": sql = """ SELECT yj_full, yj_code, brand_name, generic_name_jp, category_code, category_name, CASE WHEN lower(generic_name_jp) = %s THEN 95.0 WHEN lower(generic_name_jp) LIKE %s || '%%' THEN 85.0 ELSE 65.0 END AS score FROM drug_master WHERE generic_name_jp ILIKE %s ORDER BY score DESC, length(generic_name_jp) ASC, yj_full ASC LIMIT %s """ params = (q_lower, q_lower, q_like, limit) else: # any sql = f""" SELECT yj_full, yj_code, brand_name, generic_name_jp, category_code, category_name, {_score_expr(q_lower, q_like)} AS score FROM drug_master WHERE brand_name ILIKE %s OR generic_name_jp ILIKE %s OR yj_code LIKE %s OR yj_full LIKE %s ORDER BY score DESC, length(brand_name) ASC, yj_full ASC LIMIT %s """ # _score_expr 占位符顺序:brand=, brand LIKE, brand ILIKE, # generic=, generic LIKE, generic ILIKE, yj_code= # 然后 WHERE: brand ILIKE, generic ILIKE, yj LIKE, yj_full LIKE params = ( q_lower, q_lower, q_like, q_lower, q_lower, q_like, q_upper, q_like, q_like, f"{q_upper}%", f"{q_upper}%", limit, ) with session() as conn, conn.cursor() as cur: cur.execute(sql, params) rows = cur.fetchall() return [ DrugHit( yj_full=r[0], yj_code=r[1], brand_name=r[2] or "", generic_name=r[3] or "", category_code=r[4] or "", category_name=r[5] or "", score=float(r[6] or 0), ) for r in rows ] # ---- 类别导航 ------------------------------------------------------------ def list_categories_with_counts() -> list[dict]: """全 L1 / L2 分类 + 各 L2 的 drug 数。 分类层级名取自 drug_category.md(不用 PMDA 的 category_name 自由文, 因为后者一药一表达,难以聚合);drug_count 取自 drug_master 的实际行数。 """ tax = _taxonomy() with session() as conn, conn.cursor() as cur: cur.execute( "SELECT category_code, COUNT(*) FROM drug_master " "WHERE category_code IS NOT NULL " "GROUP BY category_code" ) counts: dict[str, int] = dict(cur.fetchall()) by_l1: dict[str, dict] = {} for l2_code, l2 in tax.items(): c = counts.get(l2_code, 0) if c == 0: continue l1 = by_l1.setdefault( l2.l1_code, {"l1_code": l2.l1_code, "l1_name": l2.l1_name, "l2": []}, ) l1["l2"].append({"code": l2_code, "name": l2.name, "drug_count": c}) # 内层按 code 排序,外层按 l1_code 排序 for l1 in by_l1.values(): l1["l2"].sort(key=lambda x: x["code"]) return [by_l1[k] for k in sorted(by_l1)] def list_drugs_in_category( l2_code: str, *, limit_generics: int = 50, brands_per_generic: int = 5, ) -> dict: """指定 L2 类目下的「一般名 → [販売名]」一览。 Returns the same JSON shape `_corpus_tools.list_drugs_in_category` previously yielded so the agent prompt 不变。 """ tax = _taxonomy() l2 = tax.get(l2_code) with session() as conn, conn.cursor() as cur: cur.execute( "SELECT generic_name_jp, yj_full, brand_name " "FROM drug_master WHERE category_code = %s " "ORDER BY generic_name_jp, yj_full", (l2_code,), ) rows = cur.fetchall() by_gen: "OrderedDict[str, list[dict]]" = OrderedDict() for gen, yj_full, brand in rows: by_gen.setdefault(gen or "(一般名不明)", []).append( {"brand": brand or "", "yj_full": yj_full} ) payload: list[dict] = [] for gen in list(by_gen)[:limit_generics]: drugs = by_gen[gen] shown = drugs[:brands_per_generic] extra = len(drugs) - len(shown) entry = {"generic": gen, "drugs": list(shown)} if extra > 0: entry["drugs"].append({"_more": f"+{extra} more brands"}) payload.append(entry) out = { "l2_code": l2_code, "l2_name": l2.name if l2 else "", "generics": payload, } if len(by_gen) > limit_generics: out["_more_generics"] = len(by_gen) - limit_generics return out # ---- fact 查询:drug_master / interaction / restriction / dosing ---------- @dataclass(frozen=True) class DrugMasterRow: yj_code: str yj_full: str brand_name: str generic_name_jp: str category_code: str category_name: str regulation: str | None manufacturer: str | None revision_date: str | None # ISO date string def drug_master_get(yj_code: str) -> DrugMasterRow | None: with session() as conn, conn.cursor() as cur: cur.execute( "SELECT yj_code, yj_full, brand_name, generic_name_jp, " " category_code, category_name, regulation, manufacturer, " " to_char(revision_date, 'YYYY-MM-DD') " "FROM drug_master WHERE yj_code = %s", (yj_code,), ) row = cur.fetchone() if not row: return None return DrugMasterRow(*row) @dataclass(frozen=True) class InteractionRow: id: str drug_a_yj: str drug_b_yj: str | None drug_b_class: str | None severity: str mechanism: str | None clinical_effect: str | None source_section: str source_drug_yj: str def drug_interaction_query( drug_a_yj: str | None = None, drug_b_yj: str | None = None, *, severity: str | None = None, keyword: str | None = None, limit: int = 50, ) -> list[InteractionRow]: """検索条件: drug_a_yj alone → drug_a の全相互作用(drug_b 任意) drug_a_yj + drug_b_yj → 双向(A→B もしくは B→A 両方) keyword → drug_b_class や mechanism / clinical_effect の ILIKE """ where = [] params: list = [] if drug_a_yj and drug_b_yj: where.append("((drug_a_yj=%s AND drug_b_yj=%s) OR " "(drug_a_yj=%s AND drug_b_yj=%s))") params += [drug_a_yj, drug_b_yj, drug_b_yj, drug_a_yj] elif drug_a_yj: where.append("drug_a_yj = %s") params.append(drug_a_yj) elif drug_b_yj: where.append("drug_b_yj = %s") params.append(drug_b_yj) if severity: where.append("severity = %s") params.append(severity) if keyword: where.append("(drug_b_class ILIKE %s OR mechanism ILIKE %s " " OR clinical_effect ILIKE %s)") kw = f"%{keyword}%" params += [kw, kw, kw] if not where: return [] sql = ( "SELECT id, drug_a_yj, drug_b_yj, drug_b_class, severity, " " mechanism, clinical_effect, source_section, source_drug_yj " "FROM drug_interaction WHERE " + " AND ".join(where) + " ORDER BY severity, drug_b_class NULLS LAST LIMIT %s" ) params.append(limit) with session() as conn, conn.cursor() as cur: cur.execute(sql, params) return [InteractionRow(*r) for r in cur.fetchall()] @dataclass(frozen=True) class RestrictionRow: id: str drug_yj: str condition_type: str condition_text: str condition_params: dict severity: str source_section: str def drug_restriction_query( drug_yj: str | None = None, *, condition_type: str | None = None, severity: str | None = None, keyword: str | None = None, limit: int = 50, ) -> list[RestrictionRow]: where = [] params: list = [] if drug_yj: where.append("drug_yj = %s") params.append(drug_yj) if condition_type: where.append("condition_type = %s") params.append(condition_type) if severity: where.append("severity = %s") params.append(severity) if keyword: where.append("condition_text ILIKE %s") params.append(f"%{keyword}%") if not where: return [] sql = ( "SELECT id, drug_yj, condition_type, condition_text, condition_params, " " severity, source_section " "FROM drug_restriction WHERE " + " AND ".join(where) + " ORDER BY severity, condition_type LIMIT %s" ) params.append(limit) with session() as conn, conn.cursor() as cur: cur.execute(sql, params) return [RestrictionRow(*r) for r in cur.fetchall()] @dataclass(frozen=True) class DosingRow: id: str drug_yj: str indication_code: str | None patient_segment: str segment_params: dict dose_amount: float | None dose_unit: str | None frequency: str | None duration: str | None adjustment_text: str source_section: str def drug_dosing_query( drug_yj: str, *, patient_segment: str | None = None, limit: int = 30, ) -> list[DosingRow]: where = ["drug_yj = %s"] params: list = [drug_yj] if patient_segment: where.append("patient_segment = %s") params.append(patient_segment) sql = ( "SELECT id, drug_yj, indication_code, patient_segment, segment_params, " " dose_amount, dose_unit, frequency, duration, adjustment_text, " " source_section " "FROM drug_dosing WHERE " + " AND ".join(where) + " ORDER BY patient_segment, indication_code NULLS LAST LIMIT %s" ) params.append(limit) with session() as conn, conn.cursor() as cur: cur.execute(sql, params) return [DosingRow(*r) for r in cur.fetchall()]