"""
NVML Sampler - GPU Telemetry Collection via NVIDIA Management Library
"""
import time
import logging
from dataclasses import dataclass
from typing import Optional, List, Dict, Any

try:
    import pynvml
    NVML_AVAILABLE = True
except ImportError:
    NVML_AVAILABLE = False

logger = logging.getLogger(__name__)


@dataclass
class GPUTelemetry:
    """Real-time GPU telemetry sample"""
    timestamp: float
    gpu_index: int
    
    # Identification
    gpu_model: str
    gpu_uuid: str
    
    # Power metrics
    power_draw_w: float
    power_limit_w: float
    power_limit_default_w: float
    power_limit_min_w: float
    power_limit_max_w: float
    
    # Clock speeds
    gpu_clock_mhz: int
    gpu_clock_max_mhz: int
    memory_clock_mhz: int
    
    # Utilization
    gpu_utilization: int  # 0-100%
    memory_utilization: int  # 0-100%
    memory_used_mb: int
    memory_total_mb: int
    
    # Thermal
    temperature_c: int
    temperature_throttle_c: int
    fan_speed_percent: int
    
    # State
    is_throttling: bool
    throttle_reasons: List[str]
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization"""
        return {
            'timestamp': self.timestamp,
            'gpuIndex': self.gpu_index,
            'gpuModel': self.gpu_model,
            'gpuUuid': self.gpu_uuid,
            'powerDraw': self.power_draw_w,
            'powerLimit': self.power_limit_w,
            'powerLimitDefault': self.power_limit_default_w,
            'powerLimitMin': self.power_limit_min_w,
            'powerLimitMax': self.power_limit_max_w,
            'gpuClock': self.gpu_clock_mhz,
            'gpuClockMax': self.gpu_clock_max_mhz,
            'memoryClock': self.memory_clock_mhz,
            'gpuUtilization': self.gpu_utilization,
            'memoryUtilization': self.memory_utilization,
            'memoryUsedMb': self.memory_used_mb,
            'memoryTotalMb': self.memory_total_mb,
            'temperature': self.temperature_c,
            'temperatureThrottle': self.temperature_throttle_c,
            'fanSpeed': self.fan_speed_percent,
            'isThrottling': self.is_throttling,
            'throttleReasons': self.throttle_reasons,
        }


class NVMLSampler:
    """
    NVML-based GPU telemetry sampler for NVIDIA GPUs.
    Supports RTX 4090 and other consumer/datacenter GPUs.
    """
    
    def __init__(self):
        self._initialized = False
        self._gpu_handles: List[Any] = []
        self._gpu_count = 0
    
    def initialize(self) -> bool:
        """Initialize NVML library"""
        if not NVML_AVAILABLE:
            logger.error("pynvml not installed. Run: pip install pynvml")
            return False
        
        try:
            pynvml.nvmlInit()
            self._gpu_count = pynvml.nvmlDeviceGetCount()
            
            for i in range(self._gpu_count):
                handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                self._gpu_handles.append(handle)
                name = pynvml.nvmlDeviceGetName(handle)
                logger.info(f"GPU {i}: {name}")
            
            self._initialized = True
            logger.info(f"NVML initialized: {self._gpu_count} GPU(s) detected")
            return True
            
        except pynvml.NVMLError as e:
            logger.error(f"NVML initialization failed: {e}")
            return False
    
    def shutdown(self):
        """Shutdown NVML library"""
        if self._initialized:
            try:
                pynvml.nvmlShutdown()
                self._initialized = False
                logger.info("NVML shutdown complete")
            except pynvml.NVMLError as e:
                logger.error(f"NVML shutdown error: {e}")
    
    def get_gpu_count(self) -> int:
        """Get number of detected GPUs"""
        return self._gpu_count
    
    def sample(self, gpu_index: int = 0) -> Optional[GPUTelemetry]:
        """
        Collect telemetry sample from specified GPU.
        
        Args:
            gpu_index: GPU index (0-based)
            
        Returns:
            GPUTelemetry object or None on error
        """
        if not self._initialized:
            logger.error("NVML not initialized")
            return None
        
        if gpu_index >= self._gpu_count:
            logger.error(f"Invalid GPU index: {gpu_index}")
            return None
        
        try:
            handle = self._gpu_handles[gpu_index]
            
            # Get GPU identification
            name = pynvml.nvmlDeviceGetName(handle)
            uuid = pynvml.nvmlDeviceGetUUID(handle)
            
            # Get power metrics
            power_draw = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0  # mW to W
            power_limit = pynvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0
            power_default = pynvml.nvmlDeviceGetPowerManagementDefaultLimit(handle) / 1000.0
            
            try:
                power_constraints = pynvml.nvmlDeviceGetPowerManagementLimitConstraints(handle)
                power_min = power_constraints[0] / 1000.0
                power_max = power_constraints[1] / 1000.0
            except:
                power_min = power_default * 0.5
                power_max = power_default * 1.1
            
            # Get clock speeds
            gpu_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_SM)
            gpu_clock_max = pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM)
            mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM)
            
            # Get utilization
            utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
            
            # Get memory info
            memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
            memory_used_mb = memory.used // (1024 * 1024)
            memory_total_mb = memory.total // (1024 * 1024)
            
            # Get thermal info
            temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
            
            try:
                temp_threshold = pynvml.nvmlDeviceGetTemperatureThreshold(
                    handle, pynvml.NVML_TEMPERATURE_THRESHOLD_GPU_MAX
                )
            except:
                temp_threshold = 83  # Default for most consumer GPUs
            
            try:
                fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle)
            except:
                fan_speed = 0  # Some GPUs don't support fan speed query
            
            # Check throttling
            throttle_reasons = []
            is_throttling = False
            try:
                throttle = pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(handle)
                if throttle & pynvml.nvmlClocksThrottleReasonGpuIdle:
                    throttle_reasons.append("idle")
                if throttle & pynvml.nvmlClocksThrottleReasonSwThermalSlowdown:
                    throttle_reasons.append("thermal_sw")
                    is_throttling = True
                if throttle & pynvml.nvmlClocksThrottleReasonHwThermalSlowdown:
                    throttle_reasons.append("thermal_hw")
                    is_throttling = True
                if throttle & pynvml.nvmlClocksThrottleReasonSwPowerCap:
                    throttle_reasons.append("power_cap")
                    is_throttling = True
            except:
                pass
            
            return GPUTelemetry(
                timestamp=time.time(),
                gpu_index=gpu_index,
                gpu_model=name,
                gpu_uuid=uuid,
                power_draw_w=round(power_draw, 1),
                power_limit_w=round(power_limit, 1),
                power_limit_default_w=round(power_default, 1),
                power_limit_min_w=round(power_min, 1),
                power_limit_max_w=round(power_max, 1),
                gpu_clock_mhz=gpu_clock,
                gpu_clock_max_mhz=gpu_clock_max,
                memory_clock_mhz=mem_clock,
                gpu_utilization=utilization.gpu,
                memory_utilization=utilization.memory,
                memory_used_mb=memory_used_mb,
                memory_total_mb=memory_total_mb,
                temperature_c=temperature,
                temperature_throttle_c=temp_threshold,
                fan_speed_percent=fan_speed,
                is_throttling=is_throttling,
                throttle_reasons=throttle_reasons,
            )
            
        except pynvml.NVMLError as e:
            logger.error(f"NVML sampling error: {e}")
            return None
    
    def sample_all(self) -> List[GPUTelemetry]:
        """Sample all detected GPUs"""
        samples = []
        for i in range(self._gpu_count):
            sample = self.sample(i)
            if sample:
                samples.append(sample)
        return samples


# Demo function for testing
def demo_sample():
    """Demo function to test NVML sampling"""
    sampler = NVMLSampler()
    
    if not sampler.initialize():
        print("Failed to initialize NVML")
        return
    
    try:
        print(f"\nDetected {sampler.get_gpu_count()} GPU(s)")
        print("-" * 60)
        
        for _ in range(5):
            samples = sampler.sample_all()
            for sample in samples:
                print(f"\n[GPU {sample.gpu_index}] {sample.gpu_model}")
                print(f"  Power:       {sample.power_draw_w}W / {sample.power_limit_w}W limit")
                print(f"  Clock:       {sample.gpu_clock_mhz} MHz (max: {sample.gpu_clock_max_mhz})")
                print(f"  Utilization: GPU {sample.gpu_utilization}% | Memory {sample.memory_utilization}%")
                print(f"  Temperature: {sample.temperature_c}°C")
                print(f"  Throttling:  {sample.is_throttling} {sample.throttle_reasons}")
            
            time.sleep(2)
            
    finally:
        sampler.shutdown()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    demo_sample()
