qwen_agent/skills/support/kfs-answer/scripts/query.py
2026-05-06 19:39:53 +08:00

518 lines
19 KiB
Python

"""Budget-aware auto query for knowledge files.
Usage: python3 query.py <file_id1:sheet_id1>,... <question> <kw1> <kw2> ...
Keywords are separate positional arguments (not comma-separated).
For db-type sheets: keyword SQL with budget control (COUNT → sample → select columns → LIMIT).
For markdown-type sheets: keyword section matching within budget.
Output: TSV (or markdown section) followed by a `[CITATIONS]` block with pre-built
<CITATION file="F1" filename="..." sheet="N" rows="[...]" /> tags. The `__src`
column is consumed internally and stripped from visible output — agent should
preserve and place CITATION tags near the data they cite.
datasets directory: ./datasets/ (gbase-agent-service) or ./dataset/ (catalog-agent), auto-detected at runtime.
dataset_ids are discovered automatically from subdirectories under datasets directory.
"""
import os
import re
import sqlite3
import sys
import yaml
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from _session import get_session_dir
# Derive project root from script location: scripts/ → kfs-answer/ → skills/ → project root
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_ds = os.path.join(_PROJECT_ROOT, "datasets")
DATASETS_DIR = _ds if os.path.isdir(_ds) else os.path.join(_PROJECT_ROOT, "dataset")
def _discover_datasets():
"""Scan DATASETS_DIR for subdirectory names (each is a dataset_id)."""
if not os.path.isdir(DATASETS_DIR):
return []
return [d for d in sorted(os.listdir(DATASETS_DIR))
if os.path.isdir(os.path.join(DATASETS_DIR, d))]
def load_file_ref_map():
"""Load file_id → (F{n}, filename) mapping from file_refs.txt (in session dir)."""
refs_path = os.path.join(get_session_dir(), "file_refs.txt")
mapping = {} # file_id → (f_code, filename)
if not os.path.isfile(refs_path):
return mapping
ref_pat = re.compile(r"^(F\d+)=([0-9a-f-]+)\((.+?)\)\s*$")
with open(refs_path, "r", encoding="utf-8") as f:
for line in f:
m = ref_pat.match(line.strip())
if m:
mapping[m.group(2)] = (m.group(1), m.group(3))
return mapping
# Row-level: cell value stored as `__src="F0S1R5"` (xls-agent-parse wraps it),
# so non-anchored + re.search() extracts the triple. Sheet-level comes from
# extract_sheet_src which returns bare `F0S1`, so that one stays anchored.
SRC_ROW_PAT = re.compile(r"F(\d+)S(\d+)R(\d+)")
SRC_SHEET_PAT = re.compile(r"^F(\d+)S(\d+)$")
def _format_citation(file_id, sheet_num, filename, row_nums=None):
"""Build one CITATION tag. file_id=UUID. row_nums=None → sheet-level (no rows attr)."""
fn_attr = f' filename="{filename}"' if filename else ""
if row_nums is None:
return f'<CITATION file="{file_id}"{fn_attr} sheet="{sheet_num}" />'
rows_str = "[" + ", ".join(str(r) for r in row_nums) + "]"
return f'<CITATION file="{file_code}"{fn_attr} sheet="{sheet_num}" rows="{rows_str}" />'
def replace_f0(text, f_code):
"""Replace F0 with assigned f_code (e.g., F1) in __src values."""
if not f_code or f_code == "F0":
return text
return text.replace('F0S', f'{f_code}S')
def extract_sheet_src(body, sheet_id):
"""Extract __src value from <!-- sheet_xxx __src="F0S1" --> marker. Returns empty string if not found."""
m = re.search(rf'<!--\s*{re.escape(sheet_id)}\s+__src="([^"]*)"', body)
return m.group(1) if m else ""
def calc_budget(entry_count):
"""Character budget per entry, scaled by total count."""
if entry_count <= 1:
return 3000
elif entry_count == 2:
return 1800
else:
return 1200
def find_file_dir(dataset_ids, file_id):
for dataset_id in dataset_ids:
candidate = os.path.join(DATASETS_DIR, dataset_id, file_id)
if os.path.isdir(candidate):
return candidate
return None
def load_knowledge_meta(file_dir):
km_path = os.path.join(file_dir, "knowledge.md")
if not os.path.isfile(km_path):
return None, None
with open(km_path, "r", encoding="utf-8") as f:
content = f.read()
if not content.startswith("---"):
return None, None
parts = content.split("---", 2)
if len(parts) < 3:
return None, None
meta = yaml.safe_load(parts[1])
body = parts[2].strip()
return meta, body
def query_db_sheet(db_path, table_name, columns, keywords, budget):
"""Budget-aware SQLite query: COUNT → sample → select columns → LIMIT."""
if not os.path.isfile(db_path):
return {"error": f"DB not found: {db_path}"}
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
# Get actual table columns from DB
try:
cursor = conn.execute(f'PRAGMA table_info("{table_name}")')
db_cols = [r["name"] for r in cursor.fetchall()]
except Exception:
conn.close()
return {"error": f"Table not found: {table_name}"}
if not db_cols:
conn.close()
return {"error": f"Table has no columns: {table_name}"}
# Build WHERE clause from keywords — pick the MOST SELECTIVE keyword
where_clause = None
where_params = None
used_kw = None
valid_kws = [kw for kw in keywords if kw]
if valid_kws:
# Try each keyword individually, pick the one with smallest COUNT > 0
best_kw = None
best_count = float("inf")
best_wc = None
best_wp = None
for kw in valid_kws:
conditions = [f'"{c}" LIKE ?' for c in db_cols]
wc = " OR ".join(conditions)
wp = [f"%{kw}%"] * len(db_cols)
try:
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {wc}', wp)
cnt = cursor.fetchone()[0]
if 0 < cnt < best_count:
best_kw, best_count, best_wc, best_wp = kw, cnt, wc, wp
except Exception:
continue
if best_kw:
where_clause, where_params, used_kw = best_wc, best_wp, best_kw
# Fallback: all keywords OR (if no single keyword matched)
if where_clause is None and valid_kws:
conditions = []
params = []
for kw in valid_kws:
for c in db_cols:
conditions.append(f'"{c}" LIKE ?')
params.append(f"%{kw}%")
wc = " OR ".join(conditions)
try:
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {wc}', params)
cnt = cursor.fetchone()[0]
if cnt > 0:
where_clause, where_params = wc, params
except Exception:
pass
# Fallback: no keywords → get all
if where_clause is None and not valid_kws:
where_clause = "1=1"
where_params = []
if where_clause is None:
conn.close()
return {"rows": [], "total": 0, "note": "No keyword match"}
# COUNT total
try:
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {where_clause}', where_params)
total_rows = cursor.fetchone()[0]
except Exception:
conn.close()
return {"error": "COUNT failed"}
if total_rows == 0:
conn.close()
return {"rows": [], "total": 0}
# Sample 5 rows to estimate column widths
try:
cursor = conn.execute(f'SELECT * FROM "{table_name}" WHERE {where_clause} LIMIT 5', where_params)
sample_rows = [dict(r) for r in cursor.fetchall()]
except Exception:
conn.close()
return {"error": "Sample query failed"}
col_avg_chars = {}
for col in db_cols:
total_chars = sum(len(str(row.get(col, "") or "")) for row in sample_rows)
col_avg_chars[col] = total_chars / max(len(sample_rows), 1)
# Budget check: decide columns and LIMIT
header_overhead = len("\t".join(db_cols)) + 50
avg_row_chars = sum(col_avg_chars.values()) + len(db_cols)
estimated_total = header_overhead + total_rows * avg_row_chars
if estimated_total <= budget:
select_cols = db_cols
limit = total_rows
else:
# __src column always retained (source marker for citation)
# Prioritize keyword-hit columns, then fill remaining budget
must_cols = [c for c in db_cols if c == "__src"]
keyword_cols = []
other_cols = []
for col in db_cols:
if col == "__src":
continue
is_kw_col = any(
any(kw.lower() in str(row.get(col, "") or "").lower() for row in sample_rows)
for kw in keywords if kw
)
if is_kw_col:
keyword_cols.append(col)
else:
other_cols.append(col)
select_cols = must_cols + (keyword_cols[:] if keyword_cols else [])
def row_width(cols):
return sum(col_avg_chars.get(c, 10) for c in cols) + len(cols)
for col in other_cols:
test_cols = select_cols + [col]
if header_overhead + 5 * row_width(test_cols) <= budget:
select_cols.append(col)
if not select_cols:
select_cols = db_cols
rw = row_width(select_cols)
available = budget - header_overhead
limit = max(1, int(available / max(rw, 1)))
limit = min(limit, total_rows)
# Execute
cols_str = ", ".join(f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_str} FROM "{table_name}" WHERE {where_clause} LIMIT {limit}'
try:
cursor = conn.execute(sql, where_params)
rows = [dict(r) for r in cursor.fetchall()]
except Exception:
conn.close()
return {"error": "Query execution failed"}
conn.close()
result = {
"table": table_name,
"columns": select_cols,
"all_columns": db_cols,
"rows": rows,
"count": len(rows),
"total": total_rows,
"fields_reduced": len(select_cols) < len(db_cols),
"rows_limited": limit < total_rows,
"keyword": used_kw,
"db_path": db_path,
}
return result
def query_markdown_sheet(body, sheet_id, keywords, budget):
"""Keyword-based section matching for markdown sheets."""
parts = re.split(r"<!--\s*sheet_\w+(?:\s+[^>]*)?\s*-->", body)
markers = re.findall(r"<!--\s*(sheet_\w+)(?:\s+[^>]*)?\s*-->", body)
section = ""
for i, marker in enumerate(markers):
if marker == sheet_id and i < len(parts) - 1:
section = parts[i + 1].strip()
break
if not section and len(markers) == 0 and len(parts) == 1:
section = body.strip()
if not section:
return {"content": "", "note": "No content found"}
if len(section) <= budget:
return {"content": section, "full": True}
# Keyword-based line matching with context
lines = section.split("\n")
matched_indices = set()
context = 3
for i, line in enumerate(lines):
for kw in keywords:
if kw and kw.lower() in line.lower():
for j in range(max(0, i - context), min(len(lines), i + context + 1)):
matched_indices.add(j)
if not matched_indices:
truncated = section[:budget]
last_nl = truncated.rfind("\n")
if last_nl > budget * 0.7:
truncated = truncated[:last_nl]
return {"content": truncated, "note": f"[No keyword match. First {len(truncated)} chars of {len(section)}]"}
result_lines = []
chars = 0
prev = -2
for idx in sorted(matched_indices):
line = lines[idx]
line_chars = len(line) + 1
if chars + line_chars > budget:
break
if idx > prev + 1:
result_lines.append("---")
chars += 4
result_lines.append(line)
chars += line_chars
prev = idx
return {"content": "\n".join(result_lines), "matched_lines": len(result_lines)}
def main():
# Auto-discover datasets from ./dataset/ or ./datasets/ subdirectories
dataset_ids = _discover_datasets()
raw_entries = [e.strip() for e in sys.argv[1].split(",") if e.strip()]
question = sys.argv[2] if len(sys.argv) > 2 else ""
keywords = sys.argv[3:] # remaining positional args are keywords
entries = []
for entry in raw_entries:
if ":" in entry:
fid, sid = entry.split(":", 1)
entries.append((fid.strip(), sid.strip()))
else:
entries.append((entry.strip(), None))
per_entry_budget = calc_budget(len(entries))
print(f"[Budget: {per_entry_budget} chars/entry, {len(entries)} entries]")
# Load F0→F{n} mapping from search.py
f_ref_map = load_file_ref_map() # file_id → f_code
for fid, target_sheet_id in entries:
file_dir = find_file_dir(dataset_ids, fid)
if not file_dir:
print(f"\n{'='*60}")
print(f"file_id: {fid}")
print(f" ERROR: not found")
continue
meta, body = load_knowledge_meta(file_dir)
if not meta:
print(f"\n{'='*60}")
print(f"file_id: {fid}")
print(f" ERROR: knowledge.md invalid")
continue
source_name = meta.get("source_name", "unknown")
sheets_meta = {s["id"]: s for s in meta.get("sheets", [])}
if target_sheet_id and target_sheet_id in sheets_meta:
target_sheets = [(target_sheet_id, sheets_meta[target_sheet_id])]
elif target_sheet_id:
print(f"\n{'='*60}")
print(f"file_id: {fid}, sheet: {target_sheet_id}")
print(f" ERROR: sheet not found")
continue
else:
target_sheets = list(sheets_meta.items())
for sid, sheet in target_sheets:
stype = sheet.get("type", "unknown")
sname = sheet.get("name", "?")
print(f"\n{'='*60}")
print(f"file_id: {fid} / {sid}: {sname} [{stype}]")
print(f"source: {source_name}")
if stype == "db":
db_path = os.path.join(file_dir, "knowledge.db")
db_table = sheet.get("db_table", sid)
print(f"db: {db_path}, table: {db_table}")
result = query_db_sheet(db_path, db_table, sheet.get("columns", []),
keywords, per_entry_budget)
if "error" in result:
print(f" ERROR: {result['error']}")
continue
if not result.get("rows"):
print(f" No matching rows (total: {result.get('total', 0)})")
continue
budget_info = []
if result.get("fields_reduced"):
budget_info.append(f"fields: {len(result['columns'])}/{len(result['all_columns'])}")
if result.get("rows_limited"):
budget_info.append(f"rows: {result['count']}/{result['total']}")
budget_str = f" [BUDGET: {', '.join(budget_info)}]" if budget_info else ""
# Hide __src from both COLUMNS report and TSV display — it's consumed
# into CITATION tags below.
display_cols = [c for c in result["columns"] if c != "__src"]
biz_all_cols = [c for c in result["all_columns"] if c != "__src"]
omitted = [c for c in biz_all_cols if c not in display_cols]
print(f" TABLE: {result['table']} ({result['count']}/{result['total']} rows){budget_str}")
print(f" COLUMNS: {', '.join(display_cols)}")
if omitted:
print(f" OMITTED: {len(omitted)} columns")
if result.get("keyword"):
print(f" KEYWORD: {result['keyword']}")
# Warn if keyword filtering returned suspiciously few rows
if result.get("keyword") and result["count"] <= 3 and result["total"] >= 10:
print(f" ⚠ NOTE: keyword \"{result['keyword']}\" matched only {result['count']}/{result['total']} rows. Results may be incomplete — consider removing this keyword.")
print("-" * 40)
# TSV output — __src stripped, collected into src_groups for CITATIONS.
f_entry = f_ref_map.get(fid, ("", ""))
f_code, filename = f_entry
cols = result["columns"]
has_src = "__src" in cols
print("\t".join(display_cols))
src_groups = {} # (file_code, sheet_num) -> set of row_nums
for row in result["rows"]:
if has_src:
raw_src = row.get("__src")
if raw_src:
src_val = str(raw_src)
if f_code:
src_val = replace_f0(src_val, f_code)
m = SRC_ROW_PAT.search(src_val)
if m:
file_code = f"F{m.group(1)}"
sheet_num = int(m.group(2))
row_num = int(m.group(3))
src_groups.setdefault((file_code, sheet_num), set()).add(row_num)
vals = []
for c in display_cols:
v = row.get(c)
s = "" if v is None else str(v)
if len(s) > 200:
s = s[:200] + "..."
vals.append(s)
print("\t".join(vals))
if src_groups:
import json
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
with open(citations_path, "a", encoding="utf-8") as cf:
for (file_code, sheet_num) in sorted(src_groups.keys()):
row_nums = sorted(src_groups[(file_code, sheet_num)])
cf.write(json.dumps({
"file": fid, "filename": filename,
"sheet": sheet_num, "rows": row_nums,
"source": "query",
}, ensure_ascii=False) + "\n")
elif stype == "markdown":
if not body:
print(f" ERROR: no body content")
continue
src_tag = extract_sheet_src(body, sid) if body else ""
f_entry = f_ref_map.get(fid, ("", ""))
f_code, filename = f_entry
if src_tag and f_code:
src_tag = replace_f0(src_tag, f_code)
result = query_markdown_sheet(body, sid, keywords, per_entry_budget)
if result.get("note"):
print(f" {result['note']}")
print("-" * 40)
print(result.get("content", ""))
# Write sheet-level citation to session file (no stdout).
if src_tag:
m = SRC_SHEET_PAT.match(src_tag)
if m:
import json
sheet_num = int(m.group(2))
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
with open(citations_path, "a", encoding="utf-8") as cf:
cf.write(json.dumps({
"file": fid, "filename": filename,
"sheet": sheet_num, "rows": [],
"source": "query",
}, ensure_ascii=False) + "\n")
else:
print(f" ERROR: unknown type '{stype}'")
print(f"\n{'='*60}")
print(f"Done. Queried {len(entries)} entries.")
if __name__ == "__main__":
main()