100行纯JAX代码,完美复现Llama 3!

B站影视 2025-02-24 22:04 1

摘要:作为一个基于 Transformer 架构的解码器,LLaMA 3 在计算效率和可扩展性方面进行了创新。而复现大模型有多难?在最新的技术探索中,开发者 Saurabh 利用纯 JAX 成功实现了 LLaMA 3 模型,展示了如何通过这一高效工具构建和训练强大的

作为一个基于 Transformer 架构的解码器,LLaMA 3 在计算效率和可扩展性方面进行了创新。而复现大模型有多难?在最新的技术探索中,开发者 Saurabh 利用纯 JAX 成功实现了 LLaMA 3 模型,展示了如何通过这一高效工具构建和训练强大的语言模型。

作者 | saurabh 责编 | 苏宓

出品 | CSDN(ID:CSDNnews)

在这篇文章中,我们将使用纯 JAX 从零实现 Llama3,只需 100 行代码。

为什么选 JAX?

它是由 Google 开发的高性能计算库,它专注于自动微分(Autograd)和加速计算,尤其适用于机器学习和科学计算。我觉得它的设计很美观。此外,JAX 看起来像是 NumPy 的一个封装,但它拥有一些很酷的特性,比如 XLA(一个线性代数加速器)、JIT、vmap、pmap 等,让你的训练速度飞快。

JAX 是最早专注于纯函数式编程理念的库之一,这也让它成为众人关注的焦点之一。

注意:

本文假设你熟悉 Python,并对 Transformer 架构有基本了解。

这个实现仅用于教学目的,不适用于生产环境,但它涵盖了模型的所有组件。

LLaMA 3

LLaMA 3 本质上是一个仅包含解码器的 Transformer 语言模型,它通过逐个生成 token 来预测下一个内容,就像一句话逐词补全的过程。

所以,冲就完事了!

首先,我们要配置设备并初始化模型参数:

# Configure JAX to use GPU and prevent memory preallocationos.environ['JAX_PLATFORM_NAME'] = 'gpu'os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'print("JAX devices:", jax.devices)

这里是我们训练一个约 200 万参数模型所需的超参数。

# Define model hyperparametersargs = ModelArgs( vocab_size=enc.n_vocab, # Size of vocabulary dim=256, # Embedding dimension n_layers=6, # Number of transformer layers n_heads=8, # Number of attention heads n_kv_heads=4, # Number of key-value heads for GQA max_seq_len=512, # Maximum sequence length norm_eps=1e-5 # Normalization epsilon)

模型权重初始化

在纯 JAX 中,我们不会像 PyTorch 那样使用类(如 nn.Module),而是只使用纯函数。为什么?因为这样代码更可预测,更容易并行化。纯函数的特点是:相同的输入,永远得到相同的输出,且不会产生任何副作用。比如 F(x),无论调用多少次,结果始终是 y。

但因为我们不像 PyTorch 那样有 nn.Module 自动管理参数,所以需要手动初始化和更新权重。

如何处理随机性?

JAX 里的随机数处理方式和 PyTorch、NumPy 不一样。在 NumPy 或 PyTorch 中,我们通常设置一个全局随机种子,然后就能复现结果。但在 JAX 中,每个随机操作都需要一个独立的随机数 key。我们用父 key 生成子 key,确保可复现性和并行计算的高效性。

比如,下面的代码展示了如何创建一个随机数 key,并拆分出多个子 key,然后将这些 key 传入涉及随机数的函数中。

# Generate and split random keys for reproducibilitykey = jax.random.PRNGKey(42)# Create a new subkey for random operationskey, subkey = jax.random.split(key)# Initialize random weights using the subkeyweights = jax.random.normal(subkey, (784, 512))

接下来开始初始化模型权重。第一步,我们使用正态分布随机初始化模型参数。

# Initialize weights with optional scalingdef init_weight(key, shape, scale=None): # Calculate default scale if none provided scale = 1.0 / math.sqrt(shape[0]) if scale is None else scale # Return scaled normal distribution return jax.random.normal(key, shape) * scale

然后,我们需要识别模型中所有可训练的参数(LLaMA 3 的),给每个参数分配一个唯一的随机数 key,以保证可复现性,并执行初始化操作。

由于模型的权重本质上就是数组,所以我们可以用字典来存储它们,键是参数名称,值是对应的数值。

具体来说,我们的模型由以下几部分组成:

1.注意力机制(Attention 模块) —— 4 个可训练参数。

# Initialize attention weights for multi-head attentiondef init_attention_weights(key, dim, n_heads, n_kv_heads): # Split key for each weight matrix keys = jax.random.split(key, 4) head_dim = dim // n_heads # Return dictionary of weight matrices return { 'wq': init_weight(keys[0], (dim, n_heads head_dim)), # Query weights 'wk': init_weight(keys[1], (dim, n_kv_heads head_dim)), # Key weights 'wv': init_weight(keys[2], (dim, n_kv_heads head_dim)), # Value weights 'wo': init_weight(keys[3], (n_heads head_dim, dim)) # Output projection }

2.前馈网络(Feed-forward 网络) —— 3 个可训练参数。

# Initialize feed-forward network weightsdef init_ffn_weights(key, dim): # Split key into three for each weight matrix keys = jax.random.split(key, 3) return { 'w1': init_weight(keys[0], (dim, 4 * dim)), # First projection 'w2': init_weight(keys[1], (4 * dim, dim)), # Output projection 'w3': init_weight(keys[2], (dim, 4 * dim)) # Gate projection }

3.Transformer Block —— 由多个部分组成,并额外包含两个 RMSNorm 层的参数。

# Initialize a complete transformer blockdef init_transformer_block(key, dim, n_heads, n_kv_heads): # Split key for each component keys = jax.random.split(key, 4) return { 'attention': init_attention_weights(keys[0], dim, n_heads, n_kv_heads), # Self-attention 'ffn': init_ffn_weights(keys[1], dim), # Feed-forward network 'attention_norm': init_weight(keys[2], (dim,), scale=1.0), # Pre-attention norm 'ffn_norm': init_weight(keys[3], (dim,), scale=1.0) # Pre-ffn norm }

最后,我们把所有的权重初始化代码整合到一起,完成模型的权重初始化。

# Initialize complete model parametersdef init_model_params(key, vocab_size, dim, n_layers, n_heads, n_kv_heads): # Split keys for different components keys = jax.random.split(key, 4) params = { 'token_embedding': init_weight(keys[0], (vocab_size, dim)), # Token embeddings 'norm_f': init_weight(keys[1], (dim,), scale=1.0), # Final normalization 'output': init_weight(keys[2], (dim, vocab_size)) # Output projection } # Initialize transformer blocks block_keys = jax.random.split(keys[3], n_layers) params['blocks'] = [ init_transformer_block(k, dim, n_heads, n_kv_heads) for k in block_keys ] return params

标记化(Tokenization)

标记化就是把文本拆分成单词或子词(Token),让计算机能更好地处理。我们在训练模型时会使用 Byte Pair Encoding(BPE) 进行分词,LLaMA 3 训练时也是用的 BPE。

不过,我们不会从头开始构建 BPE,而是直接用 OpenAI 的 tiktoken 库来完成分词。

import jax.numpy as jnpimport tiktoken# Load GPT-2 BPE encodingenc = tiktoken.get_encoding("gpt2")# reading a line from with open('../shakespeare.txt', 'r') as f: text = f.readlines[0] # Take the first line# Encode the text into token IDstokens = enc.encode(text)data = jnp.array(tokens, dtype=jnp.int32) # Store as JAX array# Decode back to textdecoded_text = enc.decode(tokens)print("original Text:", text.strip)print("encoded Tokens:", tokens)print("decoded Text:", decoded_text)## Ouput ### Original Text: From fairest creatures we desire increase,# Encoded Tokens: [220, 3574, 37063, 301, 8109, 356, 6227, 2620, 11, 198]# Decoded Text: From fairest creatures we desire increase,

嵌入(Embeddings)

模型不能直接处理标记化后的 token,因为 token 是离散的(Discrete),但神经网络只能处理连续的数值(Continuous),这样才能进行数学计算。因此,我们需要一个 Embedding 层,把 token 转换成向量,映射到一个连续的向量空间。词向量不仅能让模型理解单词之间的语义和句法关系,还可以提高模型的表达能力。

词向量有静态(Static)和动态(Dynamic)两种:

静态词向量。适用于寻找相似单词,比如上面第一张图,会把语义相近的词放在相似的向量空间里。

动态词向量。适用于大语言模型(LLM),因为静态词向量不能根据上下文调整含义,导致多义词(比如 “bank” 既可以指银行,也可以指河岸)在不同语境下的含义混淆。

为了解决这个问题,我们使用自注意力机制(Self-Attention),它会根据上下文动态调整词向量。模型初始化时,词向量是随机的,然后在训练过程中不断优化,使其能更准确地表达词语的含义。

# Converting the input tokens into embeddingsh = params["token_embedding"][inputs]# token_embedding is a matrix of shape (vocab_size, dim).# inputs are token IDs (integers).

RMS 归一化(Root Mean Square Layer Normalization)

RMS 归一化(RMSNorm)是 LLaMA 3 模型中的一个重要层,它的作用是防止训练过程中的数值过大或过小,从而让训练更加稳定。这在深度网络中尤为重要,因为如果数值失衡,模型可能会收敛得很慢,甚至不收敛。

# RMS Norm function for stabilizing trainingdef rms_norm(x, weight, eps=1e-5): # Calculate variance across last dimension variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) # Normalize and scale return x * weight * jnp.reciprocal(jnp.sqrt(variance + eps))

旋转位置编码( ROPE)

Transformer 本身无法理解 token 的顺序,所以需要额外加入位置信息。在 LLaMA 3 里,我们用 ROPE(旋转位置编码) 来解决这个问题。它的做法是:不是像传统 Transformer 那样添加额外的位置向量,而是直接对 Query 和 Key 向量进行旋转变换,让模型自然地感知位置信息。

ROPE 的工作原理:

1.预计算旋转因子:先创建一个旋转因子表,每个 token 根据自己的位置分配一个独特的旋转角度。

# Compute rotary position embeddingsdef precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): # Generate frequency bands freqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim)) # Generate position indices t = jnp.arange(end, dtype=jnp.float32) # Compute outer product freqs = jnp.outer(t, freqs) # Convert to complex exponential return jnp.complex64(jnp.exp(1j * freqs))

2.配对特征:把词向量重新排列,让每两个数值形成一对,可以想象成复数的实部和虚部。

3.执行旋转:将这些复数乘以预计算的旋转因子,使它们在复平面上旋转。

4.转换回去:把旋转后的复数拆分回原来的形状,恢复到模型可用的格式。

背后的数学原理:对于每对数值 (x₂ᵢ, x₂ᵢ₊₁),旋转后的计算公式如下:

其中,θᵢ 是该 token 的旋转角度。简单来说,ROPE 直接把位置信息编码进 token 的特征里,这样注意力机制(Attention)在计算时就能自然地感知 token 的顺序,无需额外的位置向量。

# Apply rotary embeddings to queries and keysdef apply_rotary_emb(xq, xk, freqs_cis): # Reshape inputs for complex multiplication xq_r, xk_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2)), jnp.reshape(xk, (*xk.shape[:-1], -1, 2)) # Convert to complex numbers xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1]) xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1]) # Reshape frequencies for broadcasting freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1])) # Apply rotation through complex multiplication xq_out = xq_complex * freqs_cis xk_out = xk_complex * freqs_cis # Convert back to real numbers and reshape xq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape) xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape) return xq, xk

分组查询注意力(GQA)

现在该讲注意力机制了!分组查询注意力(GQA) 是多头注意力(Multi-Head Attention,MHA) 的优化版本,它通过让多个查询头(Query Heads)共享键(Key)和值(Value) 来提高计算效率。

简单来说,普通的多头注意力是每个查询头都有自己的键和值,但 GQA 让多个查询头复用相同的键和值,这样可以减少计算量(因为不需要为每个查询头单独计算键和值)、降低内存占用(适用于大模型推理)、 加快推理速度(更高效的注意力机制)。本质上,GQA 还是自注意力(Self-Attention),只是做了一点优化调整,让它更适合大规模模型。

缩放点积注意力机制(Scaled Dot-Product Attention):

# Attention mechanism with grouped-query attentiondef attention(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0): # Get input dimensions B, T, C = x.shape head_dim = C // n_heads # Project inputs to queries, keys, and values q = jnp.dot(x, params['wq']).reshape(B, T, n_heads, head_dim) k = jnp.dot(x, params['wk']).reshape(B, T, n_kv_heads, head_dim) v = jnp.dot(x, params['wv']).reshape(B, T, n_kv_heads, head_dim) # Apply rotary embeddings q, k = apply_rotary_emb(q, k, freqs_cis[position:position + T]) # Handle cache for inference if cache is not None: k = jnp.concatenate([cache[0], k], axis=-1]) v = jnp.concatenate([cache[1], v], axis=-1]) new_cache = (k, v) # Repeat k/v heads for grouped-query attention k = repeat_kv(k, n_heads // n_kv_heads) v = repeat_kv(v, n_heads // n_kv_heads) # Compute attention scores and apply attention q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / math.sqrt(head_dim) # Apply attention mask if provided if mask is not None: scores = scores + mask[:, :, :T, :T] # Compute attention weights and final output scores = jax.nn.softmax(scores, axis=-1) output = jnp.matmul(scores, v) output = output.transpose(0, 2, 1, 3).reshape(B, T, -1) return jnp.dot(output, params['wo']), new_cache

KV 缓存(KV-cache):用于存储先前计算的键(K)和值(V)张量,以便在推理过程中复用。这种缓存机制可以显著提高推理效率。

if cache is not None: k = jnp.concatenate([cache[0], k], axis=-1) # Concatenate cached keys with new keys v = jnp.concatenate([cache[1], v], axis=-1) # Concatenate cached values with new valuesnew_cache = (k, v) # Create new cache with updated k/v pairs

前馈(Feed-forward)

这是一个带有 SiLU 激活函数的简单前馈网络(Feed-Forward Network,FFN)。

w3_ = jnp.dot(x, params['w3']) # SwiGLU(a,b)=SiLU(a)⊙b activated = jax.nn.silu(w3_) combined = activated * w1_ # Final output projection with w2 output = jnp.dot(combined, params['w2']) return output

Transformer-block

Transformer block 是所有关键组件组合在一起的地方。在这里,我们取出之前初始化好的权重,并分配给对应的层。

Transformer 块主要包含以下部分:

注意力机制(Attention):让模型关注不同位置的关键信息

归一化层(Normalization):保持数值稳定,防止训练失控

前馈网络(Feed-Forward):负责信息处理和特征转换

残差连接(Residual Connections):帮助梯度流动,避免信息丢失

简单来说,Transformer 块就像一个“信息处理单元”,每个块都按照这个流程运行,让模型更高效地理解和生成文本。

# Transformer block implementationdef transformer_block(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0): # Apply attention with normalization and residual connection attn_output, new_cache = attention( params['attention'], rms_norm(x, params['attention_norm']), mask, freqs_cis, n_heads, n_kv_heads, cache, position ) # First residual connection h = x + attn_output # Apply feed-forward network with normalization and residual ffn_output = feed_forward(params['ffn'], rms_norm(h, params['ffn_norm'])) # Second residual connection out = h + ffn_output return out, new_cache

前向传播(Forward-Pass)

前向传播就是把输入的数据从头到尾传递通过整个模型的过程,将输入的 tokens 转换为词向量(embeddings),再到通过多个 Transformer 块进行处理,最后输出结果。

换句话说,前向传播就是把所有层(词向量、Transformer、输出层)连接起来,最终生成模型的预测结果。

# Forward pass through the entire modeldef model_forward(params, inputs, config, cache=None, position=0): # Get batch dimensions B, T = inputs.shape # Convert input tokens to embeddings h = params['token_embedding'][inputs] # Compute freqs_cis for this forward pass freqs_cis = precompute_freqs_cis(config.dim // config.n_heads, config.max_seq_len) # Create causal mask mask = jnp.tril(jnp.ones((config.max_seq_len, config.max_seq_len))) mask = jnp.where(mask == 0, -1e9, 0.0) mask = mask.astype(h.dtype) mask = mask[None, None, :, :] # Process through transformer blocks new_caches = for i, block in enumerate(params['blocks']): layer_cache = cache[i] if cache is not None else None h, layer_cache = transformer_block( block, h, mask, freqs_cis, config.n_heads, config.n_kv_heads, layer_cache, position, training=False) new_caches.append(layer_cache) # Final normalization and output projection h = rms_norm(h, params['norm_f']) logits = jnp.dot(h, params['output']) return logits, new_caches

数据集(Dataset)

现在模型部分已经完成,接下来是时候在 Shakespeare dataset 上训练模型了。首先,我们会从 .txt 文件中读取数据,然后使用 BPE 对数据进行编码,最后将其转换为 JAX 数组。

# Initialize tokenizer and load dataenc = tiktoken.get_encoding("gpt2")# Read text filewith open('shakespeare.txt', 'r') as f: text = f.read# Convert text to token IDstokens = enc.encode(text)# Convert to JAX arraydata = jnp.array(tokens)

获取批次(Get Batches)

get_batch 函数从 Shakespeare dataset 创建训练批次。我们需要将数据分成小块,喂给模型。对于每个批次,我们会随机选择文本中的起始位置,这样模型就能看到多种上下文。

接下来,这里就是 JAX 中非常酷的 vmap 特性派上用场的地方。我们不需要写循环来提取每个数据块,而是使用 vmap 来自动化这个过程。

它是如何工作的呢?

vmap 就像是一个向量化的循环,它接收一个处理单个索引的函数(用 lax.dynamic_slice 提取一系列 tokens),然后将这个函数应用到数组中的每一个元素。这样,我们的输入序列(x)和对应的目标序列(y,目标序列是通过将原始序列右移一个 token 来实现下一个词预测)就能一次性创建出来。

def get_batch(key, data, batch_size, seq_len): # Generate random starting indices ix = random.randint(key, (batch_size,), 0, len(data) - seq_len) # Vectorized operation to get input and target sequences x = vmap(lambda i: lax.dynamic_slice(data, (i,), (seq_len,)))(ix) y = vmap(lambda i: lax.dynamic_slice(data, (i + 1,), (seq_len,)))(ix) return x, y

损失函数(Loss Function)

这个函数在训练过程中计算每个批次的交叉熵损失。它的工作流程如下:

1.首先,通过模型进行前向传播,生成输入数据的 logits(预测值)。

2.然后,将 logits 和目标值(targets)重新调整形状,把批次和序列维度合并在一起。

3.对 logits 应用 log softmax,得到每个 token 对应的对数概率。

4.最后,计算正确目标 token 的对数概率的负均值,作为最终的损失值。

交叉熵损失的定义如下:

然后,P(yᵢ) 是正确类别的概率,通过 softmax 函数计算得到:

# Compute cross-entropy lossdef compute_loss(params, batch): # Split batch into inputs and targets inputs, targets = batch # Forward pass to get logits logits, = model_forward(params, inputs, config) # Reshape for loss computation logits = logits.reshape(-1, config.vocab_size) targets = targets.reshape(-1) # Calculate negative log likelihood loss = -jnp.mean(jnp.take_along_axis(jax.nn.log_softmax(logits), targets[:, None], axis=1)) return loss

更新函数(Update function)

现在我们需要写一个函数来更新我们的权重。为了简单起见,这里我们使用随机梯度下降(SGD),不过你也可以选择 Adam 或 AdamW,它们通常能更快地收敛。

在代码中,你会看到 @jax.jit 装饰器。这是 JAX 的一个特色功能。JIT(即时编译) 通过将 Python 代码转化为优化过的机器码,从而加速执行。

它是如何工作的呢?

当你使用 JAX 的 jit 装饰一个函数时,它会改变函数的执行方式。通常,当你调用一个函数时,Python 会逐行执行它。例如,如果你有:

每次调用 sqr 函数时,它会打印 “HI jiited” 然后返回数字的平方。然而,当你添加 @jax.jit 装饰器后:

@jax.jitdef sqr(x): print("HI jiited") # side effect return x * xprint(sqr(2)) print(sqr(3)) print(sqr(4))

JAX 首先会对你的函数进行追踪,构建一个优化过的计算图。这个追踪过程发生在函数第一次被调用时,并将 Python 代码转化为机器码。

由于这个追踪过程,像 print 语句这样的副作用只会在第一次追踪时执行。函数一旦被编译,之后的调用都会使用编译后的版本,因此你可能不会每次都看到 print 输出。

@jax.jitdef update_step(params, batch): # Compute both loss and gradients in a single pass using value_and_grad # This is more efficient than computing them separately loss, grads = jax.value_and_grad(compute_loss)(params, batch) # Update parameters using gradient descent # jax.tree.map applies the update rule to each parameter in the model # The lambda function implements: p_new = p_old - learning_rate * gradient params = jax.tree.map( lambda p, g: p - config.learning_rate * g, params, grads ) # Return updated parameters and the loss value for monitoring training return params, loss

在我们的 update_step 函数中,@jax.jit 会编译代码。这个函数通过 jax.value_and_grad 同时计算损失和梯度,借助 jax.tree.map 使用梯度下降法更新参数,并返回更新后的参数和损失。

训练循环(Training-Loop)

最后,是时候在 Shakespeare dataset 上训练我们这个 200万参数的模型了。我们首先使用 get_batch 函数准备批次,它会将数据拆分成多个批次,这样我们就能在有限的计算资源下更快地进行训练。

for epoch in range(num_epochs): epoch_loss = 0.0 for step in range(steps_per_epoch): # Generate new random keys for reproducibility key, batch_key = random.split(key) # Sample random batch of sequences batch = get_batch(batch_key, data, config.batch_size, config.max_seq_len) # Forward pass, compute loss and update parameters params_state, loss = update_step(params_state, batch) # loss for epoch average epoch_loss += loss if step % 100 == 0: print(f"epoch {epoch + 1}, step {step}/{steps_per_epoch}: loss = {loss:.4f}") avg_epoch_loss = epoch_loss / steps_per_epoch epoch_losses.append(avg_epoch_loss) print(f"\nepoch {epoch + 1} | average loss: {avg_epoch_loss:.4f}")

来源:CSDN

相关推荐