107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
import os
|
|
import re
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
|
|
class SearchProviderError(Exception):
|
|
code = "provider_error"
|
|
|
|
|
|
class MissingAPIKeyError(SearchProviderError):
|
|
code = "missing_api_key"
|
|
|
|
|
|
class UnsupportedProviderError(SearchProviderError):
|
|
code = "unsupported_provider"
|
|
|
|
|
|
class UpstreamHTTPError(SearchProviderError):
|
|
code = "upstream_http_error"
|
|
|
|
|
|
class InvalidSearchInputError(SearchProviderError):
|
|
code = "invalid_input"
|
|
|
|
|
|
def _build_search_filter(freshness: str | None) -> dict[str, Any]:
|
|
if not freshness:
|
|
return {}
|
|
current_time = datetime.now()
|
|
end_date = (current_time + timedelta(days=1)).strftime("%Y-%m-%d")
|
|
if freshness == "pd":
|
|
start_date = (current_time - timedelta(days=1)).strftime("%Y-%m-%d")
|
|
elif freshness == "pw":
|
|
start_date = (current_time - timedelta(days=6)).strftime("%Y-%m-%d")
|
|
elif freshness == "pm":
|
|
start_date = (current_time - timedelta(days=30)).strftime("%Y-%m-%d")
|
|
elif freshness == "py":
|
|
start_date = (current_time - timedelta(days=364)).strftime("%Y-%m-%d")
|
|
elif re.fullmatch(r"\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}", freshness):
|
|
start_date, end_date = freshness.split("to")
|
|
else:
|
|
raise InvalidSearchInputError("freshness must be pd, pw, pm, py, or YYYY-MM-DDtoYYYY-MM-DD")
|
|
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
|
|
|
|
|
|
def baidu_search(query: str, count: int = 10, freshness: str | None = None) -> list[dict[str, Any]]:
|
|
api_key = os.getenv("BAIDU_API_KEY")
|
|
if not api_key:
|
|
raise MissingAPIKeyError("BAIDU_API_KEY must be set")
|
|
try:
|
|
response = requests.post(
|
|
"https://qianfan.baidubce.com/v2/ai_search/web_search",
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"X-Appbuilder-From": "openclaw",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={
|
|
"messages": [{"content": query, "role": "user"}],
|
|
"search_source": "baidu_search_v2",
|
|
"resource_type_filter": [{"type": "web", "top_k": max(1, min(count, 50))}],
|
|
"search_filter": _build_search_filter(freshness),
|
|
},
|
|
timeout=30,
|
|
)
|
|
response.raise_for_status()
|
|
except requests.HTTPError as exc:
|
|
raise UpstreamHTTPError(str(exc)) from exc
|
|
except requests.RequestException as exc:
|
|
raise UpstreamHTTPError(str(exc)) from exc
|
|
payload = response.json()
|
|
if "code" in payload:
|
|
raise UpstreamHTTPError(payload.get("message", "baidu search failed"))
|
|
return payload.get("references", [])
|
|
|
|
|
|
def collect_events(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
|
data = payload.get("data", {})
|
|
competitors = data.get("competitors", [])
|
|
search_cfg = data.get("search", {})
|
|
provider = search_cfg.get("provider", "baidu")
|
|
if provider != "baidu":
|
|
raise UnsupportedProviderError(f"provider {provider} is not supported")
|
|
freshness = search_cfg.get("freshness") or data.get("time_range")
|
|
count = int(search_cfg.get("count", 10))
|
|
keywords_extra = search_cfg.get("keywords_extra", [])
|
|
categories = data.get("event_categories", [])
|
|
events = []
|
|
for competitor in competitors:
|
|
query_parts = [competitor, *keywords_extra, *categories]
|
|
query = " ".join(str(part) for part in query_parts if part)
|
|
for item in baidu_search(query=query, count=count, freshness=freshness):
|
|
events.append(
|
|
{
|
|
"competitor": competitor,
|
|
"date": item.get("page_time") or "",
|
|
"title": item.get("title") or "",
|
|
"summary": item.get("abstract") or item.get("content") or "",
|
|
"source_url": item.get("url") or "",
|
|
"source_name": item.get("site_name") or "baidu",
|
|
}
|
|
)
|
|
return events
|