518 lines
19 KiB
Python
518 lines
19 KiB
Python
"""Budget-aware auto query for knowledge files.
|
|
|
|
Usage: python3 query.py <file_id1:sheet_id1>,... <question> <kw1> <kw2> ...
|
|
|
|
Keywords are separate positional arguments (not comma-separated).
|
|
|
|
For db-type sheets: keyword SQL with budget control (COUNT → sample → select columns → LIMIT).
|
|
For markdown-type sheets: keyword section matching within budget.
|
|
|
|
Output: TSV (or markdown section) followed by a `[CITATIONS]` block with pre-built
|
|
<CITATION file="F1" filename="..." sheet="N" rows="[...]" /> tags. The `__src`
|
|
column is consumed internally and stripped from visible output — agent should
|
|
preserve and place CITATION tags near the data they cite.
|
|
|
|
datasets directory: ./datasets/ (gbase-agent-service) or ./dataset/ (catalog-agent), auto-detected at runtime.
|
|
dataset_ids are discovered automatically from subdirectories under datasets directory.
|
|
"""
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
import sys
|
|
|
|
import yaml
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from _session import get_session_dir
|
|
|
|
# Derive project root from script location: scripts/ → kfs-answer/ → skills/ → project root
|
|
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
_ds = os.path.join(_PROJECT_ROOT, "datasets")
|
|
DATASETS_DIR = _ds if os.path.isdir(_ds) else os.path.join(_PROJECT_ROOT, "dataset")
|
|
|
|
|
|
def _discover_datasets():
|
|
"""Scan DATASETS_DIR for subdirectory names (each is a dataset_id)."""
|
|
if not os.path.isdir(DATASETS_DIR):
|
|
return []
|
|
return [d for d in sorted(os.listdir(DATASETS_DIR))
|
|
if os.path.isdir(os.path.join(DATASETS_DIR, d))]
|
|
|
|
|
|
def load_file_ref_map():
|
|
"""Load file_id → (F{n}, filename) mapping from file_refs.txt (in session dir)."""
|
|
refs_path = os.path.join(get_session_dir(), "file_refs.txt")
|
|
mapping = {} # file_id → (f_code, filename)
|
|
if not os.path.isfile(refs_path):
|
|
return mapping
|
|
ref_pat = re.compile(r"^(F\d+)=([0-9a-f-]+)\((.+?)\)\s*$")
|
|
with open(refs_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
m = ref_pat.match(line.strip())
|
|
if m:
|
|
mapping[m.group(2)] = (m.group(1), m.group(3))
|
|
return mapping
|
|
|
|
|
|
# Row-level: cell value stored as `__src="F0S1R5"` (xls-agent-parse wraps it),
|
|
# so non-anchored + re.search() extracts the triple. Sheet-level comes from
|
|
# extract_sheet_src which returns bare `F0S1`, so that one stays anchored.
|
|
SRC_ROW_PAT = re.compile(r"F(\d+)S(\d+)R(\d+)")
|
|
SRC_SHEET_PAT = re.compile(r"^F(\d+)S(\d+)$")
|
|
|
|
|
|
def _format_citation(file_id, sheet_num, filename, row_nums=None):
|
|
"""Build one CITATION tag. file_id=UUID. row_nums=None → sheet-level (no rows attr)."""
|
|
fn_attr = f' filename="{filename}"' if filename else ""
|
|
if row_nums is None:
|
|
return f'<CITATION file="{file_id}"{fn_attr} sheet="{sheet_num}" />'
|
|
rows_str = "[" + ", ".join(str(r) for r in row_nums) + "]"
|
|
return f'<CITATION file="{file_code}"{fn_attr} sheet="{sheet_num}" rows="{rows_str}" />'
|
|
|
|
|
|
def replace_f0(text, f_code):
|
|
"""Replace F0 with assigned f_code (e.g., F1) in __src values."""
|
|
if not f_code or f_code == "F0":
|
|
return text
|
|
return text.replace('F0S', f'{f_code}S')
|
|
|
|
|
|
def extract_sheet_src(body, sheet_id):
|
|
"""Extract __src value from <!-- sheet_xxx __src="F0S1" --> marker. Returns empty string if not found."""
|
|
m = re.search(rf'<!--\s*{re.escape(sheet_id)}\s+__src="([^"]*)"', body)
|
|
return m.group(1) if m else ""
|
|
|
|
|
|
def calc_budget(entry_count):
|
|
"""Character budget per entry, scaled by total count."""
|
|
if entry_count <= 1:
|
|
return 3000
|
|
elif entry_count == 2:
|
|
return 1800
|
|
else:
|
|
return 1200
|
|
|
|
|
|
def find_file_dir(dataset_ids, file_id):
|
|
for dataset_id in dataset_ids:
|
|
candidate = os.path.join(DATASETS_DIR, dataset_id, file_id)
|
|
if os.path.isdir(candidate):
|
|
return candidate
|
|
return None
|
|
|
|
|
|
def load_knowledge_meta(file_dir):
|
|
km_path = os.path.join(file_dir, "knowledge.md")
|
|
if not os.path.isfile(km_path):
|
|
return None, None
|
|
with open(km_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
if not content.startswith("---"):
|
|
return None, None
|
|
parts = content.split("---", 2)
|
|
if len(parts) < 3:
|
|
return None, None
|
|
meta = yaml.safe_load(parts[1])
|
|
body = parts[2].strip()
|
|
return meta, body
|
|
|
|
|
|
def query_db_sheet(db_path, table_name, columns, keywords, budget):
|
|
"""Budget-aware SQLite query: COUNT → sample → select columns → LIMIT."""
|
|
if not os.path.isfile(db_path):
|
|
return {"error": f"DB not found: {db_path}"}
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
# Get actual table columns from DB
|
|
try:
|
|
cursor = conn.execute(f'PRAGMA table_info("{table_name}")')
|
|
db_cols = [r["name"] for r in cursor.fetchall()]
|
|
except Exception:
|
|
conn.close()
|
|
return {"error": f"Table not found: {table_name}"}
|
|
|
|
if not db_cols:
|
|
conn.close()
|
|
return {"error": f"Table has no columns: {table_name}"}
|
|
|
|
# Build WHERE clause from keywords — pick the MOST SELECTIVE keyword
|
|
where_clause = None
|
|
where_params = None
|
|
used_kw = None
|
|
|
|
valid_kws = [kw for kw in keywords if kw]
|
|
|
|
if valid_kws:
|
|
# Try each keyword individually, pick the one with smallest COUNT > 0
|
|
best_kw = None
|
|
best_count = float("inf")
|
|
best_wc = None
|
|
best_wp = None
|
|
for kw in valid_kws:
|
|
conditions = [f'"{c}" LIKE ?' for c in db_cols]
|
|
wc = " OR ".join(conditions)
|
|
wp = [f"%{kw}%"] * len(db_cols)
|
|
try:
|
|
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {wc}', wp)
|
|
cnt = cursor.fetchone()[0]
|
|
if 0 < cnt < best_count:
|
|
best_kw, best_count, best_wc, best_wp = kw, cnt, wc, wp
|
|
except Exception:
|
|
continue
|
|
|
|
if best_kw:
|
|
where_clause, where_params, used_kw = best_wc, best_wp, best_kw
|
|
|
|
# Fallback: all keywords OR (if no single keyword matched)
|
|
if where_clause is None and valid_kws:
|
|
conditions = []
|
|
params = []
|
|
for kw in valid_kws:
|
|
for c in db_cols:
|
|
conditions.append(f'"{c}" LIKE ?')
|
|
params.append(f"%{kw}%")
|
|
wc = " OR ".join(conditions)
|
|
try:
|
|
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {wc}', params)
|
|
cnt = cursor.fetchone()[0]
|
|
if cnt > 0:
|
|
where_clause, where_params = wc, params
|
|
except Exception:
|
|
pass
|
|
|
|
# Fallback: no keywords → get all
|
|
if where_clause is None and not valid_kws:
|
|
where_clause = "1=1"
|
|
where_params = []
|
|
|
|
if where_clause is None:
|
|
conn.close()
|
|
return {"rows": [], "total": 0, "note": "No keyword match"}
|
|
|
|
# COUNT total
|
|
try:
|
|
cursor = conn.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {where_clause}', where_params)
|
|
total_rows = cursor.fetchone()[0]
|
|
except Exception:
|
|
conn.close()
|
|
return {"error": "COUNT failed"}
|
|
|
|
if total_rows == 0:
|
|
conn.close()
|
|
return {"rows": [], "total": 0}
|
|
|
|
# Sample 5 rows to estimate column widths
|
|
try:
|
|
cursor = conn.execute(f'SELECT * FROM "{table_name}" WHERE {where_clause} LIMIT 5', where_params)
|
|
sample_rows = [dict(r) for r in cursor.fetchall()]
|
|
except Exception:
|
|
conn.close()
|
|
return {"error": "Sample query failed"}
|
|
|
|
col_avg_chars = {}
|
|
for col in db_cols:
|
|
total_chars = sum(len(str(row.get(col, "") or "")) for row in sample_rows)
|
|
col_avg_chars[col] = total_chars / max(len(sample_rows), 1)
|
|
|
|
# Budget check: decide columns and LIMIT
|
|
header_overhead = len("\t".join(db_cols)) + 50
|
|
avg_row_chars = sum(col_avg_chars.values()) + len(db_cols)
|
|
estimated_total = header_overhead + total_rows * avg_row_chars
|
|
|
|
if estimated_total <= budget:
|
|
select_cols = db_cols
|
|
limit = total_rows
|
|
else:
|
|
# __src column always retained (source marker for citation)
|
|
# Prioritize keyword-hit columns, then fill remaining budget
|
|
must_cols = [c for c in db_cols if c == "__src"]
|
|
keyword_cols = []
|
|
other_cols = []
|
|
for col in db_cols:
|
|
if col == "__src":
|
|
continue
|
|
is_kw_col = any(
|
|
any(kw.lower() in str(row.get(col, "") or "").lower() for row in sample_rows)
|
|
for kw in keywords if kw
|
|
)
|
|
if is_kw_col:
|
|
keyword_cols.append(col)
|
|
else:
|
|
other_cols.append(col)
|
|
|
|
select_cols = must_cols + (keyword_cols[:] if keyword_cols else [])
|
|
|
|
def row_width(cols):
|
|
return sum(col_avg_chars.get(c, 10) for c in cols) + len(cols)
|
|
|
|
for col in other_cols:
|
|
test_cols = select_cols + [col]
|
|
if header_overhead + 5 * row_width(test_cols) <= budget:
|
|
select_cols.append(col)
|
|
|
|
if not select_cols:
|
|
select_cols = db_cols
|
|
|
|
rw = row_width(select_cols)
|
|
available = budget - header_overhead
|
|
limit = max(1, int(available / max(rw, 1)))
|
|
limit = min(limit, total_rows)
|
|
|
|
# Execute
|
|
cols_str = ", ".join(f'"{c}"' for c in select_cols)
|
|
sql = f'SELECT {cols_str} FROM "{table_name}" WHERE {where_clause} LIMIT {limit}'
|
|
try:
|
|
cursor = conn.execute(sql, where_params)
|
|
rows = [dict(r) for r in cursor.fetchall()]
|
|
except Exception:
|
|
conn.close()
|
|
return {"error": "Query execution failed"}
|
|
|
|
conn.close()
|
|
|
|
result = {
|
|
"table": table_name,
|
|
"columns": select_cols,
|
|
"all_columns": db_cols,
|
|
"rows": rows,
|
|
"count": len(rows),
|
|
"total": total_rows,
|
|
"fields_reduced": len(select_cols) < len(db_cols),
|
|
"rows_limited": limit < total_rows,
|
|
"keyword": used_kw,
|
|
"db_path": db_path,
|
|
}
|
|
return result
|
|
|
|
|
|
def query_markdown_sheet(body, sheet_id, keywords, budget):
|
|
"""Keyword-based section matching for markdown sheets."""
|
|
parts = re.split(r"<!--\s*sheet_\w+(?:\s+[^>]*)?\s*-->", body)
|
|
markers = re.findall(r"<!--\s*(sheet_\w+)(?:\s+[^>]*)?\s*-->", body)
|
|
|
|
section = ""
|
|
for i, marker in enumerate(markers):
|
|
if marker == sheet_id and i < len(parts) - 1:
|
|
section = parts[i + 1].strip()
|
|
break
|
|
if not section and len(markers) == 0 and len(parts) == 1:
|
|
section = body.strip()
|
|
|
|
if not section:
|
|
return {"content": "", "note": "No content found"}
|
|
|
|
if len(section) <= budget:
|
|
return {"content": section, "full": True}
|
|
|
|
# Keyword-based line matching with context
|
|
lines = section.split("\n")
|
|
matched_indices = set()
|
|
context = 3
|
|
for i, line in enumerate(lines):
|
|
for kw in keywords:
|
|
if kw and kw.lower() in line.lower():
|
|
for j in range(max(0, i - context), min(len(lines), i + context + 1)):
|
|
matched_indices.add(j)
|
|
|
|
if not matched_indices:
|
|
truncated = section[:budget]
|
|
last_nl = truncated.rfind("\n")
|
|
if last_nl > budget * 0.7:
|
|
truncated = truncated[:last_nl]
|
|
return {"content": truncated, "note": f"[No keyword match. First {len(truncated)} chars of {len(section)}]"}
|
|
|
|
result_lines = []
|
|
chars = 0
|
|
prev = -2
|
|
for idx in sorted(matched_indices):
|
|
line = lines[idx]
|
|
line_chars = len(line) + 1
|
|
if chars + line_chars > budget:
|
|
break
|
|
if idx > prev + 1:
|
|
result_lines.append("---")
|
|
chars += 4
|
|
result_lines.append(line)
|
|
chars += line_chars
|
|
prev = idx
|
|
|
|
return {"content": "\n".join(result_lines), "matched_lines": len(result_lines)}
|
|
|
|
|
|
def main():
|
|
# Auto-discover datasets from ./dataset/ or ./datasets/ subdirectories
|
|
dataset_ids = _discover_datasets()
|
|
raw_entries = [e.strip() for e in sys.argv[1].split(",") if e.strip()]
|
|
question = sys.argv[2] if len(sys.argv) > 2 else ""
|
|
keywords = sys.argv[3:] # remaining positional args are keywords
|
|
|
|
entries = []
|
|
for entry in raw_entries:
|
|
if ":" in entry:
|
|
fid, sid = entry.split(":", 1)
|
|
entries.append((fid.strip(), sid.strip()))
|
|
else:
|
|
entries.append((entry.strip(), None))
|
|
|
|
per_entry_budget = calc_budget(len(entries))
|
|
print(f"[Budget: {per_entry_budget} chars/entry, {len(entries)} entries]")
|
|
|
|
# Load F0→F{n} mapping from search.py
|
|
f_ref_map = load_file_ref_map() # file_id → f_code
|
|
|
|
for fid, target_sheet_id in entries:
|
|
file_dir = find_file_dir(dataset_ids, fid)
|
|
if not file_dir:
|
|
print(f"\n{'='*60}")
|
|
print(f"file_id: {fid}")
|
|
print(f" ERROR: not found")
|
|
continue
|
|
|
|
meta, body = load_knowledge_meta(file_dir)
|
|
if not meta:
|
|
print(f"\n{'='*60}")
|
|
print(f"file_id: {fid}")
|
|
print(f" ERROR: knowledge.md invalid")
|
|
continue
|
|
|
|
source_name = meta.get("source_name", "unknown")
|
|
sheets_meta = {s["id"]: s for s in meta.get("sheets", [])}
|
|
|
|
if target_sheet_id and target_sheet_id in sheets_meta:
|
|
target_sheets = [(target_sheet_id, sheets_meta[target_sheet_id])]
|
|
elif target_sheet_id:
|
|
print(f"\n{'='*60}")
|
|
print(f"file_id: {fid}, sheet: {target_sheet_id}")
|
|
print(f" ERROR: sheet not found")
|
|
continue
|
|
else:
|
|
target_sheets = list(sheets_meta.items())
|
|
|
|
for sid, sheet in target_sheets:
|
|
stype = sheet.get("type", "unknown")
|
|
sname = sheet.get("name", "?")
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"file_id: {fid} / {sid}: {sname} [{stype}]")
|
|
print(f"source: {source_name}")
|
|
|
|
if stype == "db":
|
|
db_path = os.path.join(file_dir, "knowledge.db")
|
|
db_table = sheet.get("db_table", sid)
|
|
print(f"db: {db_path}, table: {db_table}")
|
|
|
|
result = query_db_sheet(db_path, db_table, sheet.get("columns", []),
|
|
keywords, per_entry_budget)
|
|
if "error" in result:
|
|
print(f" ERROR: {result['error']}")
|
|
continue
|
|
|
|
if not result.get("rows"):
|
|
print(f" No matching rows (total: {result.get('total', 0)})")
|
|
continue
|
|
|
|
budget_info = []
|
|
if result.get("fields_reduced"):
|
|
budget_info.append(f"fields: {len(result['columns'])}/{len(result['all_columns'])}")
|
|
if result.get("rows_limited"):
|
|
budget_info.append(f"rows: {result['count']}/{result['total']}")
|
|
budget_str = f" [BUDGET: {', '.join(budget_info)}]" if budget_info else ""
|
|
|
|
# Hide __src from both COLUMNS report and TSV display — it's consumed
|
|
# into CITATION tags below.
|
|
display_cols = [c for c in result["columns"] if c != "__src"]
|
|
biz_all_cols = [c for c in result["all_columns"] if c != "__src"]
|
|
omitted = [c for c in biz_all_cols if c not in display_cols]
|
|
|
|
print(f" TABLE: {result['table']} ({result['count']}/{result['total']} rows){budget_str}")
|
|
print(f" COLUMNS: {', '.join(display_cols)}")
|
|
if omitted:
|
|
print(f" OMITTED: {len(omitted)} columns")
|
|
if result.get("keyword"):
|
|
print(f" KEYWORD: {result['keyword']}")
|
|
# Warn if keyword filtering returned suspiciously few rows
|
|
if result.get("keyword") and result["count"] <= 3 and result["total"] >= 10:
|
|
print(f" ⚠ NOTE: keyword \"{result['keyword']}\" matched only {result['count']}/{result['total']} rows. Results may be incomplete — consider removing this keyword.")
|
|
print("-" * 40)
|
|
# TSV output — __src stripped, collected into src_groups for CITATIONS.
|
|
f_entry = f_ref_map.get(fid, ("", ""))
|
|
f_code, filename = f_entry
|
|
cols = result["columns"]
|
|
has_src = "__src" in cols
|
|
print("\t".join(display_cols))
|
|
|
|
src_groups = {} # (file_code, sheet_num) -> set of row_nums
|
|
for row in result["rows"]:
|
|
if has_src:
|
|
raw_src = row.get("__src")
|
|
if raw_src:
|
|
src_val = str(raw_src)
|
|
if f_code:
|
|
src_val = replace_f0(src_val, f_code)
|
|
m = SRC_ROW_PAT.search(src_val)
|
|
if m:
|
|
file_code = f"F{m.group(1)}"
|
|
sheet_num = int(m.group(2))
|
|
row_num = int(m.group(3))
|
|
src_groups.setdefault((file_code, sheet_num), set()).add(row_num)
|
|
vals = []
|
|
for c in display_cols:
|
|
v = row.get(c)
|
|
s = "" if v is None else str(v)
|
|
if len(s) > 200:
|
|
s = s[:200] + "..."
|
|
vals.append(s)
|
|
print("\t".join(vals))
|
|
|
|
if src_groups:
|
|
import json
|
|
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
|
|
with open(citations_path, "a", encoding="utf-8") as cf:
|
|
for (file_code, sheet_num) in sorted(src_groups.keys()):
|
|
row_nums = sorted(src_groups[(file_code, sheet_num)])
|
|
cf.write(json.dumps({
|
|
"file": fid, "filename": filename,
|
|
"sheet": sheet_num, "rows": row_nums,
|
|
"source": "query",
|
|
}, ensure_ascii=False) + "\n")
|
|
|
|
elif stype == "markdown":
|
|
if not body:
|
|
print(f" ERROR: no body content")
|
|
continue
|
|
src_tag = extract_sheet_src(body, sid) if body else ""
|
|
f_entry = f_ref_map.get(fid, ("", ""))
|
|
f_code, filename = f_entry
|
|
if src_tag and f_code:
|
|
src_tag = replace_f0(src_tag, f_code)
|
|
result = query_markdown_sheet(body, sid, keywords, per_entry_budget)
|
|
if result.get("note"):
|
|
print(f" {result['note']}")
|
|
print("-" * 40)
|
|
print(result.get("content", ""))
|
|
# Write sheet-level citation to session file (no stdout).
|
|
if src_tag:
|
|
m = SRC_SHEET_PAT.match(src_tag)
|
|
if m:
|
|
import json
|
|
sheet_num = int(m.group(2))
|
|
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
|
|
with open(citations_path, "a", encoding="utf-8") as cf:
|
|
cf.write(json.dumps({
|
|
"file": fid, "filename": filename,
|
|
"sheet": sheet_num, "rows": [],
|
|
"source": "query",
|
|
}, ensure_ascii=False) + "\n")
|
|
|
|
else:
|
|
print(f" ERROR: unknown type '{stype}'")
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Done. Queried {len(entries)} entries.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|