#!/usr/bin/env python3 import json import logging import os import ssl import sys from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List from urllib import error, parse, request SCRIPT_DIR = Path(__file__).resolve().parent DEFAULT_WINDOW_HOURS = 24 DEFAULT_MAX_MESSAGES = 200 MAX_PER_PAGE = 200 DEFAULT_OUTPUT_FILE = str(SCRIPT_DIR / "generated" / "mattermost_context.jsonl") REQUEST_TIMEOUT = 15 LOGGER = logging.getLogger("mattermost_context") USER_CACHE: Dict[str, Dict[str, Any]] = {} MATTERMOST_URL = "" CHANNEL_IDS: List[str] = [] WINDOW_HOURS = DEFAULT_WINDOW_HOURS MAX_MESSAGES = DEFAULT_MAX_MESSAGES CUTOFF_TIMESTAMP_MS = 0 OUTPUT_FILE = DEFAULT_OUTPUT_FILE REQUEST_HEADERS: Dict[str, str] = {} SSL_CONTEXT: ssl.SSLContext | None = None class MattermostAPIError(RuntimeError): pass def parse_bool_env(name: str, default: bool = False) -> bool: raw_value = os.getenv(name) if raw_value is None: return default return raw_value.strip().lower() in {"1", "true", "yes", "on"} def load_dotenv_file(path: Path | None = None) -> None: dotenv_path = path or (SCRIPT_DIR / ".env") if not dotenv_path.exists(): return with dotenv_path.open("r", encoding="utf-8") as file_handle: for raw_line in file_handle: line = raw_line.strip() if not line or line.startswith("#"): continue if line.startswith("export "): line = line[len("export ") :].strip() if "=" not in line: continue key, value = line.split("=", 1) key = key.strip() value = value.strip() if not key: continue if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}: value = value[1:-1] os.environ.setdefault(key, value) def require_env(name: str) -> str: value = os.getenv(name, "").strip() if not value: raise ValueError(f"Missing required environment variable: {name}") return value def parse_channel_ids(raw_value: str) -> List[str]: normalized = raw_value.replace("\n", ",") channel_ids = [item.strip() for item in normalized.split(",") if item.strip()] if not channel_ids: raise ValueError("CHANNEL_IDS must contain at least one channel id.") return channel_ids def build_ssl_context() -> ssl.SSLContext: ca_bundle = os.getenv("MATTERMOST_CA_BUNDLE", "").strip() skip_tls_verify = parse_bool_env("MATTERMOST_SKIP_TLS_VERIFY", default=False) if skip_tls_verify: LOGGER.warning("TLS certificate verification is disabled via MATTERMOST_SKIP_TLS_VERIFY.") return ssl._create_unverified_context() if ca_bundle: LOGGER.info("Using custom CA bundle from MATTERMOST_CA_BUNDLE: %s", ca_bundle) return ssl.create_default_context(cafile=ca_bundle) return ssl.create_default_context() def configure() -> None: global MATTERMOST_URL, CHANNEL_IDS, WINDOW_HOURS, MAX_MESSAGES, CUTOFF_TIMESTAMP_MS, OUTPUT_FILE global REQUEST_HEADERS, SSL_CONTEXT load_dotenv_file() MATTERMOST_URL = require_env("MATTERMOST_URL").rstrip("/") token = require_env("MATTERMOST_TOKEN") CHANNEL_IDS = parse_channel_ids(require_env("CHANNEL_IDS")) WINDOW_HOURS = int(os.getenv("MESSAGE_WINDOW_HOURS", str(DEFAULT_WINDOW_HOURS))) MAX_MESSAGES = int(os.getenv("MAX_MESSAGES", str(DEFAULT_MAX_MESSAGES))) OUTPUT_FILE = os.getenv("MATTERMOST_OUTPUT_FILE", DEFAULT_OUTPUT_FILE).strip() or DEFAULT_OUTPUT_FILE if WINDOW_HOURS <= 0: raise ValueError("MESSAGE_WINDOW_HOURS must be greater than 0.") if MAX_MESSAGES <= 0: raise ValueError("MAX_MESSAGES must be greater than 0.") cutoff = datetime.now().astimezone() - timedelta(hours=WINDOW_HOURS) CUTOFF_TIMESTAMP_MS = int(cutoff.timestamp() * 1000) REQUEST_HEADERS = { "Authorization": f"Bearer {token}", "Accept": "application/json", "Content-Type": "application/json", } SSL_CONTEXT = build_ssl_context() def api_get_json(api_path: str, params: Dict[str, Any] | None = None) -> Dict[str, Any]: query = f"?{parse.urlencode(params)}" if params else "" url = f"{MATTERMOST_URL}{api_path}{query}" req = request.Request(url, headers=REQUEST_HEADERS, method="GET") try: with request.urlopen(req, timeout=REQUEST_TIMEOUT, context=SSL_CONTEXT) as response: charset = response.headers.get_content_charset() or "utf-8" payload = response.read().decode(charset) except error.HTTPError as exc: body = "" try: body = exc.read().decode("utf-8", errors="replace") except Exception: body = "" raise MattermostAPIError(f"HTTP {exc.code} for {api_path}: {body or exc.reason}") from exc except error.URLError as exc: reason = exc.reason if hasattr(exc, "reason") else str(exc) raise MattermostAPIError(f"Request failed for {api_path}: {reason}") from exc try: return json.loads(payload) except json.JSONDecodeError as exc: raise MattermostAPIError(f"Invalid JSON returned by {api_path}") from exc def get_channel_posts(channel_id: str) -> List[Dict[str, Any]]: collected: List[Dict[str, Any]] = [] page = 0 per_page = min(MAX_PER_PAGE, MAX_MESSAGES) while len(collected) < MAX_MESSAGES: payload = api_get_json( f"/api/v4/channels/{channel_id}/posts", {"page": page, "per_page": per_page}, ) order = payload.get("order", []) posts_by_id = payload.get("posts", {}) if not order: break reached_cutoff = False for post_id in order: post = posts_by_id.get(post_id) if not post: continue created_at = int(post.get("create_at", 0)) if created_at < CUTOFF_TIMESTAMP_MS: reached_cutoff = True continue collected.append(post) if len(collected) >= MAX_MESSAGES: break if reached_cutoff or len(order) < per_page: break page += 1 LOGGER.info("Fetched %s raw posts from channel %s", len(collected), channel_id) return collected def get_user_info(user_id: str) -> Dict[str, Any]: if not user_id: return {"id": "", "username": "unknown"} if user_id in USER_CACHE: return USER_CACHE[user_id] try: user_data = api_get_json(f"/api/v4/users/{user_id}") except MattermostAPIError as exc: LOGGER.error("Could not fetch user %s: %s", user_id, exc) fallback = {"id": user_id, "username": user_id} USER_CACHE[user_id] = fallback return fallback USER_CACHE[user_id] = user_data return user_data def build_user_map(messages: List[Dict[str, Any]]) -> Dict[str, str]: user_map: Dict[str, str] = {} user_ids = sorted({message["user_id"] for message in messages if message.get("user_id")}) for user_id in user_ids: user_info = get_user_info(user_id) username = ( user_info.get("username") or user_info.get("nickname") or user_info.get("first_name") or user_id ) user_map[user_id] = username return user_map def is_system_message(post: Dict[str, Any]) -> bool: post_type = (post.get("type") or "").strip() return post_type.startswith("system_") def extract_messages() -> List[Dict[str, Any]]: all_messages: List[Dict[str, Any]] = [] for channel_id in CHANNEL_IDS: raw_posts = get_channel_posts(channel_id) for post in raw_posts: if is_system_message(post): continue message = (post.get("message") or "").strip() if not message: continue all_messages.append( { "channel_id": channel_id, "channel_ref": channel_id, "post_id": post.get("id", ""), "user_id": post.get("user_id", ""), "create_at": int(post.get("create_at", 0)), "message": message.replace("\r\n", "\n"), "root_id": post.get("root_id", ""), "reply_count": int(post.get("reply_count", 0)), } ) all_messages.sort(key=lambda item: item["create_at"]) if len(all_messages) > MAX_MESSAGES: all_messages = all_messages[-MAX_MESSAGES:] user_map = build_user_map(all_messages) for message in all_messages: message["username"] = user_map.get(message["user_id"], message["user_id"] or "unknown") LOGGER.info("Prepared %s messages after filtering", len(all_messages)) return all_messages def format_messages(messages: List[Dict[str, Any]]) -> str: lines: List[str] = [] for message in messages: timestamp = datetime.fromtimestamp(message["create_at"] / 1000).astimezone() username = message.get("username", "unknown") post_id = message.get("post_id", "") root_id = message.get("root_id", "") thread_id = root_id or post_id or "unknown" reply_count = int(message.get("reply_count", 0)) if root_id: message_kind = "thread_reply" elif reply_count > 0: message_kind = "thread_root" else: message_kind = "channel_post" channel_ref = message.get("channel_ref", message.get("channel_id", "unknown")) record = { "source": "mattermost", "channel": channel_ref, "channel_id": message.get("channel_id", ""), "post_id": post_id, "thread_id": thread_id, "root_id": root_id or None, "type": message_kind, "timestamp": timestamp.isoformat(), "username": username, "message": message["message"], } if root_id: record["reply_to"] = root_id if reply_count > 0: record["reply_count"] = reply_count lines.append(json.dumps(record, ensure_ascii=False, sort_keys=False)) return "\n".join(lines) def save_to_file(text: str) -> None: output_path = Path(OUTPUT_FILE).expanduser().resolve() output_path.parent.mkdir(parents=True, exist_ok=True) if text and not text.endswith("\n"): text = f"{text}\n" output_path.write_text(text, encoding="utf-8") def main() -> int: logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") try: configure() messages = extract_messages() output = format_messages(messages) print(output) save_to_file(output) LOGGER.info("Saved context to %s", OUTPUT_FILE) except ValueError as exc: LOGGER.error("%s", exc) return 1 except MattermostAPIError as exc: LOGGER.error("%s", exc) return 1 except Exception as exc: LOGGER.exception("Unexpected error: %s", exc) return 1 return 0 if __name__ == "__main__": sys.exit(main())