Files
fidelity-ai-workspace/scripts/mattermost/mattermost_context.py

350 lines
11 KiB
Python

#!/usr/bin/env python3
import json
import logging
import os
import ssl
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List
from urllib import error, parse, request
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_WINDOW_HOURS = 24
DEFAULT_MAX_MESSAGES = 200
MAX_PER_PAGE = 200
DEFAULT_OUTPUT_FILE = str(SCRIPT_DIR / "generated" / "mattermost_context.jsonl")
REQUEST_TIMEOUT = 15
LOGGER = logging.getLogger("mattermost_context")
USER_CACHE: Dict[str, Dict[str, Any]] = {}
MATTERMOST_URL = ""
CHANNEL_IDS: List[str] = []
WINDOW_HOURS = DEFAULT_WINDOW_HOURS
MAX_MESSAGES = DEFAULT_MAX_MESSAGES
CUTOFF_TIMESTAMP_MS = 0
OUTPUT_FILE = DEFAULT_OUTPUT_FILE
REQUEST_HEADERS: Dict[str, str] = {}
SSL_CONTEXT: ssl.SSLContext | None = None
class MattermostAPIError(RuntimeError):
pass
def parse_bool_env(name: str, default: bool = False) -> bool:
raw_value = os.getenv(name)
if raw_value is None:
return default
return raw_value.strip().lower() in {"1", "true", "yes", "on"}
def load_dotenv_file(path: Path | None = None) -> None:
dotenv_path = path or (SCRIPT_DIR / ".env")
if not dotenv_path.exists():
return
with dotenv_path.open("r", encoding="utf-8") as file_handle:
for raw_line in file_handle:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[len("export ") :].strip()
if "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
if not key:
continue
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
value = value[1:-1]
os.environ.setdefault(key, value)
def require_env(name: str) -> str:
value = os.getenv(name, "").strip()
if not value:
raise ValueError(f"Missing required environment variable: {name}")
return value
def parse_channel_ids(raw_value: str) -> List[str]:
normalized = raw_value.replace("\n", ",")
channel_ids = [item.strip() for item in normalized.split(",") if item.strip()]
if not channel_ids:
raise ValueError("CHANNEL_IDS must contain at least one channel id.")
return channel_ids
def build_ssl_context() -> ssl.SSLContext:
ca_bundle = os.getenv("MATTERMOST_CA_BUNDLE", "").strip()
skip_tls_verify = parse_bool_env("MATTERMOST_SKIP_TLS_VERIFY", default=False)
if skip_tls_verify:
LOGGER.warning("TLS certificate verification is disabled via MATTERMOST_SKIP_TLS_VERIFY.")
return ssl._create_unverified_context()
if ca_bundle:
LOGGER.info("Using custom CA bundle from MATTERMOST_CA_BUNDLE: %s", ca_bundle)
return ssl.create_default_context(cafile=ca_bundle)
return ssl.create_default_context()
def configure() -> None:
global MATTERMOST_URL, CHANNEL_IDS, WINDOW_HOURS, MAX_MESSAGES, CUTOFF_TIMESTAMP_MS, OUTPUT_FILE
global REQUEST_HEADERS, SSL_CONTEXT
load_dotenv_file()
MATTERMOST_URL = require_env("MATTERMOST_URL").rstrip("/")
token = require_env("MATTERMOST_TOKEN")
CHANNEL_IDS = parse_channel_ids(require_env("CHANNEL_IDS"))
WINDOW_HOURS = int(os.getenv("MESSAGE_WINDOW_HOURS", str(DEFAULT_WINDOW_HOURS)))
MAX_MESSAGES = int(os.getenv("MAX_MESSAGES", str(DEFAULT_MAX_MESSAGES)))
OUTPUT_FILE = os.getenv("MATTERMOST_OUTPUT_FILE", DEFAULT_OUTPUT_FILE).strip() or DEFAULT_OUTPUT_FILE
if WINDOW_HOURS <= 0:
raise ValueError("MESSAGE_WINDOW_HOURS must be greater than 0.")
if MAX_MESSAGES <= 0:
raise ValueError("MAX_MESSAGES must be greater than 0.")
cutoff = datetime.now().astimezone() - timedelta(hours=WINDOW_HOURS)
CUTOFF_TIMESTAMP_MS = int(cutoff.timestamp() * 1000)
REQUEST_HEADERS = {
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
SSL_CONTEXT = build_ssl_context()
def api_get_json(api_path: str, params: Dict[str, Any] | None = None) -> Dict[str, Any]:
query = f"?{parse.urlencode(params)}" if params else ""
url = f"{MATTERMOST_URL}{api_path}{query}"
req = request.Request(url, headers=REQUEST_HEADERS, method="GET")
try:
with request.urlopen(req, timeout=REQUEST_TIMEOUT, context=SSL_CONTEXT) as response:
charset = response.headers.get_content_charset() or "utf-8"
payload = response.read().decode(charset)
except error.HTTPError as exc:
body = ""
try:
body = exc.read().decode("utf-8", errors="replace")
except Exception:
body = ""
raise MattermostAPIError(f"HTTP {exc.code} for {api_path}: {body or exc.reason}") from exc
except error.URLError as exc:
reason = exc.reason if hasattr(exc, "reason") else str(exc)
raise MattermostAPIError(f"Request failed for {api_path}: {reason}") from exc
try:
return json.loads(payload)
except json.JSONDecodeError as exc:
raise MattermostAPIError(f"Invalid JSON returned by {api_path}") from exc
def get_channel_posts(channel_id: str) -> List[Dict[str, Any]]:
collected: List[Dict[str, Any]] = []
page = 0
per_page = min(MAX_PER_PAGE, MAX_MESSAGES)
while len(collected) < MAX_MESSAGES:
payload = api_get_json(
f"/api/v4/channels/{channel_id}/posts",
{"page": page, "per_page": per_page},
)
order = payload.get("order", [])
posts_by_id = payload.get("posts", {})
if not order:
break
reached_cutoff = False
for post_id in order:
post = posts_by_id.get(post_id)
if not post:
continue
created_at = int(post.get("create_at", 0))
if created_at < CUTOFF_TIMESTAMP_MS:
reached_cutoff = True
continue
collected.append(post)
if len(collected) >= MAX_MESSAGES:
break
if reached_cutoff or len(order) < per_page:
break
page += 1
LOGGER.info("Fetched %s raw posts from channel %s", len(collected), channel_id)
return collected
def get_user_info(user_id: str) -> Dict[str, Any]:
if not user_id:
return {"id": "", "username": "unknown"}
if user_id in USER_CACHE:
return USER_CACHE[user_id]
try:
user_data = api_get_json(f"/api/v4/users/{user_id}")
except MattermostAPIError as exc:
LOGGER.error("Could not fetch user %s: %s", user_id, exc)
fallback = {"id": user_id, "username": user_id}
USER_CACHE[user_id] = fallback
return fallback
USER_CACHE[user_id] = user_data
return user_data
def build_user_map(messages: List[Dict[str, Any]]) -> Dict[str, str]:
user_map: Dict[str, str] = {}
user_ids = sorted({message["user_id"] for message in messages if message.get("user_id")})
for user_id in user_ids:
user_info = get_user_info(user_id)
username = (
user_info.get("username")
or user_info.get("nickname")
or user_info.get("first_name")
or user_id
)
user_map[user_id] = username
return user_map
def is_system_message(post: Dict[str, Any]) -> bool:
post_type = (post.get("type") or "").strip()
return post_type.startswith("system_")
def extract_messages() -> List[Dict[str, Any]]:
all_messages: List[Dict[str, Any]] = []
for channel_id in CHANNEL_IDS:
raw_posts = get_channel_posts(channel_id)
for post in raw_posts:
if is_system_message(post):
continue
message = (post.get("message") or "").strip()
if not message:
continue
all_messages.append(
{
"channel_id": channel_id,
"channel_ref": channel_id,
"post_id": post.get("id", ""),
"user_id": post.get("user_id", ""),
"create_at": int(post.get("create_at", 0)),
"message": message.replace("\r\n", "\n"),
"root_id": post.get("root_id", ""),
"reply_count": int(post.get("reply_count", 0)),
}
)
all_messages.sort(key=lambda item: item["create_at"])
if len(all_messages) > MAX_MESSAGES:
all_messages = all_messages[-MAX_MESSAGES:]
user_map = build_user_map(all_messages)
for message in all_messages:
message["username"] = user_map.get(message["user_id"], message["user_id"] or "unknown")
LOGGER.info("Prepared %s messages after filtering", len(all_messages))
return all_messages
def format_messages(messages: List[Dict[str, Any]]) -> str:
lines: List[str] = []
for message in messages:
timestamp = datetime.fromtimestamp(message["create_at"] / 1000).astimezone()
username = message.get("username", "unknown")
post_id = message.get("post_id", "")
root_id = message.get("root_id", "")
thread_id = root_id or post_id or "unknown"
reply_count = int(message.get("reply_count", 0))
if root_id:
message_kind = "thread_reply"
elif reply_count > 0:
message_kind = "thread_root"
else:
message_kind = "channel_post"
channel_ref = message.get("channel_ref", message.get("channel_id", "unknown"))
record = {
"source": "mattermost",
"channel": channel_ref,
"channel_id": message.get("channel_id", ""),
"post_id": post_id,
"thread_id": thread_id,
"root_id": root_id or None,
"type": message_kind,
"timestamp": timestamp.isoformat(),
"username": username,
"message": message["message"],
}
if root_id:
record["reply_to"] = root_id
if reply_count > 0:
record["reply_count"] = reply_count
lines.append(json.dumps(record, ensure_ascii=False, sort_keys=False))
return "\n".join(lines)
def save_to_file(text: str) -> None:
output_path = Path(OUTPUT_FILE).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
if text and not text.endswith("\n"):
text = f"{text}\n"
output_path.write_text(text, encoding="utf-8")
def main() -> int:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
try:
configure()
messages = extract_messages()
output = format_messages(messages)
print(output)
save_to_file(output)
LOGGER.info("Saved context to %s", OUTPUT_FILE)
except ValueError as exc:
LOGGER.error("%s", exc)
return 1
except MattermostAPIError as exc:
LOGGER.error("%s", exc)
return 1
except Exception as exc:
LOGGER.exception("Unexpected error: %s", exc)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())