"""
LLM Load Simulator - GPU Stress Test for Power Optimization Testing

Simulates LLM training/inference workloads to test DeepOptiFlex optimization.
Uses matrix operations to generate realistic GPU load patterns.
"""
import sys
import time
import signal
import logging
from typing import Optional, List
from dataclasses import dataclass

import click

try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

logger = logging.getLogger(__name__)


@dataclass
class WorkloadConfig:
    """Workload configuration"""
    mode: str  # 'training', 'inference', 'burst', 'mixed'
    intensity: float  # 0.0 - 1.0
    duration_seconds: int
    batch_size: int
    hidden_size: int
    num_layers: int


DEFAULT_CONFIGS = {
    'training': WorkloadConfig(
        mode='training',
        intensity=0.9,
        duration_seconds=300,
        batch_size=64,
        hidden_size=4096,
        num_layers=32,
    ),
    'inference': WorkloadConfig(
        mode='inference',
        intensity=0.6,
        duration_seconds=300,
        batch_size=32,
        hidden_size=4096,
        num_layers=32,
    ),
    'burst': WorkloadConfig(
        mode='burst',
        intensity=1.0,
        duration_seconds=60,
        batch_size=128,
        hidden_size=8192,
        num_layers=48,
    ),
    'mixed': WorkloadConfig(
        mode='mixed',
        intensity=0.7,
        duration_seconds=300,
        batch_size=48,
        hidden_size=4096,
        num_layers=24,
    ),
}


class LLMLoadSimulator:
    """
    Simulates LLM workloads on GPU for testing power optimization.
    Uses matrix multiplications to mimic transformer operations.
    """
    
    def __init__(self, gpu_index: int = 0):
        """
        Initialize the simulator.
        
        Args:
            gpu_index: CUDA device index
        """
        if not TORCH_AVAILABLE:
            raise RuntimeError("PyTorch not installed. Run: pip install torch")
        
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA not available")
        
        self.gpu_index = gpu_index
        self.device = torch.device(f'cuda:{gpu_index}')
        self._running = False
        self._stats = {
            'iterations': 0,
            'total_flops': 0,
            'start_time': None,
            'end_time': None,
        }
        
        # Check GPU
        gpu_name = torch.cuda.get_device_name(gpu_index)
        gpu_memory = torch.cuda.get_device_properties(gpu_index).total_memory / (1024**3)
        logger.info(f"GPU {gpu_index}: {gpu_name} ({gpu_memory:.1f} GB)")
    
    def _create_tensors(self, config: WorkloadConfig):
        """Create tensors for workload simulation"""
        batch = config.batch_size
        hidden = config.hidden_size
        
        # Create weight matrices (mimicking transformer layers)
        weights = []
        for _ in range(config.num_layers):
            w1 = torch.randn(hidden, hidden * 4, device=self.device, dtype=torch.float16)
            w2 = torch.randn(hidden * 4, hidden, device=self.device, dtype=torch.float16)
            weights.append((w1, w2))
        
        # Create input tensor
        x = torch.randn(batch, hidden, device=self.device, dtype=torch.float16)
        
        return x, weights
    
    def _forward_pass(self, x, weights, backward: bool = False):
        """
        Perform forward/backward pass through simulated layers.
        
        Args:
            x: Input tensor
            weights: Layer weights
            backward: Whether to compute gradients
        """
        if backward:
            x = x.requires_grad_(True)
        
        for w1, w2 in weights:
            # MLP forward: x -> x*W1 -> GELU -> *W2
            h = torch.matmul(x, w1)
            h = torch.nn.functional.gelu(h)
            x = torch.matmul(h, w2)
        
        if backward:
            # Backward pass
            loss = x.sum()
            loss.backward()
            torch.cuda.synchronize()
        
        return x
    
    def run_training(self, config: WorkloadConfig):
        """
        Run training workload (forward + backward).
        
        Args:
            config: Workload configuration
        """
        logger.info(f"Starting TRAINING workload: {config.duration_seconds}s, intensity={config.intensity}")
        
        x, weights = self._create_tensors(config)
        
        # Enable gradients on weights
        for w1, w2 in weights:
            w1.requires_grad_(True)
            w2.requires_grad_(True)
        
        start_time = time.time()
        iteration = 0
        
        while self._running and (time.time() - start_time) < config.duration_seconds:
            self._forward_pass(x.clone(), weights, backward=True)
            iteration += 1
            self._stats['iterations'] = iteration
            
            # Add some variation based on intensity
            if config.intensity < 0.9:
                time.sleep(0.01 * (1 - config.intensity))
        
        self._log_stats(iteration, time.time() - start_time)
    
    def run_inference(self, config: WorkloadConfig):
        """
        Run inference workload (forward only, variable batch).
        
        Args:
            config: Workload configuration
        """
        logger.info(f"Starting INFERENCE workload: {config.duration_seconds}s, intensity={config.intensity}")
        
        x, weights = self._create_tensors(config)
        
        start_time = time.time()
        iteration = 0
        
        with torch.no_grad():
            while self._running and (time.time() - start_time) < config.duration_seconds:
                # Inference with variable batch simulation
                self._forward_pass(x, weights, backward=False)
                iteration += 1
                self._stats['iterations'] = iteration
                
                # Simulate request intervals
                if config.intensity < 0.8:
                    time.sleep(0.05 * (1 - config.intensity))
        
        self._log_stats(iteration, time.time() - start_time)
    
    def run_burst(self, config: WorkloadConfig):
        """
        Run burst workload (maximum intensity).
        
        Args:
            config: Workload configuration
        """
        logger.info(f"Starting BURST workload: {config.duration_seconds}s")
        
        x, weights = self._create_tensors(config)
        
        start_time = time.time()
        iteration = 0
        
        with torch.no_grad():
            while self._running and (time.time() - start_time) < config.duration_seconds:
                # Maximum throughput
                for _ in range(10):  # Batch of operations
                    self._forward_pass(x, weights, backward=False)
                iteration += 10
                self._stats['iterations'] = iteration
        
        self._log_stats(iteration, time.time() - start_time)
    
    def run_mixed(self, config: WorkloadConfig):
        """
        Run mixed workload (alternating phases).
        
        Args:
            config: Workload configuration
        """
        logger.info(f"Starting MIXED workload: {config.duration_seconds}s")
        
        phase_duration = 30  # seconds per phase
        phases = ['training', 'inference', 'idle', 'burst']
        
        start_time = time.time()
        
        while self._running and (time.time() - start_time) < config.duration_seconds:
            elapsed = time.time() - start_time
            phase_idx = int(elapsed / phase_duration) % len(phases)
            phase = phases[phase_idx]
            
            phase_time = phase_duration - (elapsed % phase_duration)
            
            if phase == 'training':
                logger.info(f"Phase: TRAINING ({phase_time:.0f}s)")
                sub_config = DEFAULT_CONFIGS['training']
                sub_config.duration_seconds = int(phase_time)
                self.run_training(sub_config)
            elif phase == 'inference':
                logger.info(f"Phase: INFERENCE ({phase_time:.0f}s)")
                sub_config = DEFAULT_CONFIGS['inference']
                sub_config.duration_seconds = int(phase_time)
                self.run_inference(sub_config)
            elif phase == 'burst':
                logger.info(f"Phase: BURST ({phase_time:.0f}s)")
                sub_config = DEFAULT_CONFIGS['burst']
                sub_config.duration_seconds = int(phase_time)
                self.run_burst(sub_config)
            else:
                logger.info(f"Phase: IDLE ({phase_time:.0f}s)")
                time.sleep(phase_time)
    
    def run(self, config: WorkloadConfig):
        """
        Run workload based on mode.
        
        Args:
            config: Workload configuration
        """
        self._running = True
        self._stats['start_time'] = time.time()
        
        try:
            if config.mode == 'training':
                self.run_training(config)
            elif config.mode == 'inference':
                self.run_inference(config)
            elif config.mode == 'burst':
                self.run_burst(config)
            elif config.mode == 'mixed':
                self.run_mixed(config)
            else:
                logger.error(f"Unknown mode: {config.mode}")
        finally:
            self._stats['end_time'] = time.time()
            self._running = False
            self._cleanup()
    
    def stop(self):
        """Stop the workload"""
        self._running = False
    
    def _cleanup(self):
        """Clean up GPU memory"""
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    def _log_stats(self, iterations: int, duration: float):
        """Log workload statistics"""
        if duration > 0:
            throughput = iterations / duration
            logger.info(f"Completed {iterations} iterations in {duration:.1f}s ({throughput:.1f} iter/s)")
    
    def get_stats(self):
        """Get workload statistics"""
        return self._stats.copy()


@click.command()
@click.option('--mode', '-m', 
              type=click.Choice(['training', 'inference', 'burst', 'mixed']),
              default='mixed',
              help='Workload mode')
@click.option('--duration', '-d', default=300, help='Duration in seconds')
@click.option('--intensity', '-i', default=0.7, help='Intensity 0.0-1.0')
@click.option('--gpu', '-g', default=0, help='GPU index')
def main(mode: str, duration: int, intensity: float, gpu: int):
    """LLM Load Simulator for GPU Power Optimization Testing"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s | %(levelname)-7s | %(message)s',
        datefmt='%H:%M:%S'
    )
    
    logger.info("=" * 60)
    logger.info("LLM Load Simulator")
    logger.info("=" * 60)
    
    # Get default config and override
    config = DEFAULT_CONFIGS.get(mode, DEFAULT_CONFIGS['mixed'])
    config.duration_seconds = duration
    config.intensity = intensity
    
    try:
        simulator = LLMLoadSimulator(gpu_index=gpu)
    except RuntimeError as e:
        logger.error(str(e))
        sys.exit(1)
    
    # Setup signal handler
    def signal_handler(signum, frame):
        logger.info("Stopping workload...")
        simulator.stop()
    
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    
    # Run workload
    logger.info(f"Mode: {mode}, Duration: {duration}s, Intensity: {intensity}")
    logger.info("-" * 60)
    
    simulator.run(config)
    
    logger.info("=" * 60)
    logger.info("Simulation complete")


if __name__ == '__main__':
    main()
