Files
fidelity-ai-workspace/scripts/mattermost/mattermost_context.py

590 lines
20 KiB
Python

#!/usr/bin/env python3
import json
import logging
import os
import re
import ssl
import sys
from argparse import ArgumentParser, Namespace
from datetime import date, datetime, time, timedelta
from pathlib import Path
from typing import Any, Dict, List
from urllib import error, parse, request
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_WINDOW_HOURS = 24
DEFAULT_MAX_MESSAGES = 200
MAX_PER_PAGE = 200
DEFAULT_OUTPUT_FILE = str(SCRIPT_DIR / "generated" / "mattermost_context.jsonl")
REQUEST_TIMEOUT = 15
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
LOGGER = logging.getLogger("mattermost_context")
USER_CACHE: Dict[str, Dict[str, Any]] = {}
CHANNEL_CACHE: Dict[str, Dict[str, Any]] = {}
TEAM_CACHE: List[Dict[str, Any]] | None = None
MATTERMOST_URL = ""
CHANNEL_SPECS: List[Dict[str, str]] = []
WINDOW_HOURS = DEFAULT_WINDOW_HOURS
MAX_MESSAGES = DEFAULT_MAX_MESSAGES
CUTOFF_TIMESTAMP_MS = 0
RANGE_START_TIMESTAMP_MS = 0
RANGE_END_TIMESTAMP_MS = 0
OUTPUT_FILE = DEFAULT_OUTPUT_FILE
REQUEST_HEADERS: Dict[str, str] = {}
SSL_CONTEXT: ssl.SSLContext | None = None
MATTERMOST_TEAM_NAME = ""
MATTERMOST_TEAM_ID = ""
class MattermostAPIError(RuntimeError):
pass
def parse_args() -> Namespace:
parser = ArgumentParser(description="Extract Mattermost messages as JSONL context.")
parser.add_argument(
"--previous-workday",
action="store_true",
help="Fetch the latest prior calendar day with Mattermost activity instead of a fixed recent window.",
)
parser.add_argument(
"--today",
default=date.today().isoformat(),
help="Reference date in YYYY-MM-DD format. Defaults to today.",
)
parser.add_argument(
"--max-lookback-days",
type=int,
default=int(os.getenv("MATTERMOST_MAX_LOOKBACK_DAYS", "7")),
help="Maximum days to search backward with --previous-workday.",
)
parser.add_argument(
"--window-hours",
type=int,
default=0,
help="Override MESSAGE_WINDOW_HOURS for normal recent-window mode.",
)
parser.add_argument(
"--output-file",
default="",
help="Override MATTERMOST_OUTPUT_FILE.",
)
return parser.parse_args()
def parse_iso_date(raw_value: str) -> date:
if not DATE_RE.match(raw_value):
raise ValueError(f"Invalid date '{raw_value}'. Use YYYY-MM-DD.")
return datetime.strptime(raw_value, "%Y-%m-%d").date()
def parse_bool_env(name: str, default: bool = False) -> bool:
raw_value = os.getenv(name)
if raw_value is None:
return default
return raw_value.strip().lower() in {"1", "true", "yes", "on"}
def load_dotenv_file(path: Path | None = None) -> None:
dotenv_path = path or (SCRIPT_DIR / ".env")
if not dotenv_path.exists():
return
with dotenv_path.open("r", encoding="utf-8") as file_handle:
for raw_line in file_handle:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[len("export ") :].strip()
if "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
if not key:
continue
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
value = value[1:-1]
os.environ.setdefault(key, value)
def require_env(name: str) -> str:
value = os.getenv(name, "").strip()
if not value:
raise ValueError(f"Missing required environment variable: {name}")
return value
def parse_channel_ids(raw_value: str) -> List[str]:
normalized = raw_value.replace("\n", ",")
channel_ids = [item.strip() for item in normalized.split(",") if item.strip()]
if not channel_ids:
raise ValueError("CHANNEL_IDS must contain at least one channel id.")
return channel_ids
def looks_like_channel_id(value: str) -> bool:
return len(value) == 26 and value.isalnum()
def parse_channel_specs() -> List[Dict[str, str]]:
raw_channels = os.getenv("CHANNELS", "").strip()
raw_channel_names = os.getenv("CHANNEL_NAMES", "").strip()
raw_channel_ids = os.getenv("CHANNEL_IDS", "").strip()
entries: List[str] = []
if raw_channels:
entries.extend(parse_channel_ids(raw_channels))
if raw_channel_names:
entries.extend(parse_channel_ids(raw_channel_names))
if raw_channel_ids:
entries.extend(parse_channel_ids(raw_channel_ids))
if not entries:
raise ValueError("Configure at least one of CHANNELS, CHANNEL_NAMES, or CHANNEL_IDS.")
specs: List[Dict[str, str]] = []
for entry in entries:
if looks_like_channel_id(entry):
specs.append({"kind": "id", "value": entry})
else:
specs.append({"kind": "name", "value": entry.lstrip("#")})
return specs
def build_ssl_context() -> ssl.SSLContext:
ca_bundle = os.getenv("MATTERMOST_CA_BUNDLE", "").strip()
skip_tls_verify = parse_bool_env("MATTERMOST_SKIP_TLS_VERIFY", default=False)
if skip_tls_verify:
LOGGER.warning("TLS certificate verification is disabled via MATTERMOST_SKIP_TLS_VERIFY.")
return ssl._create_unverified_context()
if ca_bundle:
LOGGER.info("Using custom CA bundle from MATTERMOST_CA_BUNDLE: %s", ca_bundle)
return ssl.create_default_context(cafile=ca_bundle)
return ssl.create_default_context()
def configure(args: Namespace) -> None:
global MATTERMOST_URL, CHANNEL_SPECS, WINDOW_HOURS, MAX_MESSAGES, CUTOFF_TIMESTAMP_MS, OUTPUT_FILE
global RANGE_START_TIMESTAMP_MS, RANGE_END_TIMESTAMP_MS, REQUEST_HEADERS, SSL_CONTEXT
global MATTERMOST_TEAM_NAME, MATTERMOST_TEAM_ID
load_dotenv_file()
MATTERMOST_URL = require_env("MATTERMOST_URL").rstrip("/")
token = require_env("MATTERMOST_TOKEN")
CHANNEL_SPECS = parse_channel_specs()
WINDOW_HOURS = args.window_hours or int(os.getenv("MESSAGE_WINDOW_HOURS", str(DEFAULT_WINDOW_HOURS)))
MAX_MESSAGES = int(os.getenv("MAX_MESSAGES", str(DEFAULT_MAX_MESSAGES)))
OUTPUT_FILE = args.output_file or os.getenv("MATTERMOST_OUTPUT_FILE", DEFAULT_OUTPUT_FILE).strip() or DEFAULT_OUTPUT_FILE
MATTERMOST_TEAM_NAME = os.getenv("MATTERMOST_TEAM_NAME", "").strip()
MATTERMOST_TEAM_ID = os.getenv("MATTERMOST_TEAM_ID", "").strip()
if WINDOW_HOURS <= 0:
raise ValueError("MESSAGE_WINDOW_HOURS must be greater than 0.")
if MAX_MESSAGES <= 0:
raise ValueError("MAX_MESSAGES must be greater than 0.")
cutoff = datetime.now().astimezone() - timedelta(hours=WINDOW_HOURS)
CUTOFF_TIMESTAMP_MS = int(cutoff.timestamp() * 1000)
RANGE_START_TIMESTAMP_MS = 0
RANGE_END_TIMESTAMP_MS = 0
REQUEST_HEADERS = {
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
SSL_CONTEXT = build_ssl_context()
def api_get_json(api_path: str, params: Dict[str, Any] | None = None) -> Dict[str, Any]:
query = f"?{parse.urlencode(params)}" if params else ""
url = f"{MATTERMOST_URL}{api_path}{query}"
req = request.Request(url, headers=REQUEST_HEADERS, method="GET")
try:
with request.urlopen(req, timeout=REQUEST_TIMEOUT, context=SSL_CONTEXT) as response:
charset = response.headers.get_content_charset() or "utf-8"
payload = response.read().decode(charset)
except error.HTTPError as exc:
body = ""
try:
body = exc.read().decode("utf-8", errors="replace")
except Exception:
body = ""
raise MattermostAPIError(f"HTTP {exc.code} for {api_path}: {body or exc.reason}") from exc
except error.URLError as exc:
reason = exc.reason if hasattr(exc, "reason") else str(exc)
raise MattermostAPIError(f"Request failed for {api_path}: {reason}") from exc
try:
return json.loads(payload)
except json.JSONDecodeError as exc:
raise MattermostAPIError(f"Invalid JSON returned by {api_path}") from exc
def get_channel_posts(channel_id: str) -> List[Dict[str, Any]]:
collected: List[Dict[str, Any]] = []
page = 0
per_page = min(MAX_PER_PAGE, MAX_MESSAGES)
start_timestamp_ms = RANGE_START_TIMESTAMP_MS or CUTOFF_TIMESTAMP_MS
end_timestamp_ms = RANGE_END_TIMESTAMP_MS
while len(collected) < MAX_MESSAGES:
payload = api_get_json(
f"/api/v4/channels/{channel_id}/posts",
{"page": page, "per_page": per_page},
)
order = payload.get("order", [])
posts_by_id = payload.get("posts", {})
if not order:
break
reached_cutoff = False
for post_id in order:
post = posts_by_id.get(post_id)
if not post:
continue
created_at = int(post.get("create_at", 0))
if created_at < start_timestamp_ms:
reached_cutoff = True
continue
if end_timestamp_ms and created_at >= end_timestamp_ms:
continue
collected.append(post)
if len(collected) >= MAX_MESSAGES:
break
if reached_cutoff or len(order) < per_page:
break
page += 1
LOGGER.info("Fetched %s raw posts from channel %s", len(collected), channel_id)
return collected
def get_channel_by_id(channel_id: str) -> Dict[str, Any]:
cache_key = f"id:{channel_id}"
if cache_key in CHANNEL_CACHE:
return CHANNEL_CACHE[cache_key]
channel_data = api_get_json(f"/api/v4/channels/{channel_id}")
CHANNEL_CACHE[cache_key] = channel_data
return channel_data
def get_channel_by_name(channel_name: str) -> Dict[str, Any]:
cache_key = f"name:{channel_name}"
if cache_key in CHANNEL_CACHE:
return CHANNEL_CACHE[cache_key]
channel_data: Dict[str, Any]
if MATTERMOST_TEAM_ID or MATTERMOST_TEAM_NAME:
channel_data = get_channel_by_name_for_team(channel_name)
else:
channel_data = find_channel_across_user_teams(channel_name)
CHANNEL_CACHE[cache_key] = channel_data
CHANNEL_CACHE[f"id:{channel_data.get('id', '')}"] = channel_data
return channel_data
def get_channel_by_name_for_team(channel_name: str) -> Dict[str, Any]:
if MATTERMOST_TEAM_ID:
api_path = f"/api/v4/teams/{MATTERMOST_TEAM_ID}/channels/name/{parse.quote(channel_name, safe='')}"
else:
api_path = f"/api/v4/teams/name/{parse.quote(MATTERMOST_TEAM_NAME, safe='')}/channels/name/{parse.quote(channel_name, safe='')}"
return api_get_json(api_path)
def get_user_teams() -> List[Dict[str, Any]]:
global TEAM_CACHE
if TEAM_CACHE is not None:
return TEAM_CACHE
teams = api_get_json("/api/v4/users/me/teams")
if not isinstance(teams, list):
raise MattermostAPIError("Unexpected response while listing user teams.")
TEAM_CACHE = teams
return TEAM_CACHE
def get_user_channels_for_team(team_id: str) -> List[Dict[str, Any]]:
channels = api_get_json(f"/api/v4/users/me/teams/{team_id}/channels")
if not isinstance(channels, list):
raise MattermostAPIError(f"Unexpected response while listing channels for team {team_id}.")
return channels
def find_channel_across_user_teams(channel_name: str) -> Dict[str, Any]:
matches: List[Dict[str, Any]] = []
for team in get_user_teams():
team_id = team.get("id", "")
team_name = team.get("name", "")
if not team_id:
continue
try:
channels = get_user_channels_for_team(team_id)
except MattermostAPIError as exc:
LOGGER.warning("Could not list channels for team %s: %s", team_name or team_id, exc)
continue
for channel in channels:
if channel.get("name") == channel_name:
channel = dict(channel)
channel["_resolved_team_name"] = team_name
channel["_resolved_team_id"] = team_id
matches.append(channel)
if not matches:
raise MattermostAPIError(
f"Unable to find channel named '{channel_name}' in the current user's accessible teams."
)
if len(matches) > 1:
teams = ", ".join(sorted({match.get("_resolved_team_name", match.get("_resolved_team_id", "unknown")) for match in matches}))
raise MattermostAPIError(
f"Channel name '{channel_name}' is ambiguous across teams: {teams}. Set MATTERMOST_TEAM_NAME or MATTERMOST_TEAM_ID."
)
return matches[0]
def resolve_channels() -> List[Dict[str, Any]]:
resolved: List[Dict[str, Any]] = []
for spec in CHANNEL_SPECS:
if spec["kind"] == "id":
channel_data = get_channel_by_id(spec["value"])
else:
channel_data = get_channel_by_name(spec["value"])
resolved.append(channel_data)
return resolved
def get_user_info(user_id: str) -> Dict[str, Any]:
if not user_id:
return {"id": "", "username": "unknown"}
if user_id in USER_CACHE:
return USER_CACHE[user_id]
try:
user_data = api_get_json(f"/api/v4/users/{user_id}")
except MattermostAPIError as exc:
LOGGER.error("Could not fetch user %s: %s", user_id, exc)
fallback = {"id": user_id, "username": user_id}
USER_CACHE[user_id] = fallback
return fallback
USER_CACHE[user_id] = user_data
return user_data
def build_user_map(messages: List[Dict[str, Any]]) -> Dict[str, str]:
user_map: Dict[str, str] = {}
user_ids = sorted({message["user_id"] for message in messages if message.get("user_id")})
for user_id in user_ids:
user_info = get_user_info(user_id)
username = (
user_info.get("username")
or user_info.get("nickname")
or user_info.get("first_name")
or user_id
)
user_map[user_id] = username
return user_map
def is_system_message(post: Dict[str, Any]) -> bool:
post_type = (post.get("type") or "").strip()
return post_type.startswith("system_")
def extract_messages(resolved_channels: List[Dict[str, Any]] | None = None) -> List[Dict[str, Any]]:
all_messages: List[Dict[str, Any]] = []
for channel in resolved_channels or resolve_channels():
channel_id = channel.get("id", "")
channel_name = channel.get("name", "") or channel_id
channel_display_name = channel.get("display_name", "") or channel_name
raw_posts = get_channel_posts(channel_id)
for post in raw_posts:
if is_system_message(post):
continue
message = (post.get("message") or "").strip()
if not message:
continue
all_messages.append(
{
"channel_id": channel_id,
"channel_ref": channel_name,
"channel_name": channel_name,
"channel_display_name": channel_display_name,
"post_id": post.get("id", ""),
"user_id": post.get("user_id", ""),
"create_at": int(post.get("create_at", 0)),
"message": message.replace("\r\n", "\n"),
"root_id": post.get("root_id", ""),
"reply_count": int(post.get("reply_count", 0)),
}
)
all_messages.sort(key=lambda item: item["create_at"])
if len(all_messages) > MAX_MESSAGES:
all_messages = all_messages[-MAX_MESSAGES:]
user_map = build_user_map(all_messages)
for message in all_messages:
message["username"] = user_map.get(message["user_id"], message["user_id"] or "unknown")
LOGGER.info("Prepared %s messages after filtering", len(all_messages))
return all_messages
def day_range_ms(day: date) -> tuple[int, int]:
start = datetime.combine(day, time.min).astimezone()
end = start + timedelta(days=1)
return int(start.timestamp() * 1000), int(end.timestamp() * 1000)
def set_fetch_range(start_ms: int, end_ms: int) -> None:
global RANGE_START_TIMESTAMP_MS, RANGE_END_TIMESTAMP_MS
RANGE_START_TIMESTAMP_MS = start_ms
RANGE_END_TIMESTAMP_MS = end_ms
def extract_previous_workday_messages(args: Namespace) -> tuple[List[Dict[str, Any]], date | None, int]:
today = parse_iso_date(args.today)
resolved_channels = resolve_channels()
max_lookback_days = args.max_lookback_days
if max_lookback_days <= 0:
raise ValueError("--max-lookback-days must be greater than 0.")
for skipped_days in range(max_lookback_days):
candidate_day = today - timedelta(days=skipped_days + 1)
start_ms, end_ms = day_range_ms(candidate_day)
set_fetch_range(start_ms, end_ms)
messages = extract_messages(resolved_channels)
if messages:
return messages, candidate_day, skipped_days
LOGGER.info("No messages found for %s; expanding lookback.", candidate_day.isoformat())
return [], None, max_lookback_days
def format_messages(messages: List[Dict[str, Any]]) -> str:
lines: List[str] = []
for message in messages:
timestamp = datetime.fromtimestamp(message["create_at"] / 1000).astimezone()
username = message.get("username", "unknown")
post_id = message.get("post_id", "")
root_id = message.get("root_id", "")
thread_id = root_id or post_id or "unknown"
reply_count = int(message.get("reply_count", 0))
if root_id:
message_kind = "thread_reply"
elif reply_count > 0:
message_kind = "thread_root"
else:
message_kind = "channel_post"
channel_ref = message.get("channel_ref", message.get("channel_id", "unknown"))
record = {
"source": "mattermost",
"channel": channel_ref,
"channel_id": message.get("channel_id", ""),
"channel_display_name": message.get("channel_display_name", channel_ref),
"post_id": post_id,
"thread_id": thread_id,
"root_id": root_id or None,
"type": message_kind,
"timestamp": timestamp.isoformat(),
"username": username,
"message": message["message"],
}
if root_id:
record["reply_to"] = root_id
if reply_count > 0:
record["reply_count"] = reply_count
lines.append(json.dumps(record, ensure_ascii=False, sort_keys=False))
return "\n".join(lines)
def save_to_file(text: str) -> None:
output_path = Path(OUTPUT_FILE).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
if text and not text.endswith("\n"):
text = f"{text}\n"
output_path.write_text(text, encoding="utf-8")
def main() -> int:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
try:
args = parse_args()
configure(args)
if args.previous_workday:
messages, selected_day, skipped_days = extract_previous_workday_messages(args)
if selected_day:
LOGGER.info(
"Selected previous workday %s after skipping %s inactive calendar day(s).",
selected_day.isoformat(),
skipped_days,
)
else:
LOGGER.info("No previous workday messages found within %s day(s).", skipped_days)
else:
messages = extract_messages()
output = format_messages(messages)
print(output)
save_to_file(output)
LOGGER.info("Saved context to %s", OUTPUT_FILE)
except ValueError as exc:
LOGGER.error("%s", exc)
return 1
except MattermostAPIError as exc:
LOGGER.error("%s", exc)
return 1
except Exception as exc:
LOGGER.exception("Unexpected error: %s", exc)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())