"""
SLA Monitor - Local SLAshield enforcement
"""
import time
import logging
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any

from .nvml_sampler import GPUTelemetry

logger = logging.getLogger(__name__)


@dataclass
class SLAConfig:
    """SLA Configuration"""
    max_performance_loss_percent: float = 5.0
    grace_period_seconds: int = 60
    measurement_window_samples: int = 10
    auto_revert: bool = True
    min_fps_threshold: Optional[float] = None


@dataclass
class PerformanceBaseline:
    """Baseline performance metrics before optimization"""
    timestamp: float
    avg_power_w: float
    avg_utilization: float
    avg_clock_mhz: float
    avg_throughput: Optional[float] = None
    profile_id: str = "default"


@dataclass
class SLABreachEvent:
    """SLA breach event record"""
    timestamp: float
    breach_type: str  # 'performance', 'throughput', 'thermal'
    threshold: float
    measured_value: float
    active_profile: str
    action_taken: str  # 'revert', 'alert', 'none'


class SLAMonitor:
    """
    Local SLA monitor that tracks performance and triggers
    automatic profile reversion on SLA breaches.
    """
    
    def __init__(self, config: SLAConfig):
        """
        Initialize SLA Monitor.
        
        Args:
            config: SLA configuration
        """
        self.config = config
        self._baseline: Optional[PerformanceBaseline] = None
        self._samples: List[GPUTelemetry] = []
        self._profile_applied_at: Optional[float] = None
        self._current_profile: Optional[str] = None
        self._breach_count = 0
        self._breach_history: List[SLABreachEvent] = []
    
    def set_baseline(self, samples: List[GPUTelemetry], profile_id: str = "default"):
        """
        Set performance baseline from samples.
        
        Args:
            samples: List of telemetry samples
            profile_id: Profile ID for this baseline
        """
        if not samples:
            return
        
        avg_power = sum(s.power_draw_w for s in samples) / len(samples)
        avg_util = sum(s.gpu_utilization for s in samples) / len(samples)
        avg_clock = sum(s.gpu_clock_mhz for s in samples) / len(samples)
        
        self._baseline = PerformanceBaseline(
            timestamp=time.time(),
            avg_power_w=avg_power,
            avg_utilization=avg_util,
            avg_clock_mhz=avg_clock,
            profile_id=profile_id,
        )
        
        logger.info(
            f"Baseline set: {avg_power:.1f}W, {avg_util:.1f}% util, "
            f"{avg_clock:.0f} MHz (profile: {profile_id})"
        )
    
    def on_profile_applied(self, profile_id: str):
        """
        Called when a new profile is applied.
        
        Args:
            profile_id: Applied profile ID
        """
        self._profile_applied_at = time.time()
        self._current_profile = profile_id
        self._samples.clear()
        logger.info(f"Profile applied: {profile_id}, starting grace period")
    
    def add_sample(self, sample: GPUTelemetry):
        """
        Add telemetry sample for SLA monitoring.
        
        Args:
            sample: GPU telemetry sample
        """
        self._samples.append(sample)
        
        # Keep only recent samples
        max_samples = self.config.measurement_window_samples * 2
        if len(self._samples) > max_samples:
            self._samples = self._samples[-max_samples:]
    
    def check_sla(self) -> Dict[str, Any]:
        """
        Check SLA compliance.
        
        Returns:
            Dictionary with SLA status and any breach info
        """
        result = {
            'compliant': True,
            'in_grace_period': False,
            'performance_percent': 100.0,
            'performance_loss_percent': 0.0,
            'power_savings_percent': 0.0,
            'breach': None,
            'message': 'SLA OK',
        }
        
        # Check if in grace period
        if self._profile_applied_at:
            elapsed = time.time() - self._profile_applied_at
            if elapsed < self.config.grace_period_seconds:
                result['in_grace_period'] = True
                remaining = self.config.grace_period_seconds - elapsed
                result['message'] = f"Grace period: {remaining:.0f}s remaining"
                return result
        
        # Need baseline and samples to evaluate
        if not self._baseline:
            result['message'] = "No baseline set"
            return result
        
        if len(self._samples) < self.config.measurement_window_samples:
            result['message'] = f"Collecting samples ({len(self._samples)}/{self.config.measurement_window_samples})"
            return result
        
        # Calculate current performance
        recent = self._samples[-self.config.measurement_window_samples:]
        current_power = sum(s.power_draw_w for s in recent) / len(recent)
        current_util = sum(s.gpu_utilization for s in recent) / len(recent)
        current_clock = sum(s.gpu_clock_mhz for s in recent) / len(recent)
        
        # Calculate performance score (simplified: clock * utilization)
        baseline_score = self._baseline.avg_clock_mhz * (self._baseline.avg_utilization / 100)
        current_score = current_clock * (current_util / 100)
        
        performance_percent = (current_score / baseline_score * 100) if baseline_score > 0 else 100
        performance_loss = max(0, 100 - performance_percent)
        power_savings = ((self._baseline.avg_power_w - current_power) / self._baseline.avg_power_w * 100) if self._baseline.avg_power_w > 0 else 0
        
        result['performance_percent'] = round(performance_percent, 1)
        result['performance_loss_percent'] = round(performance_loss, 1)
        result['power_savings_percent'] = round(power_savings, 1)
        
        # Check SLA threshold
        if performance_loss > self.config.max_performance_loss_percent:
            result['compliant'] = False
            
            breach = SLABreachEvent(
                timestamp=time.time(),
                breach_type='performance',
                threshold=self.config.max_performance_loss_percent,
                measured_value=performance_loss,
                active_profile=self._current_profile or 'unknown',
                action_taken='revert' if self.config.auto_revert else 'alert',
            )
            
            self._breach_count += 1
            self._breach_history.append(breach)
            
            result['breach'] = {
                'type': breach.breach_type,
                'threshold': breach.threshold,
                'measured': breach.measured_value,
                'action': breach.action_taken,
            }
            result['message'] = (
                f"SLA BREACH: {performance_loss:.1f}% perf loss "
                f"exceeds {self.config.max_performance_loss_percent}% threshold"
            )
            
            logger.warning(result['message'])
        else:
            result['message'] = (
                f"SLA OK: {performance_loss:.1f}% perf loss, "
                f"{power_savings:.1f}% power saved"
            )
        
        return result
    
    def should_revert(self) -> bool:
        """
        Check if profile should be reverted due to SLA breach.
        
        Returns:
            True if should revert to previous profile
        """
        status = self.check_sla()
        if not status['compliant'] and self.config.auto_revert:
            return True
        return False
    
    def get_status(self) -> Dict[str, Any]:
        """Get SLA monitor status"""
        return {
            'hasBaseline': self._baseline is not None,
            'currentProfile': self._current_profile,
            'sampleCount': len(self._samples),
            'breachCount': self._breach_count,
            'config': {
                'maxPerformanceLoss': self.config.max_performance_loss_percent,
                'gracePeriod': self.config.grace_period_seconds,
                'autoRevert': self.config.auto_revert,
            },
        }
    
    def get_breach_history(self) -> List[Dict[str, Any]]:
        """Get recent breach history"""
        return [
            {
                'timestamp': b.timestamp,
                'type': b.breach_type,
                'threshold': b.threshold,
                'measured': b.measured_value,
                'profile': b.active_profile,
                'action': b.action_taken,
            }
            for b in self._breach_history[-10:]  # Last 10 breaches
        ]
