from typing import Optional, Tuple, Union
import numpy as np
import pandas as pd
from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.utils.datetime import get_seasonality
from autogluon.timeseries.utils.warning_filters import warning_filter
class TimeSeriesScorer:
    """Base class for all evaluation metrics used in AutoGluon-TimeSeries.
    This object always returns the metric in greater-is-better format.
    Follows the design of ``autogluon.core.metrics.Scorer``.
    Attributes
    ----------
    greater_is_better_internal : bool, default = False
        Whether internal method :meth:`~autogluon.timeseries.metrics.TimeSeriesScorer.compute_metric` is
        a loss function (default), meaning low is good, or a score function, meaning high is good.
    optimum : float, default = 0.0
        The best score achievable by the score function, i.e. maximum in case of scorer function and minimum in case of
        loss function.
    optimized_by_median : bool, default = False
        Whether given point forecast metric is optimized by the median (if True) or expected value (if False). If True,
        all models in AutoGluon-TimeSeries will attempt to paste median forecast into the "mean" column.
    needs_quantile : bool, default = False
        Whether the given metric uses the quantile predictions. Some models will modify the training procedure if they
        are trained to optimize a quantile metric.
    equivalent_tabular_regression_metric : str
        Name of an equivalent metric used by AutoGluon-Tabular with ``problem_type="regression"``. Used by models that
        train a TabularPredictor under the hood. This attribute should only be specified by point forecast metrics.
    """
    greater_is_better_internal: bool = False
    optimum: float = 0.0
    optimized_by_median: bool = False
    needs_quantile: bool = False
    equivalent_tabular_regression_metric: Optional[str] = None
    @property
    def sign(self) -> int:
        return 1 if self.greater_is_better_internal else -1
    @property
    def name(self) -> str:
        return f"{self.__class__.__name__}"
    def __repr__(self) -> str:
        return self.name
    def __str__(self) -> str:
        return self.name
    @property
    def name_with_sign(self) -> str:
        if self.greater_is_better_internal:
            prefix = ""
        else:
            prefix = "-"
        return f"{prefix}{self.name}"
    def __call__(
        self,
        data: TimeSeriesDataFrame,
        predictions: TimeSeriesDataFrame,
        prediction_length: int = 1,
        target: str = "target",
        seasonal_period: Optional[int] = None,
        **kwargs,
    ) -> float:
        seasonal_period = get_seasonality(data.freq) if seasonal_period is None else seasonal_period
        data_past = data.slice_by_timestep(None, -prediction_length)
        data_future = data.slice_by_timestep(-prediction_length, None)
        assert not predictions.isna().any().any(), "Predictions contain NaN values."
        assert (predictions.num_timesteps_per_item() == prediction_length).all()
        assert data_future.index.equals(predictions.index), "Prediction and data indices do not match."
        try:
            with warning_filter():
                self.save_past_metrics(
                    data_past=data_past,
                    target=target,
                    seasonal_period=seasonal_period,
                    **kwargs,
                )
                metric_value = self.compute_metric(
                    data_future=data_future,
                    predictions=predictions,
                    target=target,
                    **kwargs,
                )
        finally:
            self.clear_past_metrics()
        return metric_value * self.sign
    score = __call__
[docs]
    def compute_metric(
        self,
        data_future: TimeSeriesDataFrame,
        predictions: TimeSeriesDataFrame,
        target: str = "target",
        **kwargs,
    ) -> float:
        """Internal method that computes the metric for given forecast & actual data.
        This method should be implemented by all custom metrics.
        Parameters
        ----------
        data_future : TimeSeriesDataFrame
            Actual values of the time series during the forecast horizon (``prediction_length`` values for each time
            series in the dataset). Must have the same index as ``predictions``.
        predictions : TimeSeriesDataFrame
            Data frame with predictions for the forecast horizon. Contain columns "mean" (point forecast) and the
            columns corresponding to each of the quantile levels. Must have the same index as ``data_future``.
        target : str, default = "target"
            Name of the column in ``data_future`` that contains the target time series.
        Returns
        -------
        score : float
            Value of the metric for given forecast and data. If self.greater_is_better_internal is True, returns score
            in greater-is-better format, otherwise in lower-is-better format.
        """
        raise NotImplementedError 
    def save_past_metrics(
        self,
        data_past: TimeSeriesDataFrame,
        target: str = "target",
        seasonal_period: int = 1,
        **kwargs,
    ) -> None:
        """Compute auxiliary metrics on past data (before forecast horizon), if the chosen metric requires it.
        This method should only be implemented by metrics that rely on historic (in-sample) data, such as Mean Absolute
        Scaled Error (MASE) https://en.wikipedia.org/wiki/Mean_absolute_scaled_error.
        We keep this method separate from :meth:`compute_metric` to avoid redundant computations when fitting ensemble.
        """
        pass
    def clear_past_metrics(self) -> None:
        """Clear auxiliary metrics saved in :meth:`save_past_metrics`.
        This method should only be implemented if :meth:`save_past_metrics` has been implemented.
        """
        pass
    def error(self, *args, **kwargs):
        """Return error in lower-is-better format."""
        return self.optimum - self.score(*args, **kwargs)
    @staticmethod
    def _safemean(array: Union[np.ndarray, pd.Series]) -> float:
        """Compute mean of a numpy array-like object, ignoring inf, -inf and nan values."""
        return float(np.mean(array[np.isfinite(array)]))
    @staticmethod
    def _get_point_forecast_score_inputs(
        data_future: TimeSeriesDataFrame, predictions: TimeSeriesDataFrame, target: str = "target"
    ) -> Tuple[pd.Series, pd.Series]:
        """Get inputs necessary to compute point forecast metrics.
        Returns
        -------
        y_true : pd.Series, shape [num_items * prediction_length]
            Target time series values during the forecast horizon.
        y_pred : pd.Series, shape [num_items * prediction_length]
            Predicted time series values during the forecast horizon.
        """
        y_true = data_future[target]
        y_pred = predictions["mean"]
        return y_true, y_pred
    @staticmethod
    def _get_quantile_forecast_score_inputs(
        data_future: TimeSeriesDataFrame, predictions: TimeSeriesDataFrame, target: str = "target"
    ) -> Tuple[pd.Series, pd.DataFrame, np.ndarray]:
        """Get inputs necessary to compute quantile forecast metrics.
        Returns
        -------
        y_true : pd.Series, shape [num_items * prediction_length]
            Target time series values during the forecast horizon.
        q_pred : pd.DataFrame, shape [num_items * prediction_length, num_quantiles]
            Quantile forecast for each predicted quantile level. Column order corresponds to ``quantile_levels``.
        quantile_levels : np.ndarray, shape [num_quantiles]
            Quantile levels for which the forecasts are generated (as floats).
        """
        quantile_columns = [col for col in predictions.columns if col != "mean"]
        y_true = data_future[target]
        q_pred = pd.DataFrame(predictions[quantile_columns])
        quantile_levels = np.array(quantile_columns, dtype=float)
        return y_true, q_pred, quantile_levels