Yongil Kim

Seoul National University

[ACL2023] ReAugKD: Retrieval-Augmented Knowledge Distillation For Pre-trained Language Models

18 Mar 2024 » Retrieval, LLM, PLM

[pdf][github]

Jianyi Zhang1, Aashiq Muhamed2, Aditya Anantharaman2, Guoyin Wang2, Changyou Chen3, Kai Zhong2, Qingjun Cui2, Yi Xu2, Belinda Zeng2, Trishul Chilimbi2, Yiran Chen1
1 Duke University, 2 Amazon 3 University at Buffalo, SUNY

image

Abstract

  • (Knowledge Distillation) Large-scale pre-trained LM 을 작은 모델로 distillation 하는 연구가 성행하고 있다. 기존의 KD 접근 방식은 teacher model 의 soft label 과 intermediate activation 을 trasnfer learning 하여 student model 을 학습시킨다.
  • ( ReAugKD ) 이 논문에서는 teacher 의 soft model 에 더불어 kowledge base 형태의 non-parametric memory 를 같이 활용할 경우 더 좋은 generalization 성능을 보이는 distillation 방법을 제안한다. Student 모델로 하여금 knowledge base 를 효과적으로 retrieve 하는 이 ReAugKD framework 은 teacher 와 student embedding space 에서 relational knowledge 를 align 하는 loss 로 학습한다.
  • (Experiment) 실험 결과, GLUE benchmark 에 대해 State-of-the-Art 성능을 보인다.

1. Introduction

▶ Knowledge Distillation (KD)
BERT, RoBERTa, Electra 등의 LM 이 좋은 성능을 보이지만, 이 것들은 M ~ B 단위의 param 을 가지고 있어 제한된 환경에서 가동이 힘들다. 이에 성능 좋은 위의 모델들은 teacher model 로 하고, param 수가 더 적은 student model 로 지식을 전달하는 knowledge distillation (KD) 연구가 활발하다. 기존의 KD 모델들은 typically student param 속의 지식과 teacher 의 output prediction 의 divergence 를 최소화 하는 방식으로 학습한다. 이러한 단순한 KD 방법은 student 모델의 작은 param 떄문에 어느정도 한계점이 있다. 특히 LLM 에서 많이 나타나는 task-specific knowledge 를 distill 하여 학습하기는 힘들다.

▶ Retrieval-Augmented Knowledge Distillation (ReAugKD)
저자들은 이 문제를 해결하기 위해 Retrieval-Augmented Knowledge Distillation (ReAugKD) 방법론을 제안한다. ReAugKD 방법은 implicit parametric memory 에 더하여 non-parametric external memory 를 가져와서, kNN retrieval 을 통해 retrieve 를 한다. Key intuition 은 teacher 의 task-specific knowledge 로부터 가져올 수 있는 external memory 를 studnet 모델이 활용할 수 있는 능력을 갖추게 하는 것이다.

▶ Experiment
실험 결과, GLUE benchmark 에서 State-of-the-Art 를 달성했으며, retrieval 을 하지 않은 방법보다 단 3% 의 latency overhead 만 존재함을 보인다. 또한, ReAugKD 방식을 통한 학습이 student model 의 generalization 성능 향상을 이끄는 것을 확인한다.

2. Methodology

ReAugKD 방법은 두 개의 main phase : Training phase 와 Inference phase 가 존재한다.

image

2.1. Training Phase

Training phase 는 두 개의 step 이 존재한다.

첫 번째 step 은 sepcific downstream task 에 finetuned 된 teacher model 에 linear projection head $L$ 을 붙인다. 이 projection 의 input dimension 의 teacher embedding dim 이고, output dimension 의 student embedding dim 이다. Teacher model 의 다른 param 은 freeze 하고, head $L$ 의 param 만 supervised contrastive loss 를 통해 학습시킨다.

두 번쨰 step 은 teacher embedding with head $L$ 과 teacher osft label 을 가지고 Knowledge Distillation 을 하는 것이다.

2.2. Loss Function

Notations

  • $N$: batch
  • $x_i$ : student embedding
  • $y_i$ : student prediction
  • $\hat{y_i}$ : teacher soft label
  • $z_i$ : teacher’s prediction head

$z_i$ 와 anchor $z_j$ 의 similarity distribution $q_{i,j}$ 는 아래와 같다.

image

$q_{i,j}$ 는 batch 속의 다른 embedding 들과의 relational knowledge 의 cosine distance 를 담고 있다고 해석할 수 있다.

아래의 $\hat{q_{i,j}}$ 는 teacher embedding 과 student embedding 사이의 similarity matrix 이다.

image

Loss function 은 두 distribution $q_{i,j}$ 과 $\hat{q_{i,j}}$ 의 divergence 를 줄이는 것으로 학습된다. 추가적으로, corss-entropy loss 로 distillation 학습을 진행하여 최종적인 Loss 는 아래와 같다.

image

2.3. Inference Phase

Teacher embedding 들과 prediction 들을 comprise 하여 Knowledge base (KB) 를 구성한다. 이후, HNSW 알고리즘을 활용하여 K-nearest neighbor (KNN) 방법을 활용한다. 즉 student 의 embedding 과 prediction ($x_i$, $y_i$) 를 토대로, 가장 비슷한 teacher embedding 과 prediction ($z_i$, $\hat{y_i}$) 을 KB 에서 KNN classifier 를 통해 retrieval 해온다. K 개의 결과를 retrieval 해온 후 Average 하여 아래의 weigthed average of soft label 을 얻고,

image

hyperparameter $\beta$ 를 통해 두 prediction 을 섞어준다.

image

3. Experimental Results

Experiment setting

  • Backbone model : BERT-base -> 6-layer BERT (768 dim)
  • Benchmark : GLUE
  • Baseline method : vanilla KD, TAKD, RCO, RKD, DML, PKD, ProKT, SFTN, MetaDistil

Experimental Results on GLUE
image

  • 기존 SOTA 인 MetaDistil 을 0.34% 앞서는 SOTA 를 달성한다.
  • Metadistil 이 MRPC 에서는 더 좋지만, ReAugKD는 meta-learning 을 필요로 하지 않기 때문에 더 효율적이다.
  • Inference 단계의 retrieval 을 붙였을 때 0.37% 정도 성능향상이 있고, retrieval 을 하지 않아도 SOTA 급의 성능이다.

Number of Neighbors Retrieved (k)
image

  • Original inference time 에 비하여 3% 정도의 additional time overhead 가 있다. (CPU로만 했음에도)

Conclusion

In this paper, we present ReAugKD, a knowledge distillation framework with a retrieval mechanism that shows state-of-the-art performance on the GLUE benchmark. In the future, we plan to expand the knowledge base with more information from the teacher and extend it to additional tasks.

Limitations

Our method relies on having access to teacher embeddings and prediction which may not always be possible in a black-box distillation setting. Retrieval augmentation also requires maintaining a knowledge base that is memory intensive.
The cost of the retrieval process is dependent on the size of the training corpus, which can be a limitation when dealing with very large training datasets.
Conducting dataset distillation (Wang et al., 2018b) on the training corpus to further reduce memory cost and retrieval time is an important future step for our framework.