# fiberweave/algorithms/power_budget_batch.py
"""
Power Budget Batch Analysis Algorithm
Calculates power budgets for all network paths
"""

from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterString,
    QgsProcessingParameterNumber,
    QgsProcessingOutputNumber,
    QgsProcessingOutputString,
    QgsMessageLog,
    Qgis
)
from qgis.PyQt.QtCore import QCoreApplication
import psycopg2


class PowerBudgetBatchAlgorithm(QgsProcessingAlgorithm):
    """Batch calculates power budgets for all ONU connections"""
    
    # Parameters
    DB_HOST = 'DB_HOST'
    DB_PORT = 'DB_PORT'
    DB_NAME = 'DB_NAME'
    DB_USER = 'DB_USER'
    DB_PASSWORD = 'DB_PASSWORD'
    OLT_TX_POWER = 'OLT_TX_POWER'
    ONU_TX_POWER = 'ONU_TX_POWER'
    
    # Outputs
    OUTPUT_CALCULATED = 'OUTPUT_CALCULATED'
    OUTPUT_COMPLIANT = 'OUTPUT_COMPLIANT'
    OUTPUT_NON_COMPLIANT = 'OUTPUT_NON_COMPLIANT'
    OUTPUT_STATUS = 'OUTPUT_STATUS'
    
    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
    
    def createInstance(self):
        return PowerBudgetBatchAlgorithm()
    
    def name(self):
        return 'power_budget_batch'
    
    def displayName(self):
        return self.tr('Power Budget Batch Analysis')
    
    def group(self):
        return self.tr('Network Analysis')
    
    def groupId(self):
        return 'analysis'
    
    def shortHelpString(self):
        return self.tr(
            'Calculates ITU-T compliant power budgets for all network paths.\n\n'
            'For each ONU connection, calculates:\n'
            '- Downstream power budget (OLT → ONU)\n'
            '- Upstream power budget (ONU → OLT)\n'
            '- Total path loss\n'
            '- Compliance with ITU-T G.984.2 standards\n\n'
            'Results are stored in itu_power_budget table.'
        )
    
    def initAlgorithm(self, config=None):
        # Database parameters
        self.addParameter(
            QgsProcessingParameterString(
                self.DB_HOST,
                self.tr('Database Host'),
                defaultValue='localhost'
            )
        )
        
        self.addParameter(
            QgsProcessingParameterString(
                self.DB_PORT,
                self.tr('Database Port'),
                defaultValue='5432'
            )
        )
        
        self.addParameter(
            QgsProcessingParameterString(
                self.DB_NAME,
                self.tr('Database Name'),
                defaultValue='fiberweave_network'
            )
        )
        
        self.addParameter(
            QgsProcessingParameterString(
                self.DB_USER,
                self.tr('Database User'),
                defaultValue='postgres'
            )
        )
        
        self.addParameter(
            QgsProcessingParameterString(
                self.DB_PASSWORD,
                self.tr('Database Password'),
                defaultValue=''
            )
        )
        
        # Power parameters
        self.addParameter(
            QgsProcessingParameterNumber(
                self.OLT_TX_POWER,
                self.tr('OLT TX Power (dBm)'),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=3.0,
                minValue=-10.0,
                maxValue=10.0
            )
        )
        
        self.addParameter(
            QgsProcessingParameterNumber(
                self.ONU_TX_POWER,
                self.tr('ONU TX Power (dBm)'),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=2.0,
                minValue=-5.0,
                maxValue=10.0
            )
        )
        
        # Outputs
        self.addOutput(
            QgsProcessingOutputNumber(
                self.OUTPUT_CALCULATED,
                self.tr('Paths Calculated')
            )
        )
        
        self.addOutput(
            QgsProcessingOutputNumber(
                self.OUTPUT_COMPLIANT,
                self.tr('Compliant Paths')
            )
        )
        
        self.addOutput(
            QgsProcessingOutputNumber(
                self.OUTPUT_NON_COMPLIANT,
                self.tr('Non-Compliant Paths')
            )
        )
        
        self.addOutput(
            QgsProcessingOutputString(
                self.OUTPUT_STATUS,
                self.tr('Analysis Status')
            )
        )
    
    def processAlgorithm(self, parameters, context, feedback):
        # Get parameters
        host = self.parameterAsString(parameters, self.DB_HOST, context)
        port = self.parameterAsString(parameters, self.DB_PORT, context)
        database = self.parameterAsString(parameters, self.DB_NAME, context)
        user = self.parameterAsString(parameters, self.DB_USER, context)
        password = self.parameterAsString(parameters, self.DB_PASSWORD, context)
        
        olt_tx_power = self.parameterAsDouble(parameters, self.OLT_TX_POWER, context)
        onu_tx_power = self.parameterAsDouble(parameters, self.ONU_TX_POWER, context)
        
        feedback.pushInfo('Starting power budget batch analysis...')
        
        try:
            # Connect to database
            feedback.pushInfo(f'Connecting to database: {database}')
            conn = psycopg2.connect(
                host=host, port=port, database=database,
                user=user, password=password
            )
            cur = conn.cursor()
            
            # Clear existing power budget calculations
            feedback.setProgress(10)
            feedback.pushInfo('Clearing existing calculations...')
            cur.execute('DELETE FROM fttx.itu_power_budget')
            
            # Get all active ONUs with their connections
            feedback.setProgress(20)
            feedback.pushInfo('Retrieving active ONUs...')
            
            cur.execute("""
                SELECT 
                    o.onu_id,
                    o.connected_olt_id,
                    o.splitter_node_id,
                    o.onu_rx_sensitivity_dbm,
                    olt.receiver_sensitivity_dbm as olt_rx_sensitivity_dbm
                FROM fttx.itu_onu_equipment o
                LEFT JOIN fttx.itu_olt_equipment olt ON o.connected_olt_id = olt.olt_id
                WHERE o.onu_status = 'active'
                AND o.connected_olt_id IS NOT NULL
            """)
            
            onus = cur.fetchall()
            total_onus = len(onus)
            
            if total_onus == 0:
                feedback.pushInfo('⚠ No active ONUs with OLT connections found')
                return {
                    self.OUTPUT_CALCULATED: 0,
                    self.OUTPUT_COMPLIANT: 0,
                    self.OUTPUT_NON_COMPLIANT: 0,
                    self.OUTPUT_STATUS: 'No ONUs to analyze'
                }
            
            feedback.pushInfo(f'Found {total_onus} active ONU(s) to analyze')
            
            calculated = 0
            compliant = 0
            non_compliant = 0
            
            # Process each ONU
            for idx, onu_data in enumerate(onus):
                if feedback.isCanceled():
                    break
                
                onu_id, olt_id, splitter_id, onu_rx_sens, olt_rx_sens = onu_data
                
                # Convert to float to handle Decimal types from PostgreSQL
                onu_rx_sens = float(onu_rx_sens) if onu_rx_sens is not None else -27.0
                olt_rx_sens = float(olt_rx_sens) if olt_rx_sens is not None else -28.0
                
                progress = 20 + int((idx / total_onus) * 70)
                feedback.setProgress(progress)
                feedback.pushInfo(f'Processing ONU {idx+1}/{total_onus}...')
                
                # Calculate losses
                fiber_loss = 2.5 * 0.25
                splice_loss = 3 * 0.05
                connector_loss = 2 * 0.5
                
                # Splitter loss
                splitter_loss = 17.5
                if splitter_id:
                    cur.execute("""
                        SELECT insertion_loss_typical_db 
                        FROM fttx.itu_odn_nodes 
                        WHERE node_id = %s
                    """, (splitter_id,))
                    result = cur.fetchone()
                    if result and result[0]:
                        splitter_loss = float(result[0])
                
                # Total loss
                total_loss = fiber_loss + splice_loss + connector_loss + splitter_loss
                
                # Downstream calculation
                onu_rx_power = olt_tx_power - total_loss
                downstream_margin = onu_rx_power - onu_rx_sens
                
                # Upstream calculation
                olt_rx_power = onu_tx_power - total_loss
                upstream_margin = olt_rx_power - olt_rx_sens
                
                # Check compliance
                meets_requirements = (downstream_margin >= 3.0 and upstream_margin >= 3.0)
                
                if meets_requirements:
                    compliant += 1
                else:
                    non_compliant += 1
                
                # Determine limiting factor
                if downstream_margin < upstream_margin:
                    limiting_factor = 'downstream'
                else:
                    limiting_factor = 'upstream'
                
                # Insert into power budget table
                cur.execute("""
                    INSERT INTO fttx.itu_power_budget (
                        onu_id, power_budget_class,
                        olt_tx_power_dbm,
                        downstream_fiber_loss_db,
                        downstream_splice_loss_db,
                        downstream_connector_loss_db,
                        downstream_splitter_loss_db,
                        onu_rx_power_dbm,
                        downstream_margin_db,
                        onu_tx_power_dbm,
                        upstream_fiber_loss_db,
                        upstream_splice_loss_db,
                        upstream_connector_loss_db,
                        upstream_splitter_loss_db,
                        olt_rx_power_dbm,
                        upstream_margin_db,
                        meets_itu_requirements,
                        limiting_factor,
                        calculation_method
                    ) VALUES (
                        %s, 'B+',
                        %s, %s, %s, %s, %s, %s, %s,
                        %s, %s, %s, %s, %s, %s, %s,
                        %s, %s, 'batch_algorithm'
                    )
                """, (
                    onu_id,
                    olt_tx_power, fiber_loss, splice_loss, connector_loss, splitter_loss,
                    onu_rx_power, downstream_margin,
                    onu_tx_power, fiber_loss, splice_loss, connector_loss, splitter_loss,
                    olt_rx_power, upstream_margin,
                    meets_requirements, limiting_factor
                ))
                
                calculated += 1
            
            # Commit changes
            feedback.setProgress(95)
            feedback.pushInfo('Committing results to database...')
            conn.commit()
            
            cur.close()
            conn.close()
            
            feedback.setProgress(100)
            
            # Summary
            feedback.pushInfo('')
            feedback.pushInfo('=' * 50)
            feedback.pushInfo('POWER BUDGET ANALYSIS COMPLETE')
            feedback.pushInfo('=' * 50)
            feedback.pushInfo(f'Total ONUs analyzed: {calculated}')
            feedback.pushInfo(f'✓ Compliant paths: {compliant}')
            feedback.pushInfo(f'✗ Non-compliant paths: {non_compliant}')
            
            if non_compliant > 0:
                compliance_rate = (compliant / calculated * 100) if calculated > 0 else 0
                feedback.pushInfo(f'Compliance rate: {compliance_rate:.1f}%')
            
            feedback.pushInfo('=' * 50)
            
            status = f'Analyzed {calculated} paths: {compliant} compliant, {non_compliant} non-compliant'
            
            return {
                self.OUTPUT_CALCULATED: calculated,
                self.OUTPUT_COMPLIANT: compliant,
                self.OUTPUT_NON_COMPLIANT: non_compliant,
                self.OUTPUT_STATUS: status
            }
            
        except Exception as e:
            feedback.reportError(f'Analysis failed: {str(e)}')
            return {
                self.OUTPUT_CALCULATED: 0,
                self.OUTPUT_COMPLIANT: 0,
                self.OUTPUT_NON_COMPLIANT: 0,
                self.OUTPUT_STATUS: f'ERROR: {str(e)}'
            }