Facebook Pixel

Knowledge Distillation: Kỹ thuật truyền Tri Thức giữa các Mô hình AI

11 Mar, 2025

Knowledge Distillation là một kỹ thuật machine learning giúp chuyển giao kiến thức từ một mô hình phức tạp sang một mô hình đơn giản hơn

Knowledge Distillation: Kỹ thuật truyền Tri Thức giữa các Mô hình AI

Mục Lục

Knowledge Distillation là một công cụ giúp các developer triển khai các mô hình AI hiệu quả cao trên những thiết bị có tài nguyên hạn chế. Bằng cách "chưng cất" kiến thức từ mô hình lớn sang mô hình nhỏ, chúng ta có thể đạt được sự cân bằng tốt giữa hiệu suất và hiệu quả tài nguyên.

1. Knowledge Distillation là gì?

Knowledge Distillation (KD), hay còn gọi là "chưng cất kiến thức", là một kỹ thuật machine learning giúp chuyển giao kiến thức từ một mô hình phức tạp (gọi là "teacher") sang một mô hình đơn giản hơn (gọi là "student"). Ý tưởng chính là tạo ra những mô hình nhỏ gọn có thể chạy nhanh hơn trên các thiết bị có tài nguyên hạn chế mà vẫn giữ được phần lớn hiệu suất của mô hình lớn.

Điểm đặc biệt của KD nằm ở cách tri thức được truyền đạt.

  • Giáo viên (mô hình lớn) không chỉ đưa ra câu trả lời dạng nhãn cứng (hard label), ví dụ: "Đây là một con mèo."
  • Thay vào đó, giáo viên còn cung cấp thông tin chi tiết, chẳng hạn: "Đây 90% là con mèo, 8% có thể là chó, và 2% giống cáo" (soft label). Những con số này mang thông tin phong phú về mức độ tự tin (confident level) và mối quan hệ giữa các loại dữ liệu mà mô hình lớn đã học được.

Knowledge Distillation đã phát triển đáng kể trong những năm gần đây với nhiều kỹ thuật tiên tiến:

  1. Dual-Space Knowledge Distillation (DSKD): Phát triển năm 2024, DSKD giúp căn chỉnh đầu ra teacher và student trên cả không gian từ vựng và biểu diễn token, giúp vượt qua khoảng cách chuyển giao kiến thức.
  2. Feature-based Distillation: Không chỉ sử dụng đầu ra cuối cùng mà còn trích xuất kiến thức từ các tầng trung gian của mô hình teacher.
  3. Self-Distillation: Một mô hình có thể "tự chưng cất" bằng cách sử dụng các phiên bản trước của chính nó làm teacher.
  4. Distillation cho Large Language Models (LLMs): Áp dụng KD để làm nhỏ các mô hình ngôn ngữ lớn như GPT và BERT, giúp chúng chạy hiệu quả hơn trên các thiết bị di động

2. Vì sao Knowledge Distillation ra đời?

Các mô hình AI tiên tiến như GPT-4, BERT hay Vision Transformers ngày nay đạt hiệu suất rất cao, nhưng đi kèm với đó là kích thước khổng lồ với hàng tỷ tham số. Điều này tạo ra nhiều vấn đề:

  1. Khó triển khai: Mô hình lớn không thể chạy trên các thiết bị nhỏ gọn như điện thoại, IoT, hoặc máy tính thông thường.
  2. Chi phí tính toán cao: Mô hình lớn yêu cầu phần cứng đắt đỏ (GPU, TPU), làm tăng chi phí vận hành.
  3. Độ trễ cao: Thời gian xử lý của các mô hình lớn quá chậm, không đáp ứng được các ứng dụng thời gian thực.
  4. Rủi ro bảo mật và quyền riêng tư: Mô hình lớn thường phải chạy trên cloud, yêu cầu gửi dữ liệu nhạy cảm của người dùng lên máy chủ.

Knowledge Distillation ra đời như một giải pháp thông minh cho những vấn đề này. Geoffrey Hinton và đồng nghiệp đã giới thiệu khái niệm này vào năm 2015, với ý tưởng rằng: có thể "chưng cất" tinh hoa trí tuệ từ mô hình phức tạp vào một mô hình nhỏ gọn hơn.

3. Knowledge Distillation hoạt động như thế nào

Ý tưởng cốt lõi của Knowledge Distillation xuất phát từ Hinton và các cộng sự năm 2015. Cách thức hoạt động đơn giản như sau:

  1. Train mô hình teacher: Đầu tiên, bạn train một mô hình lớn, phức tạp (teacher) để đạt hiệu suất cao nhất có thể.
  2. Trích xuất "soft targets": Mô hình teacher tạo ra các xác suất đầu ra (soft probabilities). Ví dụ, khi nhận dạng một con chó, mô hình không chỉ cho biết "đây là con chó" mà còn cung cấp xác suất phân bố cho tất cả các lớp (ví dụ: chó: 0.9, mèo: 0.08, ngựa: 0.02).
  3. Train mô hình student: Mô hình student nhỏ hơn được huấn luyện để bắt chước không chỉ nhãn thật mà còn cả phân bố xác suất của mô hình teacher

4. Ví dụ dễ hiểu về Knowledge Distillation

Giả sử chúng ta có một mô hình lớn (ResNet50) đạt độ chính xác 99% khi nhận dạng chữ số, và chúng ta muốn tạo một mô hình nhỏ (3 lớp MLP) để chạy trên điện thoại.

Cách tiếp cận truyền thống:

  • Lấy dữ liệu và nhãn: (hình ảnh số 7, nhãn "7")
  • Huấn luyện mô hình nhỏ trực tiếp: "Đây là số 7"
  • Kết quả: Mô hình nhỏ đạt độ chính xác 96%

Cách tiếp cận Knowledge Distillation:

  1. Đưa ảnh số 7 vào mô hình lớn
  2. Mô hình lớn đưa ra logits: [0.1, 0.05, 0.02, 0.08, 0.06, 0.03, 0.05, 0.5, 0.06, 0.05]
  3. Áp dụng softmax với T=3 để tạo soft target: [0.06, 0.05, 0.04, 0.05, 0.05, 0.04, 0.05, 0.52, 0.05, 0.05]
  4. Dạy mô hình nhỏ: "Đây là số 7, VÀ nó có một chút giống số 0, số 1, số 3, ..."
  5. Kết quả: Mô hình nhỏ đạt độ chính xác 98%

Mô hình nhỏ học được thông tin tinh tế hơn, chẳng hạn:

  • Số 7 có thể nhầm với số 1 do cùng có nét thẳng
  • Số 7 khác hoàn toàn với số 8 (xác suất thấp)

5. Demo đơn giản về Knowledge Distillation

Dưới đây là một ví dụ đơn giản về cách thực hiện Knowledge Distillation bằng PyTorch:

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Định nghĩa mô hình Teacher (lớn hơn)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Định nghĩa mô hình Student (nhỏ hơn)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.fc1 = nn.Linear(16 * 16 * 16, 10)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = x.view(-1, 16 * 16 * 16)
        x = self.fc1(x)
        return x

# Hàm distillation loss
def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.5):
    # Softmax với temperature
    soft_targets = F.softmax(teacher_logits / temperature, dim=1)
    soft_prob = F.log_softmax(student_logits / temperature, dim=1)
    
    # KL Divergence loss
    distillation = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
    
    # Cross-entropy với nhãn thật
    ce_loss = F.cross_entropy(student_logits, labels)
    
    # Kết hợp hai loss
    return alpha * ce_loss + (1 - alpha) * distillation

# Train mô hình student
def train_student(teacher_model, student_model, train_loader, optimizer, device, temperature=3.0, alpha=0.5):
    teacher_model.eval()  # Set teacher to evaluation mode
    student_model.train() # Set student to training mode
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # Tắt gradient calculation cho teacher
        with torch.no_grad():
            teacher_output = teacher_model(data)
        
        # Forward pass cho student
        student_output = student_model(data)
        
        # Tính toán loss
        loss = distillation_loss(student_output, teacher_output, target, temperature, alpha)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

6. Kết luận

Nếu bạn đang phát triển ứng dụng AI và đang gặp vấn đề về kích thước mô hình hoặc tốc độ suy luận, Knowledge Distillation có thể là giải pháp bạn đang tìm kiếm. Hãy bắt đầu với những ví dụ đơn giản, hiểu nguyên lý cơ bản, rồi mở rộng đến những kỹ thuật tiên tiến hơn.

Bài viết liên quan

Đăng ký nhận thông báo

Đừng bỏ lỡ những bài viết thú vị từ 200Lab