Bài blog này sẽ hướng dẫn cách xây dựng mô hình ngôn ngữ sparse mixture of experts từ đầu. Nó được lấy cảm hứng và phần lớn dựa trên dự án ‘makemore’ của Andrej Karpathy và sử dụng một số thành phần có thể tái sử dụng từ cách triển khai đó. Giống như makemore, makeMoE cũng là một mô hình ngôn ngữ tự hồi quy ở cấp độ ký tự, nhưng sử dụng kiến trúc sparse mixture of experts đã nói ở trên. Phần còn lại của blog tập trung vào các yếu tố chính của kiến trúc này và cách chúng được triển khai. Mục tiêu của tôi là bạn có được hiểu biết trực quan về cách mọi thứ hoạt động sau khi đọc blog này và xem qua code trong repo.

  • Repo Github: https://github.com/AviSoori1x/makeMoE/tree/main
  • Bài viết gốc: https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch

Với sự ra mắt của Mixtral và những thông tin về việc GPT-4 có thể là một mô hình ngôn ngữ lớn mixture of experts, có một sự quan tâm đáng kể đến kiến trúc mô hình này. Tuy nhiên, trong các mô hình ngôn ngữ sparse mixture of experts, phần lớn các thành phần được chia sẻ với các transformer truyền thống. Bất kể sự đơn giản dường như, bằng chứng thực nghiệm cho thấy rằng sự ổn định trong quá trình huấn luyện là một trong những vấn đề chính với các mô hình này. Các triển khai quy mô nhỏ có thể tùy chỉnh như thế này có thể giúp thử nghiệm nhanh chóng các phương pháp mới.

Trong triển khai này, tôi thực hiện một vài thay đổi đáng kể so với kiến trúc makemore:

  • Sparse mixture of experts thay vì mạng nơ-ron feed forward đơn lẻ.
  • Triển khai top-k gatingnoisy top-k gating.
  • Khởi tạo - Khởi tạo Kaiming He được sử dụng ở đây, nhưng mục đích của notebook này là có thể tùy chỉnh để bạn có thể thay thế bằng khởi tạo Xavier/Glorot, v.v. và thử nghiệm.

Tuy nhiên, những điều sau đây không thay đổi so với makemore:

  • Tập dữ liệu, tiền xử lý (tokenization) và nhiệm vụ mô hình hóa ngôn ngữ mà Andrej đã chọn ban đầu - tạo văn bản giống như Shakespeare.
  • Triển khai casual self attention.
  • Vòng lặp huấn luyện.
  • Logic suy luận.

Let’s get started!

Các mô hình ngôn ngữ sparse mixture of experts, như dự đoán, phụ thuộc vào self-attention để hiểu ngữ cảnh. Chẳng bao lâu nữa, chúng ta sẽ khám phá những điểm phức tạp của khối mixture of experts. Trước tiên, hãy đi sâu vào self-attention để làm mới sự hiểu biết của chúng ta.

Understanding the intuition of Causal Scaled Dot Product Self Attention

Đoạn code được cung cấp minh họa cơ chế và các khái niệm cơ bản của self-attention, đặc biệt tập trung vào scaled dot product self-attention cổ điển. Trong biến thể này, ma trận query, keyvalue đều xuất phát từ cùng một chuỗi đầu vào. Để đảm bảo tính toàn vẹn của quá trình tạo ngôn ngữ tự hồi quy, đặc biệt là trong mô hình chỉ có decoder, code triển khai masking. Kỹ thuật masking này rất quan trọng vì nó che khuất mọi thông tin sau vị trí token hiện tại, do đó hướng sự chú ý của mô hình chỉ đến các phần trước của chuỗi. Cơ chế attention như vậy được gọi là causal self-attention. Điều quan trọng cần lưu ý là mô hình Sparse Mixture of Experts không bị giới hạn trong kiến trúc Transformer chỉ có decoder. Trên thực tế, phần lớn công việc quan trọng trong lĩnh vực này, đặc biệt là công trình của Shazeer et al, xoay quanh kiến trúc T5, bao gồm cả thành phần encoder và decoder trong mô hình Transformer.

#This code is borrowed from Andrej Karpathy's makemore repository linked in the repo.
The self attention layers in Sparse mixture of experts models are the same as
in regular transformer models

torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1) #B,T,T

v = value(x) #B,T,H
out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)
out.shape

torch.Size([4, 8, 16])

Code cho causal self-attentionmulti-head causal self-attention có thể được tổ chức như sau. Multi-head self-attention áp dụng nhiều attention head song song, mỗi head tập trung vào một phần riêng biệt của channel (chiều embedding). Multi-head self-attention về cơ bản cải thiện quá trình học và nâng cao hiệu quả huấn luyện mô hình do việc triển khai song song vốn có. Lưu ý rằng tôi đã sử dụng dropout trong suốt quá trình triển khai này để điều chỉnh, tức là ngăn ngừa overfitting.

#Causal scaled dot product self-Attention Head
n_embd = 64
n_head = 4
n_layer = 4
head_size = 16
dropout = 0.1

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

Multi-head self attention is implemented as follows:

#Multi-Headed Self Attention
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

Creating an Expert module i.e. a simple Multi Layer Perceptron

Trong kiến trúc Sparse Mixture of Experts (MoE), cơ chế self-attention bên trong mỗi khối transformer vẫn không thay đổi. Tuy nhiên, một thay đổi đáng chú ý xảy ra trong cấu trúc của mỗi khối: mạng nơ-ron feed-forward tiêu chuẩn được thay thế bằng một số mạng feed-forward được kích hoạt thưa thớt, được gọi là experts. “Kích hoạt thưa thớt” (Sparse activation) đề cập đến quá trình mà mỗi token trong chuỗi chỉ được định tuyến đến một số lượng hạn chế của các experts này - thường là một hoặc hai - trong tổng số nhóm có sẵn. Điều này giúp tăng tốc độ huấn luyện và suy luận, vì một số ít experts được kích hoạt trong mỗi lần truyền xuôi. Tuy nhiên, tất cả các experts phải nằm trong bộ nhớ GPU, do đó tạo ra các vấn đề triển khai thú vị khi tổng số lượng tham số đạt đến hàng trăm tỷ hoặc thậm chí hàng nghìn tỷ.

#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

Top-k Gating Intuition through an Example

Mạng gating, còn được gọi là router, xác định mạng expert nào nhận đầu ra cho mỗi token từ multi-head attention. Hãy xem xét một ví dụ đơn giản: giả sử có 4 experts, và token cần được định tuyến đến 2 experts hàng đầu. Ban đầu, chúng ta đưa token vào mạng gating thông qua một lớp tuyến tính. Lớp này chiếu tensor đầu vào từ hình dạng (2, 4, 32) — biểu diễn (Kích thước batch, Số lượng tokens, n_embed, trong đó n_embed là chiều channel của đầu vào) — đến một hình dạng mới (2, 4, 4), tương ứng với (Kích thước batch, Số lượng tokens, num_experts), trong đó num_experts là số lượng mạng expert. Tiếp theo, chúng ta xác định k=2 giá trị cao nhất và các chỉ số tương ứng của chúng dọc theo chiều cuối cùng.

#Understanding how gating works
num_experts = 4
top_k=2
n_embed=32


#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
top_k_logits, top_k_indices

#output:
(tensor([[[ 0.0246, -0.0190],
          [ 0.1991,  0.1513],
          [ 0.9749,  0.7185],
          [ 0.4406, -0.8357]],
 
         [[ 0.6206, -0.0503],
          [ 0.8635,  0.3784],
          [ 0.6828,  0.5972],
          [ 0.4743,  0.3420]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 3],
          [2, 1],
          [3, 1],
          [2, 1]],
 
         [[0, 2],
          [0, 3],
          [3, 2],
          [3, 0]]]))

Lấy đầu ra sparse gating bằng cách chỉ giữ lại k giá trị hàng đầu tại chỉ số tương ứng của chúng dọc theo chiều cuối cùng. Điền phần còn lại bằng ‘-inf’ và truyền qua hàm kích hoạt softmax. Điều này đẩy các giá trị ‘-inf’ về không, làm cho hai giá trị hàng đầu được nhấn mạnh hơn và tổng bằng 1. Tổng bằng 1 này giúp ích cho việc tính trọng số của đầu ra expert.

zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

#output
tensor([[[   -inf,    -inf,  0.0246, -0.0190],
         [   -inf,  0.1513,  0.1991,    -inf],
         [   -inf,  0.7185,    -inf,  0.9749],
         [   -inf, -0.8357,  0.4406,    -inf]],

        [[ 0.6206,    -inf, -0.0503,    -inf],
         [ 0.8635,    -inf,    -inf,  0.3784],
         [   -inf,    -inf,  0.5972,  0.6828],
         [ 0.3420,    -inf,    -inf,  0.4743]]], grad_fn=<ScatterBackward0>)

gating_output= F.softmax(sparse_logits, dim=-1)
gating_output


#ouput
tensor([[[0.0000, 0.0000, 0.5109, 0.4891],
         [0.0000, 0.4881, 0.5119, 0.0000],
         [0.0000, 0.4362, 0.0000, 0.5638],
         [0.0000, 0.2182, 0.7818, 0.0000]],

        [[0.6617, 0.0000, 0.3383, 0.0000],
         [0.6190, 0.0000, 0.0000, 0.3810],
         [0.0000, 0.0000, 0.4786, 0.5214],
         [0.4670, 0.0000, 0.0000, 0.5330]]], grad_fn=<SoftmaxBackward0>)

Generalizing and Modularizing above code and adding noisy top-k Gating for load balancing

# First define the top k router module 
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)
    
    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

Let’s test the functionality with some sample inputs:

#Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#And it works!!

#output
(torch.Size([2, 4, 4]),
 tensor([[[0.5284, 0.0000, 0.4716, 0.0000],
          [0.0000, 0.4592, 0.0000, 0.5408],
          [0.0000, 0.3529, 0.0000, 0.6471],
          [0.3948, 0.0000, 0.0000, 0.6052]],
 
         [[0.0000, 0.5950, 0.4050, 0.0000],
          [0.4456, 0.0000, 0.5544, 0.0000],
          [0.7208, 0.0000, 0.0000, 0.2792],
          [0.0000, 0.0000, 0.5659, 0.4341]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[0, 2],
          [3, 1],
          [3, 1],
          [3, 0]],
 
         [[1, 2],
          [2, 0],
          [0, 3],
          [2, 3]]]))

Mặc dù bài báo về Mixtral được phát hành gần đây không đề cập đến nó, tôi tin rằng Noisy top-k Gating là một công cụ quan trọng trong việc huấn luyện các mô hình MoE. Về cơ bản, bạn không muốn tất cả các token được gửi đến cùng một tập hợp các experts ‘ưu tiên’. Bạn muốn có một sự cân bằng tốt giữa khai thác (exploitation) và khám phá (exploration). Vì mục đích này, để cân bằng tải, việc thêm nhiễu chuẩn (standard normal noise) vào các logits từ lớp tuyến tính gating là hữu ích. Điều này làm cho quá trình huấn luyện hiệu quả hơn.

#Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)

    
    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

Let’s test this implementation out again

#Testing this out, again:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices
#It works!!

#output
(torch.Size([2, 4, 8]),
 tensor([[[0.4181, 0.0000, 0.5819, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4693, 0.5307, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4985, 0.5015, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.2641, 0.0000, 0.7359, 0.0000, 0.0000]],
 
         [[0.0000, 0.0000, 0.0000, 0.6301, 0.0000, 0.3699, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.4766, 0.0000, 0.0000, 0.0000, 0.5234],
          [0.0000, 0.0000, 0.0000, 0.6815, 0.0000, 0.0000, 0.3185, 0.0000],
          [0.4482, 0.5518, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 0],
          [1, 0],
          [2, 1],
          [5, 3]],
 
         [[3, 5],
          [7, 3],
          [3, 6],
          [1, 0]]]))

Creating a sparse Mixture of Experts module

Khía cạnh chính của quá trình này liên quan đến đầu ra của mạng gating. Sau khi có được các kết quả này, k giá trị hàng đầu được nhân một cách chọn lọc với đầu ra từ các experts hàng đầu tương ứng cho một token nhất định. Phép nhân chọn lọc này tạo thành tổng trọng số, tạo thành đầu ra của khối SparseMoe. Phần quan trọng và đầy thách thức của quá trình này là tránh các phép nhân không cần thiết. Điều cần thiết là chỉ thực hiện các lần truyền xuôi cho các experts top_k và sau đó tính tổng trọng số này. Thực hiện các lần truyền xuôi cho mỗi expert sẽ làm mất mục đích sử dụng sparse MoE, vì nó sẽ không còn thưa thớt nữa.

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

Việc kiểm tra bằng các đầu vào mẫu xem cách triển khai trên có hoạt động hay không là hữu ích. Sau khi chạy đoạn code sau, chúng ta có thể thấy nó hoạt động!

import torch
import torch.nn as nn

#Let's test this out
num_experts = 8
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)

Shape of the final output: torch.Size([4, 8, 16])

Để nhấn mạnh, điều quan trọng là phải nhận ra rằng độ lớn của đầu ra top_k experts từ mạng Router/gating, như được minh họa trong code trên, cũng rất quan trọng. Các chỉ số top_k này xác định các experts nào được kích hoạt, và độ lớn của các giá trị trong các chiều top_k đó xác định trọng số tương ứng của chúng. Khái niệm về tổng trọng số này được nêu bật hơn nữa trong sơ đồ bên dưới.

Putting it all together

Multi-head self-attentionsparse mixture of experts được kết hợp để tạo thành một khối transformer sparse mixture of experts. Giống như trong một khối transformer vanilla, các kết nối tắt (skip connections) được thêm vào để đảm bảo quá trình huấn luyện ổn định và tránh các vấn đề như gradient biến mất (vanishing gradient). Ngoài ra, layer normalization được sử dụng để ổn định hơn nữa quá trình học.

#Create a self attention + mixture of experts block, that may be repeated several number of times 
class Block(nn.Module):
    """ Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """

    def __init__(self, n_embed, n_head, num_experts, top_k):
        # n_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x

Finally putting it all together to crease a sparse mixture of experts language model

class SparseMoELanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Khởi tạo (Initialization) rất quan trọng để huấn luyện hiệu quả các mạng nơ-ron sâu. Khởi tạo Kaiming He được sử dụng ở đây vì có sự hiện diện của các kích hoạt ReLU trong các experts. Vui lòng thử nghiệm với khởi tạo Glorot, thường được sử dụng hơn trong các transformers. Phần 2 của Fastai của Jeremy Howard có một bài giảng xuất sắc triển khai chúng từ đầu: https://course.fast.ai/Lessons/lesson17.html. Người ta lưu ý trong tài liệu rằng khởi tạo Glorot thường được sử dụng trong các mô hình transformer, vì vậy đây là cơ hội để có thể cải thiện hiệu suất mô hình.

def kaiming_init_weights(m):
    if isinstance (m, (nn.Linear)): 
        init.kaiming_normal_(m.weight)

model = SparseMoELanguageModel()
model.apply(kaiming_init_weights)

Tôi đã sử dụng mlflow để theo dõi và ghi lại các số liệu quan trọng và các siêu tham số huấn luyện. Vòng lặp huấn luyện tôi đã trình bày ở đây bao gồm code này. Nếu bạn thích chỉ huấn luyện mà không sử dụng mlflow, các notebook trong repo makeMoE trên github có các khối code không có MLFlow. Cá nhân tôi thấy rất tiện lợi khi theo dõi các tham số và số liệu, đặc biệt là khi thử nghiệm.

#Using MLFlow
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#mlflow.set_experiment("makeMoE")
with mlflow.start_run():
    #If you use mlflow.autolog() this will be automatically logged. I chose to explicitly log here for completeness
    params = {"batch_size": batch_size , "block_size" : block_size, "max_iters": max_iters, "eval_interval": eval_interval,
              "learning_rate": learning_rate, "device": device, "eval_iters": eval_iters, "dropout" : dropout, "num_experts": num_experts, "top_k": top_k }
    mlflow.log_params(params)
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            metrics = {"train_loss": losses['train'], "val_loss": losses['val']}
            mlflow.log_metrics(metrics, step=iter)


        # sample a batch of data
        xb, yb = get_batch('train')

        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

8.996545 M parameters
step 0: train loss 5.3223, val loss 5.3166
step 100: train loss 2.7351, val loss 2.7429
step 200: train loss 2.5125, val loss 2.5233
.
.
.

step 4999: train loss 1.5712, val loss 1.7508

Ghi lại train lossvalidation loss cho bạn một dấu hiệu tốt về quá trình huấn luyện diễn ra như thế nào. Biểu đồ cho thấy rằng tôi có lẽ nên dừng lại ở khoảng 4500 bước (khi validation loss tăng nhẹ).

Bây giờ chúng ta có thể tạo văn bản bằng mô hình này theo từng ký tự, một cách tự hồi quy. Đối với một mô hình ~9M tham số được kích hoạt thưa thớt, tôi không có gì để phàn nàn.

# generate from the model. Not great. Not too bad either
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

DUKE VINCENVENTIO:
If it ever fecond he town sue kigh now,
That thou wold'st is steen 't.

SIMNA:
Angent her; no, my a born Yorthort,
Romeoos soun and lawf to your sawe with ch a woft ttastly defy,
To declay the soul art; and meart smad.

CORPIOLLANUS:
Which I cannot shall do from by born und ot cold warrike,
What king we best anone wrave's going of heard and good
Thus playvage; you have wold the grace.
...

Tôi hy vọng lời giải thích này đã giúp xây dựng sự hiểu biết của bạn về kiến trúc mô hình Sparse Mixture of Experts và cách nó kết hợp với nhau.

Tôi đã tham khảo rất nhiều các ấn phẩm sau cho việc triển khai này:

Code được phát triển hoàn toàn trên Databricks bằng một A100 duy nhất. Nếu bạn đang chạy code này trên Databricks, bạn có thể масштабировать nó trên một cụm GPU lớn tùy ý mà không gặp vấn đề gì, trên nhà cung cấp đám mây bạn chọn. Tôi đã chọn sử dụng MLFlow (đã được cài đặt sẵn trong Databricks. Nó hoàn toàn là mã nguồn mở và bạn có thể cài đặt pip một cách dễ dàng ở những nơi khác) vì tôi thấy nó hữu ích để theo dõi và ghi lại tất cả các số liệu cần thiết. Điều này là hoàn toàn tùy chọn. Xin lưu ý rằng việc triển khai nhấn mạnh vào tính dễ đọc và khả năng tùy chỉnh so với hiệu suất, vì vậy có nhiều cách mà bạn có thể cải thiện điều này.

Với những điều đó, đây là một vài điều bạn có thể thử:

  • Làm cho module Mixture of Experts hiệu quả hơn. Tôi tin rằng có thể có những cải tiến đáng kể trong cách triển khai ở trên cho việc kích hoạt thưa thớt của các experts chính xác.
  • Thử các chiến lược khởi tạo mạng nơ-ron khác nhau. Nguồn tôi đã liệt kê (Fastai phần 2) là tuyệt vời.
  • Chuyển từ cấp độ ký tự sang tokenization cấp độ từ vựng con.
  • Thực hiện tìm kiếm siêu tham số Bayesian cho số lượng experts và top_k (số lượng experts được kích hoạt cho mỗi token). Điều này có thể được phân loại lỏng lẻo như là tìm kiếm kiến trúc nơ-ron.
  • Expert Capacity không được thảo luận hoặc triển khai ở đây. Nó chắc chắn đáng để khám phá.

Với lượng quan tâm đến mixture of experts và đa phương thức, cũng sẽ rất thú vị để xem những gì sẽ được phát triển tại giao điểm của cả hai. Chúc bạn hack vui vẻ!!

PS: Part 2 of this blog with the implementation of Expert Capacity for more efficient training is given here: https://huggingface.co/blog/AviSoori1x/makemoe2