"""
GALXEE AI CfAM — Containment for Agentic Middleware
Python SDK v1.0.0

Project: GALXEE AI CfAM (First Agentic AI Containment Middleware)
Applicant: GALXEE AI CfAM, Denver
Website: galxee.ai

USAGE — Add these lines to your AI Agent:

    from containai_middleware import AgentMiddleware

    agent = AgentMiddleware(
        policy="strict",
        allowed_domains=["api.openai.com"],
        max_requests_per_minute=60
    )

    @agent.secure
    async def run_agent(task: str):
        return await my_ai_agent.execute(task)

LICENSE: Apache 2.0
"""

from __future__ import annotations

import asyncio
import hashlib
import hmac
import json
import logging
import re
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Set
from urllib.parse import urlparse

# ─── Logging ──────────────────────────────────────────────────────────────────
logger = logging.getLogger("containai")


# ─── Enums ────────────────────────────────────────────────────────────────────
class PolicyMode(str, Enum):
    STRICT = "strict"       # Block everything not explicitly allowed
    PERMISSIVE = "permissive"  # Allow everything not explicitly blocked
    AUDIT = "audit"         # Allow all but log every action


class ActionType(str, Enum):
    HTTP_REQUEST = "http_request"
    DATABASE_QUERY = "database_query"
    FILE_SYSTEM = "file_system"
    SUBPROCESS = "subprocess"
    NETWORK_SOCKET = "network_socket"
    ENVIRONMENT_ACCESS = "environment_access"


class VerdictStatus(str, Enum):
    ALLOWED = "ALLOWED"
    BLOCKED = "BLOCKED"
    RATE_LIMITED = "RATE_LIMITED"
    POLICY_VIOLATION = "POLICY_VIOLATION"
    SANDBOX_VIOLATION = "SANDBOX_VIOLATION"


# ─── Data classes ─────────────────────────────────────────────────────────────
@dataclass
class ContainmentPolicy:
    """
    Cryptographically signed policy manifest.
    Defines what an agent is permitted to do.
    """
    mode: PolicyMode = PolicyMode.STRICT
    allowed_domains: List[str] = field(default_factory=list)
    blocked_domains: List[str] = field(default_factory=list)
    allowed_methods: List[str] = field(default_factory=lambda: ["GET", "POST"])
    max_requests_per_minute: int = 60
    max_payload_bytes: int = 1_048_576  # 1 MB
    sandbox: bool = True
    ephemeral_context: bool = True       # Wipe context between sessions
    audit_log: bool = True
    allowed_action_types: Set[ActionType] = field(
        default_factory=lambda: {ActionType.HTTP_REQUEST}
    )
    policy_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    created_at: str = field(
        default_factory=lambda: datetime.now(timezone.utc).isoformat()
    )

    def sign(self, secret_key: str) -> str:
        """Produce an HMAC-SHA256 signature over the policy manifest."""
        payload = json.dumps(
            {
                "policy_id": self.policy_id,
                "mode": self.mode,
                "allowed_domains": sorted(self.allowed_domains),
                "blocked_domains": sorted(self.blocked_domains),
                "max_requests_per_minute": self.max_requests_per_minute,
                "created_at": self.created_at,
            },
            sort_keys=True,
        )
        return hmac.new(
            secret_key.encode(), payload.encode(), hashlib.sha256
        ).hexdigest()

    def verify(self, secret_key: str, signature: str) -> bool:
        """Verify the policy has not been tampered with."""
        expected = self.sign(secret_key)
        return hmac.compare_digest(expected, signature)


@dataclass
class AuditEvent:
    """Immutable audit log entry for every agent action."""
    event_id: str
    session_id: str
    agent_id: str
    action_type: ActionType
    target: str
    method: Optional[str]
    verdict: VerdictStatus
    reason: str
    latency_ms: float
    payload_bytes: int
    timestamp: str
    policy_id: str

    def to_dict(self) -> Dict[str, Any]:
        return {
            "event_id": self.event_id,
            "session_id": self.session_id,
            "agent_id": self.agent_id,
            "action_type": self.action_type.value,
            "target": self.target,
            "method": self.method,
            "verdict": self.verdict.value,
            "reason": self.reason,
            "latency_ms": round(self.latency_ms, 3),
            "payload_bytes": self.payload_bytes,
            "timestamp": self.timestamp,
            "policy_id": self.policy_id,
        }


@dataclass
class ContainmentVerdict:
    """Result of the containment layer evaluation."""
    status: VerdictStatus
    reason: str
    allowed: bool
    event: Optional[AuditEvent] = None


# ─── Rate limiter ─────────────────────────────────────────────────────────────
class SlidingWindowRateLimiter:
    """Thread-safe sliding window rate limiter."""

    def __init__(self, max_requests: int, window_seconds: int = 60):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self._timestamps: List[float] = []
        self._lock = asyncio.Lock()

    async def check(self) -> tuple[bool, int]:
        """Returns (allowed, requests_remaining)."""
        async with self._lock:
            now = time.monotonic()
            cutoff = now - self.window_seconds
            self._timestamps = [t for t in self._timestamps if t > cutoff]
            if len(self._timestamps) >= self.max_requests:
                return False, 0
            self._timestamps.append(now)
            return True, self.max_requests - len(self._timestamps)


# ─── Domain validator ─────────────────────────────────────────────────────────
class DomainValidator:
    """Validates request targets against the policy allowlist/blocklist."""

    @staticmethod
    def extract_domain(target: str) -> str:
        try:
            parsed = urlparse(target if "://" in target else f"https://{target}")
            return parsed.netloc.lower().split(":")[0]
        except Exception:
            return target.lower()

    @staticmethod
    def matches(domain: str, pattern: str) -> bool:
        """Supports exact match and wildcard prefix (*.example.com)."""
        if pattern.startswith("*."):
            suffix = pattern[2:]
            return domain == suffix or domain.endswith(f".{suffix}")
        return domain == pattern

    def is_allowed(
        self,
        target: str,
        allowed: List[str],
        blocked: List[str],
        mode: PolicyMode,
    ) -> tuple[bool, str]:
        domain = self.extract_domain(target)

        # Blocklist always wins
        for pattern in blocked:
            if self.matches(domain, pattern):
                return False, f"Domain '{domain}' is on the blocklist"

        # Allowlist check
        for pattern in allowed:
            if self.matches(domain, pattern):
                return True, f"Domain '{domain}' is allowlisted"

        if mode == PolicyMode.STRICT:
            return False, f"Domain '{domain}' is not on the allowlist (strict mode)"
        elif mode == PolicyMode.PERMISSIVE:
            return True, f"Domain '{domain}' allowed (permissive mode)"
        else:  # AUDIT
            return True, f"Domain '{domain}' allowed (audit mode — logging only)"


# ─── Payload inspector ────────────────────────────────────────────────────────
class PayloadInspector:
    """Scans request payloads for policy violations."""

    # Patterns that indicate potential prompt injection or exfiltration
    SUSPICIOUS_PATTERNS = [
        r"ignore previous instructions",
        r"system prompt",
        r"jailbreak",
        r"bypass.*filter",
        r"exfiltrate",
        r"send.*credentials",
        r"<script",
        r"eval\s*\(",
        r"__import__",
        r"os\.system",
        r"subprocess\.call",
    ]

    def __init__(self):
        self._compiled = [
            re.compile(p, re.IGNORECASE) for p in self.SUSPICIOUS_PATTERNS
        ]

    def inspect(self, payload: Any) -> tuple[bool, str]:
        """Returns (clean, reason). clean=False means suspicious content found."""
        text = json.dumps(payload) if not isinstance(payload, str) else payload
        for pattern in self._compiled:
            if pattern.search(text):
                return False, f"Suspicious pattern detected: {pattern.pattern}"
        return True, "Payload clean"


# ─── Audit logger ─────────────────────────────────────────────────────────────
class AuditLogger:
    """Append-only, tamper-evident audit log."""

    def __init__(self, log_file: Optional[str] = None):
        self._events: List[AuditEvent] = []
        self._log_file = log_file
        self._chain_hash = "genesis"

    def record(self, event: AuditEvent) -> str:
        """Record an event and return its chain hash."""
        event_json = json.dumps(event.to_dict(), sort_keys=True)
        chain_input = f"{self._chain_hash}:{event_json}"
        self._chain_hash = hashlib.sha256(chain_input.encode()).hexdigest()
        self._events.append(event)

        log_entry = {**event.to_dict(), "chain_hash": self._chain_hash}
        logger.info(
            "[%s] %s %s → %s (%s)",
            event.verdict.value,
            event.method or event.action_type.value,
            event.target,
            event.verdict.value,
            event.reason,
        )

        if self._log_file:
            with open(self._log_file, "a") as f:
                f.write(json.dumps(log_entry) + "\n")

        return self._chain_hash

    def get_events(self) -> List[Dict[str, Any]]:
        return [e.to_dict() for e in self._events]

    def verify_chain(self) -> bool:
        """Verify the integrity of the entire audit chain."""
        chain = "genesis"
        for event in self._events:
            event_json = json.dumps(event.to_dict(), sort_keys=True)
            chain = hashlib.sha256(f"{chain}:{event_json}".encode()).hexdigest()
        return chain == self._chain_hash


# ─── Containment Layer ────────────────────────────────────────────────────────
class ContainmentLayer:
    """
    Core containment engine.
    Evaluates every proposed agent action before execution.
    """

    def __init__(self, policy: ContainmentPolicy, agent_id: str):
        self.policy = policy
        self.agent_id = agent_id
        self._rate_limiter = SlidingWindowRateLimiter(policy.max_requests_per_minute)
        self._domain_validator = DomainValidator()
        self._payload_inspector = PayloadInspector()
        self._audit_logger = AuditLogger()
        self._session_id = str(uuid.uuid4())

    async def evaluate(
        self,
        action_type: ActionType,
        target: str,
        method: Optional[str] = None,
        payload: Any = None,
        payload_bytes: int = 0,
    ) -> ContainmentVerdict:
        """
        Evaluate a proposed action against the containment policy.
        Returns a ContainmentVerdict — NEVER execute the action if verdict.allowed is False.
        """
        start = time.monotonic()

        # 1. Action type check
        if action_type not in self.policy.allowed_action_types:
            return await self._deny(
                action_type, target, method, payload_bytes, start,
                VerdictStatus.POLICY_VIOLATION,
                f"Action type '{action_type.value}' is not permitted by policy",
            )

        # 2. Rate limit check
        allowed_rate, remaining = await self._rate_limiter.check()
        if not allowed_rate:
            return await self._deny(
                action_type, target, method, payload_bytes, start,
                VerdictStatus.RATE_LIMITED,
                f"Rate limit exceeded: {self.policy.max_requests_per_minute} req/min",
            )

        # 3. Domain / target validation
        domain_ok, domain_reason = self._domain_validator.is_allowed(
            target,
            self.policy.allowed_domains,
            self.policy.blocked_domains,
            self.policy.mode,
        )
        if not domain_ok:
            return await self._deny(
                action_type, target, method, payload_bytes, start,
                VerdictStatus.BLOCKED,
                domain_reason,
            )

        # 4. HTTP method check
        if method and method.upper() not in [m.upper() for m in self.policy.allowed_methods]:
            return await self._deny(
                action_type, target, method, payload_bytes, start,
                VerdictStatus.POLICY_VIOLATION,
                f"HTTP method '{method}' is not permitted by policy",
            )

        # 5. Payload size check
        if payload_bytes > self.policy.max_payload_bytes:
            return await self._deny(
                action_type, target, method, payload_bytes, start,
                VerdictStatus.POLICY_VIOLATION,
                f"Payload size {payload_bytes} bytes exceeds limit {self.policy.max_payload_bytes}",
            )

        # 6. Payload inspection (prompt injection / exfiltration detection)
        if payload is not None:
            clean, inspect_reason = self._payload_inspector.inspect(payload)
            if not clean:
                return await self._deny(
                    action_type, target, method, payload_bytes, start,
                    VerdictStatus.POLICY_VIOLATION,
                    inspect_reason,
                )

        # ✓ All checks passed — allow
        latency = (time.monotonic() - start) * 1000
        event = AuditEvent(
            event_id=str(uuid.uuid4()),
            session_id=self._session_id,
            agent_id=self.agent_id,
            action_type=action_type,
            target=target,
            method=method,
            verdict=VerdictStatus.ALLOWED,
            reason=domain_reason,
            latency_ms=latency,
            payload_bytes=payload_bytes,
            timestamp=datetime.now(timezone.utc).isoformat(),
            policy_id=self.policy.policy_id,
        )
        self._audit_logger.record(event)
        return ContainmentVerdict(
            status=VerdictStatus.ALLOWED,
            reason=domain_reason,
            allowed=True,
            event=event,
        )

    async def _deny(
        self,
        action_type: ActionType,
        target: str,
        method: Optional[str],
        payload_bytes: int,
        start: float,
        status: VerdictStatus,
        reason: str,
    ) -> ContainmentVerdict:
        latency = (time.monotonic() - start) * 1000
        event = AuditEvent(
            event_id=str(uuid.uuid4()),
            session_id=self._session_id,
            agent_id=self.agent_id,
            action_type=action_type,
            target=target,
            method=method,
            verdict=status,
            reason=reason,
            latency_ms=latency,
            payload_bytes=payload_bytes,
            timestamp=datetime.now(timezone.utc).isoformat(),
            policy_id=self.policy.policy_id,
        )
        self._audit_logger.record(event)
        return ContainmentVerdict(
            status=status,
            reason=reason,
            allowed=False,
            event=event,
        )

    def get_audit_log(self) -> List[Dict[str, Any]]:
        return self._audit_logger.get_events()

    def verify_audit_integrity(self) -> bool:
        return self._audit_logger.verify_chain()

    def new_session(self):
        """Start a new ephemeral session (wipes context if ephemeral_context=True)."""
        self._session_id = str(uuid.uuid4())
        if self.policy.ephemeral_context:
            self._rate_limiter = SlidingWindowRateLimiter(
                self.policy.max_requests_per_minute
            )


# ─── Main middleware decorator ────────────────────────────────────────────────
class AgentMiddleware:
    """
    GALXEE AI CfAM — Containment for Agentic Middleware

    Drop-in security wrapper for any AI agent function.

    QUICK START:
        from containai_middleware import AgentMiddleware

        agent = AgentMiddleware(
            policy="strict",
            allowed_domains=["api.openai.com"],
            max_requests_per_minute=60
        )

        @agent.secure
        async def run_agent(task: str):
            return await my_ai_agent.execute(task)
    """

    def __init__(
        self,
        policy: str | ContainmentPolicy = "strict",
        allowed_domains: Optional[List[str]] = None,
        blocked_domains: Optional[List[str]] = None,
        allowed_methods: Optional[List[str]] = None,
        max_requests_per_minute: int = 60,
        max_payload_bytes: int = 1_048_576,
        sandbox: bool = True,
        audit_log: bool = True,
        agent_id: Optional[str] = None,
        log_file: Optional[str] = None,
    ):
        if isinstance(policy, str):
            policy = ContainmentPolicy(
                mode=PolicyMode(policy),
                allowed_domains=allowed_domains or [],
                blocked_domains=blocked_domains or [],
                allowed_methods=allowed_methods or ["GET", "POST"],
                max_requests_per_minute=max_requests_per_minute,
                max_payload_bytes=max_payload_bytes,
                sandbox=sandbox,
                audit_log=audit_log,
            )

        self.policy = policy
        self.agent_id = agent_id or f"agent-{uuid.uuid4().hex[:8]}"
        self._containment = ContainmentLayer(policy, self.agent_id)

        logger.info(
            "ContainAI middleware initialized | agent=%s | policy=%s | mode=%s",
            self.agent_id,
            policy.policy_id,
            policy.mode.value,
        )

    def secure(self, func: Callable) -> Callable:
        """
        Decorator that wraps an agent function with containment.

        Usage:
            @agent.secure
            async def run_agent(task: str):
                ...
        """
        @wraps(func)
        async def wrapper(*args, **kwargs):
            self._containment.new_session()
            try:
                return await func(*args, **kwargs)
            finally:
                if self.policy.ephemeral_context:
                    # Wipe ephemeral context after execution
                    self._containment.new_session()

        return wrapper

    async def check_request(
        self,
        url: str,
        method: str = "GET",
        payload: Any = None,
        payload_bytes: Optional[int] = None,
    ) -> ContainmentVerdict:
        """
        Manually check a request before making it.
        Use this inside your agent when you need fine-grained control.

        Example:
            verdict = await agent.check_request("https://api.openai.com/v1/chat", "POST", body)
            if not verdict.allowed:
                raise PermissionError(verdict.reason)
        """
        if payload_bytes is None and payload is not None:
            payload_str = json.dumps(payload) if not isinstance(payload, str) else payload
            payload_bytes = len(payload_str.encode())

        return await self._containment.evaluate(
            action_type=ActionType.HTTP_REQUEST,
            target=url,
            method=method,
            payload=payload,
            payload_bytes=payload_bytes or 0,
        )

    def get_audit_log(self) -> List[Dict[str, Any]]:
        """Return all audit events for this agent session."""
        return self._containment.get_audit_log()

    def verify_audit_integrity(self) -> bool:
        """Verify the tamper-evident audit chain is intact."""
        return self._containment.verify_audit_integrity()

    @property
    def containment(self) -> ContainmentLayer:
        """Direct access to the containment layer for advanced use."""
        return self._containment
