import argparse
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

class FusionForecaster:
    """
    Implementation of Fusion Forecasting for residential PV systems.
    Blends AI-based and physics-based (GTI) forecasts with online calibration.
    """
    
    def __init__(self, 
                 train_window_size=90, 
                 calibration_window_size=14,
                 alpha=0.05,
                 p_max=10.0,  # Maximum power capacity in kW
                 min_train_samples=30,
                 fallback_weights=(0.7, 0.3)):
        """
        Initialize the Fusion Forecasting model.
        
        Parameters:
        -----------
        train_window_size : int
            Number of days to use for training the regression model
        calibration_window_size : int
            Number of days to use for online calibration
        alpha : float
            Winsorization parameter for outlier removal (0-0.5)
        p_max : float
            Maximum power capacity of the PV system (kW)
        min_train_samples : int
            Minimum number of samples required for training
        fallback_weights : tuple
            (w_AI, w_GTI) weights to use when falling back
        """
        self.train_window_size = train_window_size
        self.calibration_window_size = calibration_window_size
        self.alpha = alpha
        self.p_max = p_max
        self.min_train_samples = min_train_samples
        self.fallback_weights = fallback_weights
        
        # Model parameters (to be learned)
        self.w_ai = None
        self.w_gti = None
        self.b = None
        self.scale = 1.0
        self.prev_scale = 1.0
        
        # Data storage
        self.history = deque(maxlen=max(train_window_size, calibration_window_size) * 2)
        
        # State flags
        self.is_trained = False
        
    def add_observation(self, date, y_ai, y_gti, y_actual, daylight_hours):
        """
        Add a new observation to the history.
        
        Parameters:
        -----------
        date : datetime or comparable
            Date identifier for the observation
        y_ai : float
            AI-based forecast for the day
        y_gti : float
            GTI-based forecast for the day
        y_actual : float
            Actual energy production (kWh)
        daylight_hours : float
            Number of daylight hours for the day
        """
        self.history.append({
            'date': date,
            'y_ai': y_ai,
            'y_gti': y_gti,
            'y_actual': y_actual,
            'daylight_hours': daylight_hours
        })
    
    def train_fusion(self, lambda_values=np.logspace(-3, 2, 10)):
        """
        Train the fusion model using historical data.
        Implements the TRAIN_FUSION procedure from the paper.
        
        Parameters:
        -----------
        lambda_values : array-like
            Candidate values for the ridge regularization parameter
            
        Returns:
        --------
        w_ai, w_gti, b : learned parameters
        """
        # Check if we have enough data
        if len(self.history) < self.min_train_samples:
            warnings.warn(f"Insufficient data ({len(self.history)} < {self.min_train_samples}). Using fallback weights.")
            return self._fallback_training()
        
        # Prepare training data
        X = []
        y = []
        valid_dates = []
        
        for obs in self.history:
            if (obs['y_ai'] is not None and obs['y_gti'] is not None and 
                obs['y_actual'] is not None and obs['y_actual'] >= 0):
                X.append([obs['y_ai'], obs['y_gti']])
                y.append(obs['y_actual'])
                valid_dates.append(obs['date'])
        
        if len(X) < self.min_train_samples:
            warnings.warn(f"Insufficient valid data ({len(X)} < {self.min_train_samples}). Using fallback weights.")
            return self._fallback_training()
        
        X = np.array(X)
        y = np.array(y)
        
        # Time-series cross-validation to select lambda
        best_lambda = self._select_lambda_cv(X, y, lambda_values)
        
        # Fit final model with best lambda
        try:
            # Use HuberRegressor for robust regression (similar to Huber loss)
            model = HuberRegressor(alpha=best_lambda, epsilon=1.35)  # epsilon=1.35 ≈ delta=1.0 in Huber
            model.fit(X, y)
            
            self.w_ai, self.w_gti = model.coef_
            self.b = model.intercept_
            self.is_trained = True
            
            return self.w_ai, self.w_gti, self.b
            
        except Exception as e:
            warnings.warn(f"Regression failed: {e}. Using fallback weights.")
            return self._fallback_training()
    
    def _select_lambda_cv(self, X, y, lambda_values):
        """
        Select the best regularization parameter using time-series cross-validation.
        """
        tscv = TimeSeriesSplit(n_splits=min(5, len(X) - 1))
        best_score = float('inf')
        best_lambda = lambda_values[0]
        
        for lambda_val in lambda_values:
            scores = []
            for train_idx, test_idx in tscv.split(X):
                X_train, X_test = X[train_idx], X[test_idx]
                y_train, y_test = y[train_idx], y[test_idx]
                
                try:
                    model = HuberRegressor(alpha=lambda_val, epsilon=1.35)
                    model.fit(X_train, y_train)
                    y_pred = model.predict(X_test)
                    
                    # Use MAE as in the paper's pseudocode
                    score = np.mean(np.abs(y_test - y_pred))
                    scores.append(score)
                except:
                    scores.append(float('inf'))
            
            avg_score = np.mean(scores) if scores else float('inf')
            if avg_score < best_score:
                best_score = avg_score
                best_lambda = lambda_val
        
        return best_lambda
    
    def _fallback_training(self):
        """
        Fallback training procedure when insufficient data or regression fails.
        """
        self.w_ai, self.w_gti = self.fallback_weights
        
        # Calculate intercept as median residual
        residuals = []
        for obs in self.history:
            if (obs['y_ai'] is not None and obs['y_gti'] is not None and 
                obs['y_actual'] is not None and obs['y_actual'] >= 0):
                y_pred = self.w_ai * obs['y_ai'] + self.w_gti * obs['y_gti']
                residuals.append(obs['y_actual'] - y_pred)
        
        self.b = np.median(residuals) if residuals else 0.0
        self.is_trained = True
        
        return self.w_ai, self.w_gti, self.b
    
    def calibrate_scale(self, min_points=3):
        """
        Perform online calibration using median-ratio scaling.
        Implements the CALIBRATE_SCALE procedure from the paper.
        
        Parameters:
        -----------
        min_points : int
            Minimum number of valid points required for calibration
            
        Returns:
        --------
        scale : float
            Calibration scale factor
        """
        # Get recent observations for calibration
        recent_data = list(self.history)[-self.calibration_window_size:]
        
        ratios = []
        for obs in recent_data:
            if (obs['y_ai'] is not None and obs['y_gti'] is not None and 
                obs['y_actual'] is not None and obs['y_actual'] > 0):
                
                # Calculate linear prediction
                y_lin = self.w_ai * obs['y_ai'] + self.w_gti * obs['y_gti'] + self.b
                
                if y_lin > 0:
                    ratios.append(obs['y_actual'] / y_lin)
        
        # Check if we have enough valid points
        min_required = max(min_points, int(0.3 * self.calibration_window_size))
        if len(ratios) < min_required:
            warnings.warn(f"Insufficient data for calibration ({len(ratios)} < {min_required}). Using previous scale.")
            return self.prev_scale if hasattr(self, 'prev_scale') else 1.0
        
        # Winsorize to reduce outlier impact
        ratios_sorted = np.sort(ratios)
        n = len(ratios_sorted)
        lower_idx = int(self.alpha * n)
        upper_idx = int((1 - self.alpha) * n) - 1
        
        winsorized = np.clip(ratios_sorted, 
                            ratios_sorted[lower_idx], 
                            ratios_sorted[upper_idx])
        
        # Update scale as median of winsorized ratios
        self.prev_scale = self.scale
        self.scale = np.median(winsorized)
        
        return self.scale
    
    def forecast_day(self, y_ai, y_gti, daylight_hours):
        """
        Generate a forecast for a single day.
        Implements the FORECAST_DAY procedure from the paper.
        
        Parameters:
        -----------
        y_ai : float
            AI-based forecast for the day
        y_gti : float
            GTI-based forecast for the day
        daylight_hours : float
            Number of daylight hours for the day
            
        Returns:
        --------
        y_fusion : float
            Fusion forecast for the day
        """
        if not self.is_trained:
            warnings.warn("Model not trained yet. Using fallback weights.")
            self._fallback_training()
        
        # Handle missing inputs conservatively
        if y_ai is None:
            # Carry-forward AI forecast (simple implementation)
            y_ai = np.mean([obs['y_ai'] for obs in self.history if obs['y_ai'] is not None])
        
        if y_gti is None:
            # Clear-sky capped proxy for GTI (simple implementation)
            y_gti = min(self.p_max * daylight_hours * 0.8,  # Assume 80% of max possible
                       np.mean([obs['y_gti'] for obs in self.history if obs['y_gti'] is not None] or [0]))
        
        # Calculate linear prediction
        y_lin = self.w_ai * y_ai + self.w_gti * y_gti + self.b
        
        # Apply calibration
        y_fusion = self.scale * y_lin
        
        # Apply physical plausibility constraints
        y_fusion = max(0, min(y_fusion, self.p_max * daylight_hours))
        
        return y_fusion
    
    def estimate_uncertainty(self, min_samples=5):
        """
        Estimate prediction uncertainty using residual dispersion.
        Implements the UNCERTAINTY procedure from the paper.
        
        Parameters:
        -----------
        min_samples : int
            Minimum number of samples required for uncertainty estimation
            
        Returns:
        --------
        sigma : float or None
            Estimated standard deviation of prediction errors, or None if insufficient data
        """
        # Get recent observations
        recent_data = list(self.history)[-self.calibration_window_size:]
        
        errors = []
        for obs in recent_data:
            if (obs['y_ai'] is not None and obs['y_gti'] is not None and 
                obs['y_actual'] is not None and obs['y_actual'] > 0):
                
                # Calculate calibrated prediction
                y_lin = self.w_ai * obs['y_ai'] + self.w_gti * obs['y_gti'] + self.b
                y_fus = self.scale * y_lin
                
                if y_fus > 0:
                    errors.append(obs['y_actual'] - y_fus)
        
        if len(errors) < min_samples:
            return None
        
        # Calculate robust sigma using MAD
        errors = np.array(errors)
        med_error = np.median(errors)
        mad = np.median(np.abs(errors - med_error))
        sigma = 1.4826 * mad  # Convert MAD to sigma for normal distribution
        
        return sigma
    
    def prediction_interval(self, y_ai, y_gti, daylight_hours, confidence=0.95):
        """
        Generate a prediction interval for the forecast.
        
        Parameters:
        -----------
        y_ai : float
            AI-based forecast for the day
        y_gti : float
            GTI-based forecast for the day
        daylight_hours : float
            Number of daylight hours for the day
        confidence : float
            Confidence level for the interval (0-1)
            
        Returns:
        --------
        (point_forecast, lower_bound, upper_bound) : tuple
        """
        point_forecast = self.forecast_day(y_ai, y_gti, daylight_hours)
        sigma = self.estimate_uncertainty()
        
        if sigma is None:
            return point_forecast, None, None
        
        # Calculate z-score for given confidence level
        z = norm.ppf(0.5 + confidence/2)
        
        lower = max(0, point_forecast - z * sigma)
        upper = min(self.p_max * daylight_hours, point_forecast + z * sigma)
        
        return point_forecast, lower, upper
    
    def daily_operation(self, date, y_ai, y_gti, daylight_hours, y_actual=None, retrain_frequency=30):
        """
        Complete daily operation as described in the paper.
        
        Parameters:
        -----------
        date : datetime or comparable
            Date for the forecast/observation
        y_ai : float
            AI-based forecast for the day
        y_gti : float
            GTI-based forecast for the day
        daylight_hours : float
            Number of daylight hours for the day
        y_actual : float, optional
            Actual energy production, if available
        retrain_frequency : int
            How often to retrain the model (in days)
            
        Returns:
        --------
        result : dict
            Contains forecast, uncertainty, and model parameters
        """
        # Add observation if available
        if y_actual is not None:
            self.add_observation(date, y_ai, y_gti, y_actual, daylight_hours)
        
        # Check if we need to retrain
        if (len(self.history) % retrain_frequency == 0 or not self.is_trained):
            self.train_fusion()
        
        # Update calibration
        if len(self.history) >= self.calibration_window_size:
            self.calibrate_scale()
        
        # Generate forecast
        forecast = self.forecast_day(y_ai, y_gti, daylight_hours)
        
        # Estimate uncertainty
        sigma = self.estimate_uncertainty()
        
        # Prepare result
        result = {
            'date': date,
            'forecast': forecast,
            'uncertainty': sigma,
            'w_ai': self.w_ai,
            'w_gti': self.w_gti,
            'b': self.b,
            'scale': self.scale
        }
        
        # Add prediction interval if uncertainty is available
        if sigma is not None:
            z_95 = norm.ppf(0.975)  # ~1.96
            result['lower_95'] = max(0, forecast - z_95 * sigma)
            result['upper_95'] = min(self.p_max * daylight_hours, forecast + z_95 * sigma)
        
        return result

def main():
    parser = argparse.ArgumentParser(
        description="Fusion Forecasting for Residential PV - Command Line Interface",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Train and forecast using CSV data
  python fusion_forecast.py --train-data historical.csv --forecast-dates 2023-08-01,2023-08-02
  
  # Run with specific parameters
  python fusion_forecast.py --train-data data.csv --train-window 60 --calibration-window 14 --p-max 8.5
  
  # Save results to file
  python fusion_forecast.py --train-data data.csv --forecast-dates 2023-08-01 --output results.json
        """
    )
    
    # Input/output options
    parser.add_argument('--train-data', type=str, required=True,
                       help='CSV file containing historical training data')
    parser.add_argument('--forecast-dates', type=str, required=True,
                       help='Comma-separated list of dates to forecast (YYYY-MM-DD)')
    parser.add_argument('--output', type=str,
                       help='Output file to save results (JSON format)')
    
    # Model parameters
    parser.add_argument('--train-window', type=int, default=90,
                       help='Number of days to use for training (default: 90)')
    parser.add_argument('--calibration-window', type=int, default=14,
                       help='Number of days to use for calibration (default: 14)')
    parser.add_argument('--p-max', type=float, default=10.0,
                       help='Maximum power capacity of PV system in kW (default: 10.0)')
    parser.add_argument('--min-train-samples', type=int, default=30,
                       help='Minimum samples required for training (default: 30)')
    parser.add_argument('--alpha', type=float, default=0.05,
                       help='Winsorization parameter for outlier removal (default: 0.05)')
    parser.add_argument('--fallback-ai-weight', type=float, default=0.7,
                       help='Fallback weight for AI forecast (default: 0.7)')
    parser.add_argument('--fallback-gti-weight', type=float, default=0.3,
                       help='Fallback weight for GTI forecast (default: 0.3)')
    
    # Operational options
    parser.add_argument('--retrain-frequency', type=int, default=30,
                       help='How often to retrain the model (in days, default: 30)')
    parser.add_argument('--confidence', type=float, default=0.95,
                       help='Confidence level for prediction intervals (default: 0.95)')
    parser.add_argument('--verbose', action='store_true',
                       help='Enable verbose output')
    
    args = parser.parse_args()
    
    # Parse forecast dates
    try:
        forecast_dates = [datetime.strptime(d.strip(), '%Y-%m-%d').date() 
                         for d in args.forecast_dates.split(',')]
    except ValueError as e:
        print(f"Error parsing dates: {e}")
        sys.exit(1)
    
    # Load training data
    try:
        if args.verbose:
            print(f"Loading training data from {args.train_data}...")
        
        data = pd.read_csv(args.train_data, parse_dates=['date'])
        
        # Check required columns
        required_cols = ['date', 'y_ai', 'y_gti', 'y_actual', 'daylight_hours']
        missing_cols = [col for col in required_cols if col not in data.columns]
        if missing_cols:
            print(f"Missing columns in data: {missing_cols}")
            sys.exit(1)
            
    except Exception as e:
        print(f"Error loading data: {e}")
        sys.exit(1)
    
    # Initialize the forecaster
    if args.verbose:
        print("Initializing Fusion Forecaster...")
    
    forecaster = FusionForecaster(
        train_window_size=args.train_window,
        calibration_window_size=args.calibration_window,
        alpha=args.alpha,
        p_max=args.p_max,
        min_train_samples=args.min_train_samples,
        fallback_weights=(args.fallback_ai_weight, args.fallback_gti_weight)
    )
    
    # Add historical observations
    if args.verbose:
        print(f"Adding {len(data)} historical observations...")
    
    for _, row in data.iterrows():
        forecaster.add_observation(
            date=row['date'].date(),
            y_ai=row['y_ai'],
            y_gti=row['y_gti'],
            y_actual=row['y_actual'],
            daylight_hours=row['daylight_hours']
        )
    
    # Generate forecasts
    results = []
    
    for forecast_date in forecast_dates:
        if args.verbose:
            print(f"Generating forecast for {forecast_date}...")
        
        # Find the input data for this date
        # In a real application, you would get this from your forecast sources
        # For this example, we'll use the last available values as proxies
        last_row = data.iloc[-1]
        
        result = forecaster.daily_operation(
            date=forecast_date,
            y_ai=last_row['y_ai'],  # In practice, get from AI forecast service
            y_gti=last_row['y_gti'],  # In practice, get from GTI model
            daylight_hours=last_row['daylight_hours'],  # In practice, calculate based on date/location
            y_actual=None,  # No actual value for future dates
            retrain_frequency=args.retrain_frequency
        )
        
        results.append(result)
    
    # Output results
    if args.output:
        if args.verbose:
            print(f"Saving results to {args.output}...")
        
        # Convert results to JSON-serializable format
        serializable_results = []
        for result in results:
            serializable_result = {
                'date': result['date'].isoformat() if hasattr(result['date'], 'isoformat') else str(result['date']),
                'forecast': float(result['forecast']),
                'uncertainty': float(result['uncertainty']) if result['uncertainty'] is not None else None,
                'w_ai': float(result['w_ai']),
                'w_gti': float(result['w_gti']),
                'b': float(result['b']),
                'scale': float(result['scale'])
            }
            
            # Add prediction interval if available
            if 'lower_95' in result and 'upper_95' in result:
                serializable_result['prediction_interval'] = {
                    'lower': float(result['lower_95']),
                    'upper': float(result['upper_95']),
                    'confidence': args.confidence
                }
                
            serializable_results.append(serializable_result)
        
        with open(args.output, 'w') as f:
            json.dump(serializable_results, f, indent=2)
    
    # Print results to console
    print("\nFusion Forecasting Results:")
    print("=" * 80)
    for result in results:
        print(f"Date: {result['date']}")
        print(f"  Forecast: {result['forecast']:.2f} kWh")
        if result['uncertainty'] is not None:
            print(f"  Uncertainty: ±{result['uncertainty']:.2f} kWh (1σ)")
        if 'lower_95' in result and 'upper_95' in result:
            print(f"  {args.confidence*100:.0f}% Prediction Interval: [{result['lower_95']:.2f}, {result['upper_95']:.2f}] kWh")
        print(f"  Model Parameters: w_ai={result['w_ai']:.3f}, w_gti={result['w_gti']:.3f}, b={result['b']:.3f}")
        print(f"  Calibration Scale: {result['scale']:.3f}")
        print("-" * 80)

if __name__ == "__main__":
    main()