Source code for zoo.chronos.forecaster.seq2seq_forecaster

# Copyright 2018 Analytics Zoo Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from zoo.chronos.forecaster.base_forecaster import BasePytorchForecaster
from zoo.chronos.forecaster.utils import set_pytorch_seed
from zoo.chronos.model.Seq2Seq_pytorch import Seq2SeqPytorch
from zoo.chronos.model.Seq2Seq_pytorch import model_creator, optimizer_creator, loss_creator

[docs]class Seq2SeqForecaster(BasePytorchForecaster): """ Example: >>> #The dataset is split into x_train, x_val, x_test, y_train, y_val, y_test >>> forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=2, input_feature_num=1, output_feature_num=1, ...) >>>, y_train)) >>> forecaster.to_local() # if you set distributed=True >>> test_pred = forecaster.predict(x_test) >>> test_eval = forecaster.evaluate((x_test, y_test)) >>>{ckpt_name}) >>> forecaster.restore({ckpt_name}) """ def __init__(self, past_seq_len, future_seq_len, input_feature_num, output_feature_num, lstm_hidden_dim=64, lstm_layer_num=2, teacher_forcing=False, dropout=0.1, optimizer="Adam", loss="mse", lr=0.001, metrics=["mse"], seed=None, distributed=False, workers_per_node=1, distributed_backend="torch_distributed"): """ Build a TCN Forecast Model. TCN Forecast may fall into local optima. Please set repo_initialization to False to alleviate the issue. You can also change a random seed to work around. :param past_seq_len: Specify the history time steps (i.e. lookback). :param future_seq_len: Specify the output time steps (i.e. horizon). :param input_feature_num: Specify the feature dimension. :param output_feature_num: Specify the output dimension. :param lstm_hidden_dim: LSTM hidden channel for decoder and encoder. The value defaults to 64. :param lstm_layer_num: LSTM layer number for decoder and encoder. The value defaults to 2. :param teacher_forcing: If use teacher forcing in training. The value defaults to False. :param dropout: Specify the dropout close possibility (i.e. the close possibility to a neuron). This value defaults to 0.1. :param optimizer: Specify the optimizer used for training. This value defaults to "Adam". :param loss: Specify the loss function used for training. This value defaults to "mse". You can choose from "mse", "mae" and "huber_loss". :param lr: Specify the learning rate. This value defaults to 0.001. :param metrics: A list contains metrics for evaluating the quality of forecasting. You may only choose from "mse" and "mae" for a distributed forecaster. You may choose from "mse", "me", "mae", "mse","rmse","msle","r2", "mpe", "mape", "mspe", "smape", "mdape" and "smdape" for a non-distributed forecaster. :param seed: int, random seed for training. This value defaults to None. :param distributed: bool, if init the forecaster in a distributed fashion. If True, the internal model will use an Orca Estimator. If False, the internal model will use a pytorch model. The value defaults to False. :param workers_per_node: int, the number of worker you want to use. The value defaults to 1. The param is only effective when distributed is set to True. :param distributed_backend: str, select from "torch_distributed" or "horovod". The value defaults to "torch_distributed". """ # config setting self.data_config = { "past_seq_len": past_seq_len, "future_seq_len": future_seq_len, "input_feature_num": input_feature_num, "output_feature_num": output_feature_num } self.config = { "lr": lr, "loss": loss, "lstm_hidden_dim": lstm_hidden_dim, "lstm_layer_num": lstm_layer_num, "teacher_forcing": teacher_forcing, "optim": optimizer, "dropout": dropout } # model creator settings self.local_model = Seq2SeqPytorch self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator # distributed settings self.distributed = distributed self.distributed_backend = distributed_backend self.workers_per_node = workers_per_node # other settings = lr self.metrics = metrics self.seed = seed super().__init__()