Overview
N-BEATS (Neural Basis Expansion Analysis for Time Series) is a state-of-the-art deep neural network architecture specifically designed for univariate time series forecasting. Unlike traditional recurrent (RNNs) or convolutional (CNNs) models, N-BEATS employs a fully connected feedforward architecture. Its key innovation lies in its ability to provide interpretable outputs by explicitly decomposing time series into distinct components such as trend, seasonality, and noise, while achieving high accuracy.
Architecture & Components
The N-BEATS architecture is built upon a system of "blocks" and "stacks," utilizing backward and forward residual links to iteratively refine predictions. Each block specializes in modeling specific time series patterns, and these blocks are organized into stacks, where each stack can focus on a particular component (e.g., trend or seasonality).
- Blocks: The fundamental building block is a multi-layer fully connected network with ReLU non-linearities. Each block predicts two outputs:
- Backcast: A reconstruction of the input time series, modeling the patterns already captured. This backcast is subtracted from the original input, allowing subsequent blocks to focus on the residual errors. This iterative refinement helps the model incrementally capture finer details.
- Forecast: Predictions for future time series values, extrapolating the identified patterns.
- Stacks: Blocks are organized into stacks. N-BEATS offers two main configurations:
- Generic Architecture: Uses as little prior knowledge as possible, with no explicit time-series-specific components. It demonstrates that deep learning primitives like residual blocks are sufficient for a wide range of forecasting problems.
- Interpretable Architecture: Specifically designed to provide interpretable outputs. It typically consists of two main stacks:
- Trend Stack: Uses polynomial basis expansion to model long-term directional changes. The number of trend coefficients is a configurable hyperparameter.
- Seasonality Stack: Leverages Fourier series basis expansion to model periodic patterns. The number of seasonal coefficients is also configurable, allowing it to capture multiple seasonalities (e.g., hourly, weekly, monthly).
- Double Residual Stacking: This mechanism ensures that each block focuses on the unexplained variance from previous blocks, allowing the model to incrementally capture finer details of the data patterns, including trend, seasonality, and residual noise. The final forecast is the aggregation (sum) of the forecasts from all blocks.
The feedforward design of N-BEATS helps it avoid the limitations often associated with recursive memory, contributing to its robustness and generalization capabilities.
Conceptual diagram of N-BEATS architecture with blocks and stacks.
When to Use N-BEATS
N-BEATS is a highly competitive and versatile option, particularly suitable for:
- Achieving state-of-the-art accuracy on univariate time series forecasting problems.
- Scenarios where interpretability is desired, as it can explicitly decompose the series into trend, seasonality, and residuals.
- Long-term forecasting, as its forecasts do not seem to deteriorate significantly with extended horizons.
- Handling structurally diverse and non-linear data, including those with multiple orders of seasonality.
- As a robust alternative to recurrent or convolutional networks for sequence modeling.
Pros and Cons
Pros
- State-of-the-Art Performance: Achieved top results in major forecasting competitions (M3, M4, TOURISM), often outperforming statistical and hybrid methods.
- Interpretable by Design: Can explicitly decompose time series into trend, seasonality, and noise components, providing insights into predictions.
- Robustness & Generalization: Its feedforward design and residual stacking enhance robustness to noise, outliers, and complex dynamics, generalizing well across diverse datasets.
- No Recurrent Connections: Avoids vanishing/exploding gradient issues common in RNNs, and can be faster to train than LSTMs in some cases.
- Probabilistic Forecasting: Can provide probabilistic forecasts (e.g., quantile estimates) with low computational cost.
Cons
- Complex Architecture: While powerful, its deep stacking architecture can be complex to understand and implement from scratch.
- Computationally Intensive: Training and hyperparameter tuning can be time-consuming, especially for very large parameter values that define the tensor size.
- Hyperparameter Sensitivity: Performance can be sensitive to hyperparameters like input/output layer sizes, number of blocks/stacks, and layer widths.
- Batch Size Sensitivity: Very large batch sizes can mislead gradient descent, while very small ones might reduce accuracy.
- Output Chunk Length Constraints: The output chunk length cannot be arbitrarily increased due to memory and tensor size constraints.
Example Implementation
N-BEATS has implementations in both PyTorch and TensorFlow. Here are conceptual examples for both frameworks, highlighting key aspects of their usage. For full runnable examples, refer to the official repositories.
PyTorch Example (using Darts/PyTorch Lightning)
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from darts import TimeSeries
from darts.models import NBEATSModel
from darts.dataprocessing.transformers import MinMaxScaler
from darts.utils.timeseries_generation import gaussian_timeseries, linear_timeseries
# 1. Generate sample data (or load your own)
# Create a simple series with trend and seasonality
n_samples = 300
time_index = pd.date_range(start='2020-01-01', periods=n_samples, freq='D')
trend_data = linear_timeseries(start_value=0, end_value=100, length=n_samples)
seasonal_data = gaussian_timeseries(length=n_samples, mean=0, std=5, start_value=0) + \
10 * np.sin(np.arange(n_samples) * 2 * np.pi / 30) # Monthly-like seasonality
noise_data = gaussian_timeseries(length=n_samples, mean=0, std=2, start_value=0)
series_data = trend_data + seasonal_data + noise_data
series = TimeSeries.from_series(pd.Series(series_data.values.flatten(), index=time_index))
# 2. Preprocess data (scaling is often beneficial for neural networks)
scaler = MinMaxScaler()
series_scaled = scaler.fit_transform(series)
# 3. Split data into train and validation sets
train_size = int(0.8 * len(series_scaled))
train_series, val_series = series_scaled.split_after(train_size)
# 4. Define N-BEATS model hyperparameters
forecast_length = 10 # Number of steps to forecast
backcast_length = 3 * forecast_length # Number of past steps to consider
# 5. Instantiate the N-BEATS model
# Using interpretable architecture with Trend and Seasonality stacks
model = NBEATSModel(
input_chunk_length=backcast_length,
output_chunk_length=forecast_length,
generic_architecture=False, # Set to False for interpretable architecture
num_stacks=2, # For interpretable: 1 for trend, 1 for seasonality
num_blocks_per_stack=3,
num_layers=4,
layer_widths=512,
expansion_coefficient_dim=5, # For basis expansion
trend_polynomial_degree=2, # For trend stack
dropout=0.1,
n_epochs=50, # Reduced epochs for quick demo
batch_size=32,
optimizer_kwargs={"lr": 1e-3},
random_state=42,
pl_trainer_kwargs={
"accelerator": "auto", # Use GPU if available, otherwise CPU
"callbacks": [ModelCheckpoint(monitor="val_loss", mode="min", filename="nbeats_best_model")],
"logger": pl_loggers.TensorBoardLogger("lightning_logs/", name="nbeats_experiment")
}
)
# 6. Train the model
# Darts handles data loading and batching internally with the TimeSeries objects
model.fit(train_series, val_series=val_series, verbose=False) # verbose=False to suppress detailed output
# 7. Make a forecast
prediction = model.predict(n=forecast_length)
# 8. Inverse transform the forecast to original scale
prediction_original_scale = scaler.inverse_transform(prediction)
# 9. Display results (conceptual)
print("N-BEATS PyTorch (Darts) model training complete.")
print("\nForecast (original scale):")
print(prediction_original_scale.pd_series())
# Plotting (conceptual)
# plt.figure(figsize=(14, 7))
# series.plot(label='Original Series')
# prediction_original_scale.plot(label='N-BEATS Forecast', linestyle='--')
# plt.title('N-BEATS Model Forecast (PyTorch/Darts)')
# plt.xlabel('Date')
# plt.ylabel('Value')
# plt.legend()
# plt.grid(True)
# plt.show()