논문리뷰

[논문리뷰] Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting (1/3)

매쓰도윤 2023. 8. 1. 16:17

SSM 논문에 이어 time series forecasting과 관련된 또다른 논문을 읽어보았다. 이번 논문은 23쪽에 육박하기 때문에 읽는 데 시간이 좀 걸릴 것 같지만... 차근차근 읽어보도록 하겠다.

 

1. Introduction

Multi-horizon forecasting은 multiple future time step에서의 관심 있는 변수를 예측하는 것으로, time series machine learning에서 중요한 문제이다. one-step-ahead prediction과는 달리, multi-horizon forecast는 사용자에게 전체에 걸친 추정치를 제공하여 미래의 여러 단계에서 행동을 최적화할 수 있도록 도와준다. multi-horizon forecasting은 소매, 의료 및 경제 분야에서 많이 응용되고 있기에 기존 방법의 성능을 개선하는 문제가 매우 중요하다.

 

practical multi-horizon forecasting application은 Fig.1에 나와 있는 것처럼 다양한 데이터 소스에 적용될 수 있다. 이러한 데이터 소스에는 미래에 대해 알려진 정보 (ex: 다가오는 휴일 날짜), 기타 외부 time series 데이터 (ex: 과거 고객의 데이터), 그리고 static metadata (ex: 상점의 위치)가 포함된다. 그러나 이러한 데이터들이 어떻게 상호작용하는지에 대한 사전 지식이 없는 경우도 있다. 이러한 데이터 소스의 다양성과 상호작용에 대한 정보의 부족으로 인해 multi-horizon time series forecasting은 특히 어려운 문제이다.

 

Figure 1: Illustration of multi-horizon forecasting with static covariates.



Deep neural network (DNN)는 multi-horizontal forecasting에서 점점 더 많이 사용되고 있으며, 기존의 time series model에 비해 강력한 성능 개선을 보여주고 있다. 많은 아키텍처들이 recurrent neural network (RNN) 아키텍처의 변형에 초점을 맞추었지만, 최근의 개선 사례들은 transformer-based model을 포함한 attention-based method를 사용하기도 한다. 그러나 이러한 접근법은 종종 multi-horizon forecasting에서 input의 다양성을 고려하지 않으며 복잡한 nonlinear interaction에 의해 예측이 제어되는 'black-box' model이라는 문제점이 있다. 그렇기 때문에 multi-horizon forecasting의 데이터 다양성에 대응하여 높은 성능을 내고 해석 가능하게 만들기 위한 새로운 방법이 필요하다.

 

본 논문에서는 Temporal Fusion Transformer (TFT)라는 multi-horizontal forecasting을 위한 attention-based DNN 아키텍처를 제안한다. TFT는 static covariate encoder, gating mechanism, sequence-to-sequence layer, temporal self-attention decoder를 포함하는 아키텍처로, 높은 성능을 달성하면서 새로운 형태의 input의 해석을 가능하게 한다. 특히 TFT는 사용자로 하여금 (i) prediction problem의 globally-important variables, (ii) persistent temporal patterns, (iii) significant events를 구할 수 있도록 도와준다.

 

2. Related Work

DNNs for Multi-horizon Forecasting: 최근의 deep learning method는 iterated approach와 direct method로 분류된다. iterated approach는 one-step-ahead prediction을 수행하고, 이를 통해 얻어진 prediction을 future input으로서 반복적으로 전달하여 multi-step prediction을 얻는 방법이다. Deep AR 및 Deep State-Space Model (DSSM)과 같이 Long Short-term Memory (LSTM)을 기반으로 한 방법이 이 범주에서 연구되었다. 그러나 이러한 방법들은 예측 시점에서 목표를 제외한 모든 변수가 알려진 것으로 가정하기 때문에, 실제 시나리오에서의 활용이 제한된다.

 

반면, direct method는 sequence-to-sequence model을 사용하여 미리 정의된 horizon에 대한 예측을 생성하는 방법으로, Multi-horizon Quantile Recurrent Forecaster (MQRNN)와 같은 방법이 사용된다. LSTM-based iterative method보다 성능이 뛰어나지만, 여전히 interpretability의 문제가 남아있다. 이에 반해 TFT는 attention pattern을 해석함으로써 temporal dynamics에 대한 설명을 제공하며, 다양한 데이터셋에서 좋은 성능을 유지할 수 있다.

 

 

Time Series Interpretability with Attention: attention mechanism은 translation, image classification, 또는 tabular learning에서 attention weight을 적용하여 중요한 부분을 식별하는 데 사용된다. 최근에는 time series에 attention mechanism이 적용되고 있으며, LSTM 기반과 transformer 기반 아키텍처를 사용한다. 그러나 이러한 방법들은 static covariate의 중요성을 간과하여 적용되었다. TFT는 이를 해결하기 위해 각 time step에서 static feature에 대해 별도의 encoder-decoder attention을 사용한다.

 

Instance-wise Variable Importance with DNNs: 샘플별 변수 중요도를 얻는 방법으로는 post-hoc explanation method와 inherently interpretable model이 있다. LIME, SHAP, RL-LIM과 같은 방법은 post-hoc explanation method로, pre-trained black-box model에 적용된다. 이러한 방법은 입력 데이터의 시간 순서를 고려하지 않아 복잡한 time series에 대해 효과적이지 않을 수 있다. 반면 inherently-interpretable modeling은 feature selection을 아키텍처에 포함한다. time series forecasting에 대해 이러한 접근 방법은 time-dependent variable의 기여도를 측정한다. 예를 들어, interpretable multi-variable LSTM은 hidden state를 분할하여 변수의 기여도를 결정한다. 기존 방법들이 샘플별 해석에 초점을 맞추는 반면, TFT는 전체적인 time series의 관계를 분석하고 전체 데이터셋에서 지속적인 패턴과 규칙을 찾아 모델을 해석할 수 있다.

 

3. Multi-horizon Forecasting

주어진 time series 데이터셋에 $I$개의 개체가 있다고 하자. 이러한 개체들은 소매업에서의 다른 상점이나 의료 분야에서의 환자들과 같은 것들을 의미한다. 각 개체 $i$는 static covariate $\mathbf{s}_i \in \mathbb{R}^{m_s}$와 time step $t \in [0, T_i]$에 대해 input $\chi_{i, t} \in \mathbb{R}^{m_\chi}$ scalar targets $y_{i, t} \in \mathbb{R}$의 값을 갖는다. Time-dependent input feature $\chi_{i, t}=[\mathbf{z}_{i, t}^T, \mathbf{x}_{i, t}^T]^T$로 나눠진다. 이때 $\mathbf{z}_{i, t} \in \mathbb{R}^{m_z}$는 observed input으로 각 step에서 측정되어 그 전에는 모르는 값이고, $\mathbf{x}_{i, t} \in \mathbb{R}^{m_x}$는 known input으로 predetermined되는 값이다.

 

prediction interval은 target이 가질 수 있는 best and worst-case value를 알려주기 때문에 이를 제공하는 것은 decision을 최적화하고 risk를 관리하는 데 도움이 될 수 있다. 따라서 우리는 multi-horizon forecasting setting에 quantile regression을 적용한다. quantile forecast의 식은 다음과 같다.

 

$\hat{y}_i(q, t, \tau)$는 time $t$로부터 $\tau$ step 이후의 값을 예측한 범위에서 $q^{th}$ sample quantile의 값을 의미한다. direct method에서는 $\tau_{max}$가 정해져 있고 $\tau \in \{1, \cdots, \tau_{max}\}$이다. $k$는 look-back window로 과거의 정보를 얼마나 확인할지 결정해주는 변수이다.