← Back to Model Library

N-BEATS Model

Neural Basis Expansion Analysis for Interpretable Time Series Forecasting

Overview

N-BEATS (Neural Basis Expansion Analysis for Time Series) is a state-of-the-art deep neural network architecture specifically designed for univariate time series forecasting. Unlike traditional recurrent (RNNs) or convolutional (CNNs) models, N-BEATS employs a fully connected feedforward architecture. Its key innovation lies in its ability to provide interpretable outputs by explicitly decomposing time series into distinct components such as trend, seasonality, and noise, while achieving high accuracy.

Architecture & Components

The N-BEATS architecture is built upon a system of "blocks" and "stacks," utilizing backward and forward residual links to iteratively refine predictions. Each block specializes in modeling specific time series patterns, and these blocks are organized into stacks, where each stack can focus on a particular component (e.g., trend or seasonality).

  • Blocks: The fundamental building block is a multi-layer fully connected network with ReLU non-linearities. Each block predicts two outputs:
    • Backcast: A reconstruction of the input time series, modeling the patterns already captured. This backcast is subtracted from the original input, allowing subsequent blocks to focus on the residual errors. This iterative refinement helps the model incrementally capture finer details.
    • Forecast: Predictions for future time series values, extrapolating the identified patterns.
  • Stacks: Blocks are organized into stacks. N-BEATS offers two main configurations:
    • Generic Architecture: Uses as little prior knowledge as possible, with no explicit time-series-specific components. It demonstrates that deep learning primitives like residual blocks are sufficient for a wide range of forecasting problems.
    • Interpretable Architecture: Specifically designed to provide interpretable outputs. It typically consists of two main stacks:
      • Trend Stack: Uses polynomial basis expansion to model long-term directional changes. The number of trend coefficients is a configurable hyperparameter.
      • Seasonality Stack: Leverages Fourier series basis expansion to model periodic patterns. The number of seasonal coefficients is also configurable, allowing it to capture multiple seasonalities (e.g., hourly, weekly, monthly).
  • Double Residual Stacking: This mechanism ensures that each block focuses on the unexplained variance from previous blocks, allowing the model to incrementally capture finer details of the data patterns, including trend, seasonality, and residual noise. The final forecast is the aggregation (sum) of the forecasts from all blocks.

The feedforward design of N-BEATS helps it avoid the limitations often associated with recursive memory, contributing to its robustness and generalization capabilities.

N-BEATS Architecture Diagram

Conceptual diagram of N-BEATS architecture with blocks and stacks.

When to Use N-BEATS

N-BEATS is a highly competitive and versatile option, particularly suitable for:

  • Achieving state-of-the-art accuracy on univariate time series forecasting problems.
  • Scenarios where interpretability is desired, as it can explicitly decompose the series into trend, seasonality, and residuals.
  • Long-term forecasting, as its forecasts do not seem to deteriorate significantly with extended horizons.
  • Handling structurally diverse and non-linear data, including those with multiple orders of seasonality.
  • As a robust alternative to recurrent or convolutional networks for sequence modeling.

Pros and Cons

Pros

  • State-of-the-Art Performance: Achieved top results in major forecasting competitions (M3, M4, TOURISM), often outperforming statistical and hybrid methods.
  • Interpretable by Design: Can explicitly decompose time series into trend, seasonality, and noise components, providing insights into predictions.
  • Robustness & Generalization: Its feedforward design and residual stacking enhance robustness to noise, outliers, and complex dynamics, generalizing well across diverse datasets.
  • No Recurrent Connections: Avoids vanishing/exploding gradient issues common in RNNs, and can be faster to train than LSTMs in some cases.
  • Probabilistic Forecasting: Can provide probabilistic forecasts (e.g., quantile estimates) with low computational cost.

Cons

  • Complex Architecture: While powerful, its deep stacking architecture can be complex to understand and implement from scratch.
  • Computationally Intensive: Training and hyperparameter tuning can be time-consuming, especially for very large parameter values that define the tensor size.
  • Hyperparameter Sensitivity: Performance can be sensitive to hyperparameters like input/output layer sizes, number of blocks/stacks, and layer widths.
  • Batch Size Sensitivity: Very large batch sizes can mislead gradient descent, while very small ones might reduce accuracy.
  • Output Chunk Length Constraints: The output chunk length cannot be arbitrarily increased due to memory and tensor size constraints.

Example Implementation

N-BEATS has implementations in both PyTorch and TensorFlow. Here are conceptual examples for both frameworks, highlighting key aspects of their usage. For full runnable examples, refer to the official repositories.

PyTorch Example (using Darts/PyTorch Lightning)


import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from darts import TimeSeries
from darts.models import NBEATSModel
from darts.dataprocessing.transformers import MinMaxScaler
from darts.utils.timeseries_generation import gaussian_timeseries, linear_timeseries

# 1. Generate sample data (or load your own)
# Create a simple series with trend and seasonality
n_samples = 300
time_index = pd.date_range(start='2020-01-01', periods=n_samples, freq='D')
trend_data = linear_timeseries(start_value=0, end_value=100, length=n_samples)
seasonal_data = gaussian_timeseries(length=n_samples, mean=0, std=5, start_value=0) + \
                10 * np.sin(np.arange(n_samples) * 2 * np.pi / 30) # Monthly-like seasonality
noise_data = gaussian_timeseries(length=n_samples, mean=0, std=2, start_value=0)
series_data = trend_data + seasonal_data + noise_data
series = TimeSeries.from_series(pd.Series(series_data.values.flatten(), index=time_index))

# 2. Preprocess data (scaling is often beneficial for neural networks)
scaler = MinMaxScaler()
series_scaled = scaler.fit_transform(series)

# 3. Split data into train and validation sets
train_size = int(0.8 * len(series_scaled))
train_series, val_series = series_scaled.split_after(train_size)

# 4. Define N-BEATS model hyperparameters
forecast_length = 10 # Number of steps to forecast
backcast_length = 3 * forecast_length # Number of past steps to consider

# 5. Instantiate the N-BEATS model
# Using interpretable architecture with Trend and Seasonality stacks
model = NBEATSModel(
    input_chunk_length=backcast_length,
    output_chunk_length=forecast_length,
    generic_architecture=False, # Set to False for interpretable architecture
    num_stacks=2, # For interpretable: 1 for trend, 1 for seasonality
    num_blocks_per_stack=3,
    num_layers=4,
    layer_widths=512,
    expansion_coefficient_dim=5, # For basis expansion
    trend_polynomial_degree=2, # For trend stack
    dropout=0.1,
    n_epochs=50, # Reduced epochs for quick demo
    batch_size=32,
    optimizer_kwargs={"lr": 1e-3},
    random_state=42,
    pl_trainer_kwargs={
        "accelerator": "auto", # Use GPU if available, otherwise CPU
        "callbacks": [ModelCheckpoint(monitor="val_loss", mode="min", filename="nbeats_best_model")],
        "logger": pl_loggers.TensorBoardLogger("lightning_logs/", name="nbeats_experiment")
    }
)

# 6. Train the model
# Darts handles data loading and batching internally with the TimeSeries objects
model.fit(train_series, val_series=val_series, verbose=False) # verbose=False to suppress detailed output

# 7. Make a forecast
prediction = model.predict(n=forecast_length)

# 8. Inverse transform the forecast to original scale
prediction_original_scale = scaler.inverse_transform(prediction)

# 9. Display results (conceptual)
print("N-BEATS PyTorch (Darts) model training complete.")
print("\nForecast (original scale):")
print(prediction_original_scale.pd_series())

# Plotting (conceptual)
# plt.figure(figsize=(14, 7))
# series.plot(label='Original Series')
# prediction_original_scale.plot(label='N-BEATS Forecast', linestyle='--')
# plt.title('N-BEATS Model Forecast (PyTorch/Darts)')
# plt.xlabel('Date')
# plt.ylabel('Value')
# plt.legend()
# plt.grid(True)
# plt.show()
                        

TensorFlow Example (using flaviagiammarino/nbeats-tensorflow)


import numpy as np
import pandas as pd
import tensorflow as tf
from nbeats_tensorflow.model import NBeats
# from nbeats_tensorflow.plots import plot # Requires plotly and kaleido for image export

# 1. Generate sample data
N = 1000
t = np.linspace(0, 1, N)
trend = 30 + 20 * t + 10 * (t ** 2)
seasonality = 5 * np.cos(2 * np.pi * (10 * t - 0.5)) + 3 * np.sin(2 * np.pi * (20 * t))
noise = np.random.normal(0, 1, N)
y_data = trend + seasonality + noise

# 2. Define model parameters
forecast_period = 200
lookback_period = 400

# 3. Fit the N-BEATS model
model = NBeats(
    y=y_data,
    forecast_period=forecast_period,
    lookback_period=lookback_period,
    units=30, # Number of units in FC layers
    stacks=['trend', 'seasonality'], # Use interpretable stacks
    num_trend_coefficients=3,
    num_seasonal_coefficients=5,
    num_blocks_per_stack=2,
    share_weights=True,
    share_coefficients=False,
)

model.fit(
    loss='mse',
    epochs=50, # Reduced epochs for quick demo
    batch_size=32,
    learning_rate=0.003,
    backcast_loss_weight=0.5, # Weight for the backcast loss
    verbose=0 # Set to 1 or True for verbose output
)

# 4. Generate forecasts and backcasts
df_results = model.forecast(y=y_data, return_backcast=True)

# 5. Display results (conceptual)
print("N-BEATS TensorFlow model training complete.")
print("\nForecast (first 5 rows):")
print(df_results[['forecast']].head())
print("\nBackcast (first 5 rows):")
print(df_results[['backcast']].head())

# Plotting (conceptual, requires matplotlib or plotly)
# import matplotlib.pyplot as plt
# plt.figure(figsize=(14, 7))
# plt.plot(df_results.index[:len(y_data)], y_data, label='Original Data')
# plt.plot(df_results.index[len(y_data):], df_results['forecast'].iloc[len(y_data):], label='N-BEATS Forecast', linestyle='--')
# plt.plot(df_results.index[:len(y_data)], df_results['backcast'].iloc[:len(y_data)], label='N-BEATS Backcast', linestyle=':')
# plt.title('N-BEATS Model Forecast and Backcast (TensorFlow)')
# plt.xlabel('Time Step')
# plt.ylabel('Value')
# plt.legend()
# plt.grid(True)
# plt.show()
                        

Dependencies & Resources

Dependencies:

Resources: