Source code for zoo.chronos.model.forecast.arima_forecaster

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from zoo.chronos.model.forecast.abstract import Forecaster
from zoo.chronos.model.arima import ARIMAModel


[docs]class ARIMAForecaster(Forecaster): """ Example: >>> #The dataset is split into data, validation_data >>> model = ARIMAForecaster(p=2, q=2, seasonality_mode=False) >>> model.fit(data, validation_data) >>> predict_result = model.predict(horizon=24) """ def __init__(self, p=2, q=2, seasonality_mode=True, P=3, Q=1, m=7, metric="mse", ): """ Build a ARIMA Forecast Model. User can customize p, q, seasonality_mode, P, Q, m, metric for the ARIMA model, the differencing term (d) and seasonal differencing term (D) are automatically estimated from the data. For details of the ARIMA model hyperparameters, refer to https://alkaline-ml.com/pmdarima/modules/generated/pmdarima.arima.ARIMA.html#pmdarima.arima.ARIMA. :param p: hyperparameter p for the ARIMA model. :param q: hyperparameter q for the ARIMA model. :param seasonality_mode: hyperparameter seasonality_mode for the ARIMA model. :param P: hyperparameter P for the ARIMA model. :param Q: hyperparameter Q for the ARIMA model. :param m: hyperparameter m for the ARIMA model. :param metric: the metric for validation and evaluation. For regression, we support Mean Squared Error: ("mean_squared_error", "MSE" or "mse"), Mean Absolute Error: ("mean_absolute_error","MAE" or "mae"), Mean Absolute Percentage Error: ("mean_absolute_percentage_error", "MAPE", "mape") Cosine Proximity: ("cosine_proximity", "cosine") """ self.model_config = { "p": p, "q": q, "seasonality_mode": seasonality_mode, "P": P, "Q": Q, "m": m, "metric": metric, } self.internal = ARIMAModel() super().__init__()
[docs] def fit(self, data, validation_data): """ Fit(Train) the forecaster. :param data: A 1-D numpy array as the training data :param validation_data: A 1-D numpy array as the evaluation data """ self._check_data(data, validation_data) data = data.reshape(-1, 1) validation_data = validation_data.reshape(-1, 1) return self.internal.fit_eval(data=data, validation_data=validation_data, **self.model_config)
def _check_data(self, data, validation_data): assert data.ndim == 1, \ "data should be an 1-D array), \ Got data dimension of {}."\ .format(data.ndim) assert validation_data.ndim == 1, \ "validation_data should be an 1-D array), \ Got validation_data dimension of {}."\ .format(validation_data.ndim)
[docs] def predict(self, horizon, rolling=False): """ Predict using a trained forecaster. :param horizon: the number of steps forward to predict :param rolling: whether to use rolling prediction """ if self.internal.model is None: raise RuntimeError("You must call fit or restore first before calling predict!") return self.internal.predict(horizon=horizon, rolling=rolling)
[docs] def evaluate(self, validation_data, metrics=['mse'], rolling=False): """ Evaluate using a trained forecaster. :param validation_data: A 1-D numpy array as the evaluation data :param metrics: A list contains metrics for test/valid data. """ if validation_data is None: raise ValueError("Input invalid validation_data of None") if self.internal.model is None: raise RuntimeError("You must call fit or restore first before calling evaluate!") return self.internal.evaluate(validation_data, metrics=metrics, rolling=rolling)
[docs] def save(self, checkpoint_file): """ Save the forecaster. :param checkpoint_file: The location you want to save the forecaster. """ if self.internal.model is None: raise RuntimeError("You must call fit or restore first before calling save!") self.internal.save(checkpoint_file)
[docs] def restore(self, checkpoint_file): """ Restore the forecaster. :param checkpoint_file: The checkpoint file location you want to load the forecaster. """ self.internal.restore(checkpoint_file)