443 lines
16 KiB
Python
443 lines
16 KiB
Python
"""Execute custom SQL on a knowledge.db with auto-pagination.
|
||
|
||
Usage: python3 query_db.py <db_path> <sql> [--offset N]
|
||
|
||
Auto-fixes common LLM SQL issues before execution:
|
||
- Fullwidth punctuation → ASCII
|
||
- Unquoted identifiers → double-quoted (matched against actual DB columns/tables)
|
||
|
||
Output: TSV format with header + status line at the end.
|
||
|
||
Status line (one of three):
|
||
[RESULT: N/N rows returned | COMPLETE]
|
||
→ All data returned in this call. Proceed to answer.
|
||
|
||
[RESULT: K/total returned | this batch: rows X-Y (offset A-B) | PARTIAL — call again with --offset=M]
|
||
→ Output size limit reached. Re-invoke with SAME <db_path> and <SQL>, adding `--offset M`.
|
||
Keep calling until you see COMPLETE.
|
||
|
||
[RESULT: 0 rows | EMPTY]
|
||
→ Query matched no rows.
|
||
|
||
Key rules:
|
||
- --offset is a COMMAND-LINE argument, NOT a SQL clause.
|
||
- Keep the SQL string identical across pagination calls. Do NOT add LIMIT/OFFSET to the SQL
|
||
for output control (pagination is automatic). You MAY use SQL LIMIT when the question
|
||
genuinely requires it (e.g. "top 10 by X").
|
||
|
||
Example:
|
||
python3 query_db.py ./dataset/abc/xyz/knowledge.db "SELECT * FROM sheet_001 WHERE col LIKE '%x%'"
|
||
python3 query_db.py ./dataset/abc/xyz/knowledge.db "SELECT * FROM sheet_001 WHERE col LIKE '%x%'" --offset=11
|
||
"""
|
||
import os
|
||
import re
|
||
import sqlite3
|
||
import sys
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from _session import get_session_dir
|
||
|
||
MAX_OUTPUT_CHARS = 3000
|
||
|
||
|
||
def _fix_fullwidth(sql):
|
||
"""Replace fullwidth punctuation with ASCII equivalents."""
|
||
sql = sql.replace("\uff0c", ",") # , → ,
|
||
sql = sql.replace("\u3000", " ") # fullwidth space → space
|
||
sql = sql.replace("\uff08", "(") # ( → (
|
||
sql = sql.replace("\uff09", ")") # ) → )
|
||
sql = sql.replace("\u2018", "'") # ' → '
|
||
sql = sql.replace("\u2019", "'") # ' → '
|
||
sql = sql.replace("\u201c", '"') # " → "
|
||
sql = sql.replace("\u201d", '"') # " → "
|
||
return sql
|
||
|
||
|
||
def _normalize_width(s):
|
||
"""Normalize fullwidth digits/letters to halfwidth for comparison."""
|
||
result = []
|
||
for ch in s:
|
||
cp = ord(ch)
|
||
# Fullwidth digits 0-9 (0xFF10-0xFF19) → 0-9
|
||
if 0xFF10 <= cp <= 0xFF19:
|
||
result.append(chr(cp - 0xFF10 + ord('0')))
|
||
# Fullwidth uppercase A-Z (0xFF21-0xFF3A) → A-Z
|
||
elif 0xFF21 <= cp <= 0xFF3A:
|
||
result.append(chr(cp - 0xFF21 + ord('A')))
|
||
# Fullwidth lowercase a-z (0xFF41-0xFF5A) → a-z
|
||
elif 0xFF41 <= cp <= 0xFF5A:
|
||
result.append(chr(cp - 0xFF41 + ord('a')))
|
||
else:
|
||
result.append(ch)
|
||
return ''.join(result)
|
||
|
||
|
||
def _fix_quoted_identifiers(sql, conn):
|
||
"""Fix double-quoted identifiers that don't match DB columns due to fullwidth/halfwidth mismatch.
|
||
|
||
SQLite treats unmatched "identifier" as a string literal, causing silent wrong results.
|
||
This function replaces quoted strings with the exact DB column name when they match after normalization.
|
||
"""
|
||
identifiers = set()
|
||
try:
|
||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||
tables = [r[0] for r in cursor.fetchall()]
|
||
for table in tables:
|
||
identifiers.add(table)
|
||
cursor = conn.execute(f'PRAGMA table_info("{table}")')
|
||
for col_info in cursor.fetchall():
|
||
identifiers.add(col_info[1])
|
||
except Exception:
|
||
return sql
|
||
|
||
# Build normalized lookup: normalized_form → exact_name
|
||
norm_lookup = {}
|
||
for ident in identifiers:
|
||
norm = _normalize_width(ident).replace('\u3000', ' ').strip()
|
||
norm_lookup[norm] = ident
|
||
|
||
# Find all "quoted strings" in SQL and fix mismatches
|
||
def replace_quoted(match):
|
||
quoted = match.group(1)
|
||
# Already exact match — no fix needed
|
||
if quoted in identifiers:
|
||
return f'"{quoted}"'
|
||
# Try normalized match
|
||
norm = _normalize_width(quoted).replace('\u3000', ' ').strip()
|
||
if norm in norm_lookup:
|
||
return f'"{norm_lookup[norm]}"'
|
||
return match.group(0)
|
||
|
||
sql = re.sub(r'"([^"]+)"', replace_quoted, sql)
|
||
return sql
|
||
|
||
|
||
def _fix_identifiers(sql, conn):
|
||
"""Auto-quote unquoted table/column identifiers using actual DB schema.
|
||
|
||
Handles two LLM issues:
|
||
1. Unquoted identifiers: `SELECT 営業担当 FROM sheet_001` → `SELECT "営業担当" FROM "sheet_001"`
|
||
2. Inserted spaces at CJK/ASCII boundary: `貴社ご注文 NO` → `"貴社ご注文NO"`
|
||
(qwen tokenizer splits CJK and ASCII, generating spaces between them)
|
||
"""
|
||
# Collect all table names and column names from the database
|
||
identifiers = set()
|
||
try:
|
||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||
tables = [r[0] for r in cursor.fetchall()]
|
||
for table in tables:
|
||
identifiers.add(table)
|
||
cursor = conn.execute(f'PRAGMA table_info("{table}")')
|
||
for col_info in cursor.fetchall():
|
||
identifiers.add(col_info[1]) # column name
|
||
except Exception:
|
||
return sql
|
||
|
||
SQL_KEYWORDS = {
|
||
'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER',
|
||
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'AS', 'ON', 'JOIN', 'LEFT',
|
||
'RIGHT', 'INNER', 'OUTER', 'UNION', 'ALL', 'DISTINCT', 'COUNT', 'SUM',
|
||
'AVG', 'MIN', 'MAX', 'DESC', 'ASC', 'NULL', 'IS', 'BETWEEN', 'EXISTS',
|
||
'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'CREATE', 'TABLE', 'INSERT',
|
||
'INTO', 'VALUES', 'UPDATE', 'SET', 'DELETE', 'DROP', 'ALTER', 'INDEX',
|
||
'TEMP', 'TEMPORARY', 'IF', 'REPLACE', 'SUBSTR', 'COLLATE', 'NOCASE',
|
||
'CAST', 'INTEGER', 'TEXT', 'REAL',
|
||
}
|
||
|
||
# Sort by length descending so longer identifiers are matched first
|
||
sorted_ids = sorted(identifiers, key=len, reverse=True)
|
||
|
||
for ident in sorted_ids:
|
||
if not ident or len(ident) < 2:
|
||
continue
|
||
if ident.upper() in SQL_KEYWORDS:
|
||
continue
|
||
if f'"{ident}"' in sql:
|
||
continue
|
||
|
||
# Build a flexible pattern that allows:
|
||
# 1. Optional spaces at CJK/ASCII boundaries (qwen tokenizer splits)
|
||
# 2. Fullwidth/halfwidth bracket interchange ()↔()
|
||
# e.g., "年次(西暦)" matches "年次(西暦)" or "年次(西暦)" or "年次 (西暦)"
|
||
BRACKET_MAP = {'(': '[((]', ')': '[))]', '(': '[((]', ')': '[))]'}
|
||
flex_chars = []
|
||
for i, ch in enumerate(ident):
|
||
if ch in BRACKET_MAP:
|
||
flex_chars.append(BRACKET_MAP[ch])
|
||
else:
|
||
flex_chars.append(re.escape(ch))
|
||
if i < len(ident) - 1:
|
||
curr_cjk = ord(ch) > 0x2E80
|
||
next_cjk = ord(ident[i + 1]) > 0x2E80
|
||
if curr_cjk != next_cjk:
|
||
flex_chars.append(r'\s*') # optional whitespace at boundary
|
||
flex_pattern = ''.join(flex_chars)
|
||
|
||
# Match only unquoted occurrences
|
||
pattern = r'(?<!")(' + flex_pattern + r')(?!")'
|
||
sql = re.sub(pattern, f'"{ident}"', sql)
|
||
|
||
return sql
|
||
|
||
|
||
def _load_file_ref(db_path):
|
||
"""Return (f_code, file_id, filename) for this db file from file_refs.txt. ("","","") if not found."""
|
||
parts = db_path.replace("\\", "/").split("/")
|
||
file_id = ""
|
||
for i, p in enumerate(parts):
|
||
if p == "knowledge.db" and i >= 1:
|
||
file_id = parts[i - 1]
|
||
break
|
||
if not file_id:
|
||
return ("", "", "")
|
||
refs_path = os.path.join(get_session_dir(), "file_refs.txt")
|
||
if not os.path.isfile(refs_path):
|
||
return ("", file_id, "")
|
||
ref_pat = re.compile(r"^(F\d+)=([0-9a-f-]+)\((.+?)\)\s*$")
|
||
with open(refs_path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
m = ref_pat.match(line.strip())
|
||
if m and m.group(2) == file_id:
|
||
return (m.group(1), file_id, m.group(3))
|
||
return ("", file_id, "")
|
||
|
||
|
||
def _purge_query_citations(db_path, sql):
|
||
"""Remove stale query.py citation entries for this (file, sheet) from citations.jsonl.
|
||
|
||
Called unconditionally when query_db.py runs — ensures query.py's
|
||
partial/stale data never survives into merge_citations.py output.
|
||
Purges by (file_id, sheet) so other sheets' query.py data is preserved.
|
||
"""
|
||
import json
|
||
_, file_id, _ = _load_file_ref(db_path)
|
||
if not file_id:
|
||
return
|
||
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
|
||
if not os.path.isfile(citations_path):
|
||
return
|
||
|
||
# Determine target sheet from SQL table name → DB __src
|
||
sheet_num = None
|
||
m = re.search(r'FROM\s+"?(\w+)"?', sql, re.IGNORECASE)
|
||
if m:
|
||
table_name = m.group(1)
|
||
try:
|
||
conn = sqlite3.connect(db_path)
|
||
cursor = conn.execute(f'SELECT __src FROM "{table_name}" LIMIT 1')
|
||
row = cursor.fetchone()
|
||
conn.close()
|
||
if row and row[0]:
|
||
sm = re.search(r'S(\d+)', str(row[0]))
|
||
if sm:
|
||
sheet_num = int(sm.group(1))
|
||
except Exception:
|
||
pass
|
||
|
||
keep = []
|
||
with open(citations_path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
entry = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if (entry.get("file") == file_id
|
||
and entry.get("source") == "query"
|
||
and (sheet_num is None or entry.get("sheet") == sheet_num)):
|
||
continue # discard this (file, sheet) query.py entry
|
||
keep.append(line + "\n")
|
||
with open(citations_path, "w", encoding="utf-8") as f:
|
||
f.writelines(keep)
|
||
|
||
|
||
def _parse_offset(args):
|
||
"""Parse --offset N or --offset=N from remaining args. Returns (offset, ok, error_msg)."""
|
||
offset = 0
|
||
i = 0
|
||
while i < len(args):
|
||
a = args[i]
|
||
if a == "--offset" and i + 1 < len(args):
|
||
try:
|
||
offset = int(args[i + 1])
|
||
except ValueError:
|
||
return 0, False, f"invalid --offset value: {args[i + 1]}"
|
||
i += 2
|
||
elif a.startswith("--offset="):
|
||
try:
|
||
offset = int(a.split("=", 1)[1])
|
||
except ValueError:
|
||
return 0, False, f"invalid --offset value: {a}"
|
||
i += 1
|
||
else:
|
||
i += 1
|
||
if offset < 0:
|
||
offset = 0
|
||
return offset, True, ""
|
||
|
||
|
||
def main():
|
||
if len(sys.argv) < 3:
|
||
print("Usage: python3 query_db.py <db_path> <sql> [--offset N]")
|
||
return
|
||
|
||
db_path = sys.argv[1]
|
||
sql = sys.argv[2]
|
||
|
||
offset, ok, err = _parse_offset(sys.argv[3:])
|
||
if not ok:
|
||
print(f"[RESULT: 0 rows | {err}]")
|
||
return
|
||
|
||
# Immediately purge stale query.py citations for this (file, sheet).
|
||
# Must run unconditionally — even if this query returns 0 rows or errors out.
|
||
_purge_query_citations(db_path, sql)
|
||
|
||
# Step 1: Fix fullwidth punctuation
|
||
sql = _fix_fullwidth(sql)
|
||
|
||
conn = sqlite3.connect(db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# Step 2: Auto-quote identifiers using actual DB schema
|
||
sql = _fix_identifiers(sql, conn)
|
||
|
||
# Step 3: Fix quoted identifiers with fullwidth/halfwidth mismatch
|
||
sql = _fix_quoted_identifiers(sql, conn)
|
||
|
||
try:
|
||
cursor = conn.execute(sql)
|
||
rows = cursor.fetchall()
|
||
except Exception as e:
|
||
print(f"SQL ERROR: {e}")
|
||
print(f"SQL (after auto-fix): {sql}")
|
||
conn.close()
|
||
return
|
||
|
||
total = len(rows)
|
||
|
||
# Empty result
|
||
if total == 0:
|
||
print("[RESULT: 0 rows | EMPTY]")
|
||
print(f"SQL: {sql}")
|
||
conn.close()
|
||
return
|
||
|
||
# offset out of range
|
||
if offset >= total:
|
||
print(f"[RESULT: 0 rows | offset {offset} exceeds total {total} | call again with --offset=0]")
|
||
conn.close()
|
||
return
|
||
|
||
columns = list(rows[0].keys())
|
||
has_src = "__src" in columns
|
||
output_columns = [c for c in columns if c != "__src"]
|
||
n_cols = max(len(output_columns), 1)
|
||
|
||
# Load (F{n}, file_id UUID, filename) for this db — used to pre-build CITATION tags.
|
||
f_code, file_id, filename = _load_file_ref(db_path)
|
||
|
||
# Dynamic cell-level truncation:
|
||
# budget per cell = MAX_OUTPUT / n_cols / 2, clamped to [100, 500].
|
||
# Prevents single wide row from consuming entire budget when many columns.
|
||
cell_max = max(100, min(500, MAX_OUTPUT_CHARS // n_cols // 2))
|
||
|
||
# Build output (TSV rows, __src excluded)
|
||
lines = []
|
||
lines.append("\t".join(output_columns))
|
||
lines.append("-" * 40)
|
||
|
||
total_chars = sum(len(l) + 1 for l in lines)
|
||
shown_rows = 0 # rows output in THIS batch
|
||
|
||
# Collect __src groups for CITATION tags — only for rows actually emitted.
|
||
# Cell value is stored as `__src="F0S1R5"` (xls-agent-parse wraps it), so use
|
||
# search() with a non-anchored pattern to extract the F/S/R triple.
|
||
SRC_PAT = re.compile(r"F(\d+)S(\d+)R(\d+)")
|
||
src_groups = {} # (file_code, sheet_num) -> set of row_nums
|
||
|
||
for idx in range(offset, total):
|
||
row = rows[idx]
|
||
vals = []
|
||
for c in output_columns:
|
||
v = row[c]
|
||
s = "" if v is None else str(v)
|
||
# Cell-level truncation (dynamic based on column count)
|
||
if len(s) > cell_max:
|
||
s = s[:cell_max] + "..."
|
||
vals.append(s)
|
||
line = "\t".join(vals)
|
||
line_chars = len(line) + 1
|
||
|
||
# Force at least 1 row to guarantee progress (even if that single row exceeds budget)
|
||
if shown_rows > 0 and total_chars + line_chars > MAX_OUTPUT_CHARS:
|
||
break
|
||
|
||
lines.append(line)
|
||
total_chars += line_chars
|
||
shown_rows += 1
|
||
|
||
# Record __src for this emitted row
|
||
if has_src:
|
||
raw_src = row["__src"]
|
||
if raw_src is not None:
|
||
src_val = str(raw_src)
|
||
if f_code and "F0S" in src_val:
|
||
src_val = src_val.replace("F0S", f"{f_code}S", 1)
|
||
m = SRC_PAT.search(src_val)
|
||
if m:
|
||
file_code = f"F{m.group(1)}"
|
||
sheet_num = int(m.group(2))
|
||
row_num = int(m.group(3))
|
||
src_groups.setdefault((file_code, sheet_num), set()).add(row_num)
|
||
|
||
# Write citation data to session citations.jsonl (merge_citations.py reads it later).
|
||
# No [CITATIONS] block in stdout — LLM only sees final merged tags from merge_citations.py.
|
||
if src_groups:
|
||
import json
|
||
citations_path = os.path.join(get_session_dir(), "citations.jsonl")
|
||
fid_val = file_id if file_id else ""
|
||
with open(citations_path, "a", encoding="utf-8") as cf:
|
||
for (file_code, sheet_num) in sorted(src_groups.keys()):
|
||
row_nums = sorted(src_groups[(file_code, sheet_num)])
|
||
cf.write(json.dumps({
|
||
"file": fid_val if fid_val else file_code,
|
||
"filename": filename,
|
||
"sheet": sheet_num, "rows": row_nums,
|
||
"source": "query_db",
|
||
}, ensure_ascii=False) + "\n")
|
||
|
||
# Build status line
|
||
next_offset = offset + shown_rows
|
||
batch_start_row = offset + 1 # human-readable, 1-indexed
|
||
batch_end_row = offset + shown_rows
|
||
batch_offset_start = offset
|
||
batch_offset_end = offset + shown_rows - 1
|
||
|
||
lines.append("")
|
||
if next_offset >= total:
|
||
if offset == 0:
|
||
# Single-call, all data
|
||
lines.append(f"[RESULT: {total}/{total} rows returned | COMPLETE]")
|
||
else:
|
||
# Final batch in a paginated sequence
|
||
lines.append(
|
||
f"[RESULT: {total}/{total} returned | this batch: rows {batch_start_row}-{batch_end_row} "
|
||
f"(offset {batch_offset_start}-{batch_offset_end}) | COMPLETE]"
|
||
)
|
||
else:
|
||
# More data remaining
|
||
lines.append(
|
||
f"[RESULT: {next_offset}/{total} returned | this batch: rows {batch_start_row}-{batch_end_row} "
|
||
f"(offset {batch_offset_start}-{batch_offset_end}) | PARTIAL — call again with --offset={next_offset}]"
|
||
)
|
||
|
||
print("\n".join(lines))
|
||
conn.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|