# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import pickle
from pmdarima.arima import ARIMA
from pmdarima.arima import ndiffs
from pmdarima.arima import nsdiffs

from bigdl.chronos.metric.forecast_metrics import Evaluator


class ARIMAModel:

    def __init__(self):
        """
        Initialize Model
        """
        self.seasonal = True
        self.metric = 'mse'
        self.model = None
        self.model_init = False

    def _build(self, **config):
        """
        build the models and initialize.
        :param config: hyperparameters for the model
        """
        p = config.get('p', 2)
        d = config.get('d', 0)
        q = config.get('q', 2)
        self.seasonal = config.get('seasonality_mode', True)
        P = config.get('P', 1)
        D = config.get('D', 0)
        Q = config.get('Q', 1)
        m = config.get('m', 7)
        self.metric = config.get('metric', self.metric)
        self.metric_func = config.get('metric_func', None)

        order = (p, d, q)
        if not self.seasonal:
            seasonal_order = (0, 0, 0, 0)
        else:
            seasonal_order = (P, D, Q, m)

        self.model = ARIMA(order=order, seasonal_order=seasonal_order, suppress_warnings=True)

    def fit_eval(self, data, validation_data, **config):
        """
        Fit on the training data from scratch.
        :param data: A 1-D numpy array as the training data
        :param validation_data: A 1-D numpy array as the evaluation data
        :return: the evaluation metric value
        """

        if not self.model_init:
            # Estimating differencing term (d) and seasonal differencing term (D)
            kpss_diffs = ndiffs(data, alpha=0.05, test='kpss', max_d=6)
            adf_diffs = ndiffs(data, alpha=0.05, test='adf', max_d=6)
            d = max(adf_diffs, kpss_diffs)
            D = 0 if not self.seasonal else nsdiffs(data, m=7, max_D=12)
            config.update(d=d, D=D)

            self._build(**config)
            self.model_init = True

        self.model.fit(data)
        if self.metric_func:
            val_metric = self.evaluate(x=None, target=validation_data,
                                       metrics=[self.metric_func])[0].item()
        else:
            val_metric = self.evaluate(x=None, target=validation_data,
                                       metrics=[self.metric])[0].item()
        if self.metric_func:
            return {self.metric_func.__name__: val_metric}
        else:
            return {self.metric: val_metric}

    def predict(self, x=None, horizon=24, update=False, rolling=False):
        """
        Predict horizon time-points ahead the input x in fit_eval
        :param x: ARIMA predicts the horizon steps foreward from the training data.
            So x should be None as it is not used.
        :param horizon: the number of steps forward to predict
        :param update: whether to update the original model
        :param rolling: whether to use rolling prediction
        :return: predicted result of length horizon
        """
        from bigdl.nano.utils.common import invalidInputError
        if x is not None:
            invalidInputError(False, "x should be None")
        if update and not rolling:
            invalidInputError(False,
                              "We don't support updating model without"
                              " rolling prediction currently")
        if self.model is None:
            invalidInputError(False,
                              "Needs to call fit_eval or restore first before calling predict")

        if not update and not rolling:
            forecasts = self.model.predict(n_periods=horizon)
        elif rolling:
            if not update:
                self.save("tmp.pkl")

            forecasts = []
            for step in range(horizon):
                fc = self.model.predict(n_periods=1).item()
                forecasts.append(fc)

                # Updates the existing model with a small number of MLE steps for rolling prediction
                self.model.update(fc)

            if not update:
                self.restore("tmp.pkl")
                os.remove("tmp.pkl")

        return forecasts

    def evaluate(self, target, x=None, metrics=['mse'], rolling=False):
        """
        Evaluate on the prediction results and y. We predict horizon time-points ahead the input x
        in fit_eval before evaluation, where the horizon length equals the second dimension size of
        y.
        :param target: target for evaluation.
        :param x: ARIMA predicts the horizon steps foreward from the training data.
            So x should be None as it is not used.
        :param metrics: a list of metrics in string format or callable function with format
               it signature should be func(y_true, y_pred), where y_true and y_pred are numpy
               ndarray. The function should return a float value as evaluation result.
        :param rolling: whether to use rolling prediction
        :return: a list of metric evaluation results
        """
        from bigdl.nano.utils.common import invalidInputError
        if x is not None:
            invalidInputError(False,
                              "We don't support input x currently")
        if target is None:
            invalidInputError(False,
                              "Input invalid target of None")
        if self.model is None:
            invalidInputError(False,
                              "Needs to call fit_eval or restore first before calling evaluate")

        forecasts = self.predict(horizon=len(target), rolling=rolling)

        return Evaluator.evaluate(metrics, target, forecasts, aggregate="mean")

    def save(self, checkpoint_file):
        from bigdl.nano.utils.common import invalidInputError
        if self.model is None:
            invalidInputError(False,
                              "Needs to call fit_eval or restore first before calling save")
        with open(checkpoint_file, 'wb') as fout:
            pickle.dump(self.model, fout)

    def restore(self, checkpoint_file):
        with open(checkpoint_file, 'rb') as fin:
            self.model = pickle.load(fin)
        self.model_init = True


class ARIMABuilder:

    def __init__(self, **arima_config):
        """
        Initialize ARIMA Model Builder
        :param ARIMA_config: Other ARIMA hyperparameters. You may refer to
           https://alkaline-ml.com/pmdarima/modules/generated/pmdarima.arima.ARIMA.html#pmdarima.arima.ARIMA
        for the parameter names to specify.
        """
        self.model_config = arima_config.copy()

    def build(self, config):
        """
        Build ARIMA Model
        :param config: Other ARIMA hyperparameters. You may refer to
           https://alkaline-ml.com/pmdarima/modules/generated/pmdarima.arima.ARIMA.html#pmdarima.arima.ARIMA
        for the parameter names to specify.
        """
        from bigdl.chronos.model.arima import ARIMAModel
        model = ARIMAModel()
        model._build(**config)
        return model
