qwen_agent/skills/competitor-news-intel/scripts/search_provider.py
2026-04-16 10:23:54 +08:00

107 lines
3.9 KiB
Python

import os
import re
from datetime import datetime, timedelta
from typing import Any
import requests
class SearchProviderError(Exception):
code = "provider_error"
class MissingAPIKeyError(SearchProviderError):
code = "missing_api_key"
class UnsupportedProviderError(SearchProviderError):
code = "unsupported_provider"
class UpstreamHTTPError(SearchProviderError):
code = "upstream_http_error"
class InvalidSearchInputError(SearchProviderError):
code = "invalid_input"
def _build_search_filter(freshness: str | None) -> dict[str, Any]:
if not freshness:
return {}
current_time = datetime.now()
end_date = (current_time + timedelta(days=1)).strftime("%Y-%m-%d")
if freshness == "pd":
start_date = (current_time - timedelta(days=1)).strftime("%Y-%m-%d")
elif freshness == "pw":
start_date = (current_time - timedelta(days=6)).strftime("%Y-%m-%d")
elif freshness == "pm":
start_date = (current_time - timedelta(days=30)).strftime("%Y-%m-%d")
elif freshness == "py":
start_date = (current_time - timedelta(days=364)).strftime("%Y-%m-%d")
elif re.fullmatch(r"\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}", freshness):
start_date, end_date = freshness.split("to")
else:
raise InvalidSearchInputError("freshness must be pd, pw, pm, py, or YYYY-MM-DDtoYYYY-MM-DD")
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
def baidu_search(query: str, count: int = 10, freshness: str | None = None) -> list[dict[str, Any]]:
api_key = os.getenv("BAIDU_API_KEY")
if not api_key:
raise MissingAPIKeyError("BAIDU_API_KEY must be set")
try:
response = requests.post(
"https://qianfan.baidubce.com/v2/ai_search/web_search",
headers={
"Authorization": f"Bearer {api_key}",
"X-Appbuilder-From": "openclaw",
"Content-Type": "application/json",
},
json={
"messages": [{"content": query, "role": "user"}],
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": max(1, min(count, 50))}],
"search_filter": _build_search_filter(freshness),
},
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as exc:
raise UpstreamHTTPError(str(exc)) from exc
except requests.RequestException as exc:
raise UpstreamHTTPError(str(exc)) from exc
payload = response.json()
if "code" in payload:
raise UpstreamHTTPError(payload.get("message", "baidu search failed"))
return payload.get("references", [])
def collect_events(payload: dict[str, Any]) -> list[dict[str, Any]]:
data = payload.get("data", {})
competitors = data.get("competitors", [])
search_cfg = data.get("search", {})
provider = search_cfg.get("provider", "baidu")
if provider != "baidu":
raise UnsupportedProviderError(f"provider {provider} is not supported")
freshness = search_cfg.get("freshness") or data.get("time_range")
count = int(search_cfg.get("count", 10))
keywords_extra = search_cfg.get("keywords_extra", [])
categories = data.get("event_categories", [])
events = []
for competitor in competitors:
query_parts = [competitor, *keywords_extra, *categories]
query = " ".join(str(part) for part in query_parts if part)
for item in baidu_search(query=query, count=count, freshness=freshness):
events.append(
{
"competitor": competitor,
"date": item.get("page_time") or "",
"title": item.get("title") or "",
"summary": item.get("abstract") or item.get("content") or "",
"source_url": item.get("url") or "",
"source_name": item.get("site_name") or "baidu",
}
)
return events