When many big industry labs turn to secrecy about their LLM work, not just the LLM weights but even the science, Meta took the opposite approach, Kudos to them! In fact, the Llama 2 paper is so rich in details, and given the high-quality results, it could become the go-to recipe for LLM and RLHF for many in the next while.
This blog post is about all the interesting details that stood out to me personally. I also try to briefly explain any less-common background techniques using plain language, so that hopefully you can get the gist of ideas without going down the rabbit hole of paper links.
Model
Transformer (duh?):
Pre-LayerNorm: standard trick nowadays, normalization before multi-head attention (MHA) and feedforward net (FFN) of transformer block, instead of after additions of skip-connections. In the forward pass, there's essentially a path through skip-connections of all the layers unobstructed by layer-norm. In the backward pass, the gradient norm is proportional to square-root of overall depth. This makes training more stable.
RMSNorm: instead of subtracting by mean and scaling by standard deviation like in layernorm, just scale by root mean square (uncentered second moment). Faster without performance loss.
SwiGLU activation: standard trick nowadays; gated linear unit (GLU) with Swish activation for the gate; GLU is essentially two linear layers with one going through a nonlinear gate and element-wise multiply with the (pure) linear layer; Swish is a nonlinearity that looks like smooth ReLU with a weird non-monotonic bend in the negative range. Justification is mostly empirical but the intuition is that it improves gradient flow in the small negative value inputs. Its first derivative looks like a sigmoid that over-shoots and then bends back (hence also non-monotonic), and its second derivative looks like a bell curve but overshoots a bit coming down the peak and bends back (hence again non-monotonic).
Rotary positional embeddings (RoPE): a way to position-encode word embeddings so that the dot-product of features approximates the relative distance of words, which is clearly a useful inductive bias. Classical sinusoidal position embeddings don't have this property. The trick is to multiply with a rotary matrix that is a block diagonal matrix whose each block is a 2D rotation matrix with a predefined angle. The positions enter as wavelengths that stretch the angles. Classical sinusoidal position embeddings have the same setup of wavelengths and angles but are simply added to word embeddings.
Grouped-query attention (GQA): a parameter-sharing trick to reduce parameter, memory and compute, without sacrificing much on model performance; instead of having separate query, key, and value heads for each head in MHA, GQA groups the query heads so that each group shares the same key and value head. Used in 34B and 70B variants to scale inference.
Context length is 4096 (up from 2048 for LLama 1)
tokenizer:
bytepair encoding; numbers are split into individual digits and unknown UTF-8 characters split into bytes
vocabulary size is 32k tokens
Pre-training
Pre-training data:
2 trillion tokens, trained for 1 epoch only
no Meta user data; most factual sources (hmm, which ones?) upsampled more; "certain sites known to contain a high volume of personal information about private individuals" are excluded; no additional filtering at pre-training stage, e.g. hate-speech, which is intentional so that models more responsive to safety fine-tuning later. It's interesting and intuitive that to avoid being "evil", you need to have seen what's "evil" like.
no further detail on pre-training dataset mix
based on the good overall performance but relative poor coding performance, I suspect some major code sources are not included or not completely included due to licensing/legal reasons, as Meta's goal is to commercially release LLama 2
Pre-training hyp settings are standard: AdamW(0.9, 0.95, eps=1e-5, weight_decay=0.1), grad clipping of 1.0, cosine lr schedule with 2000 warmup steps; lr of 3e-4 for smaller variants and 1.5e-4 for larger variants that uses GQA; global batch size of 4M tokens
not in the paper, but I'm guessing they are using 2~4K A100 GPUs to train the 70B model;
Not written clearly, but probably using fully-shared data-parallel (FSDP) training
Pre-training model performance:
7B model's perplexity is about 1.75 after training on 2T tokens and 70B about 1.5
Far better than other open-source pre-trained LLMs; on-par with GPT-3.5 and PaLM; compared to the best closed-source ones (GPT4 and Palm2-L), fairly close in all evaluations except on coding, math and hard reasoning, where Llama 2 is much worse. My guess it’s due to the pre-training data set mix caused by potential licensing issues.
Fine-tuning
Overall is supervised (instruction) fine-tuning (SFT) followed by RLHF
SFT data:
from prior public fine-tuning data: 1.8K diverse tasks
27,540 high-quality annotations from vendor-based annotation
found smaller but higher quality SFT data is more important than larger but low-quality ones
SFT details:
2e-5 (~1 order of magnitude smaller than pre-training); same weight decay of 0.1 and a batch size of 64 and context length of 4096, fine-tuned for 2 epochs
because instruction annotations aren't always long, to fully utilize the context length, they use the trick of concatenating prompts and answers from the training set, separated by special tokens. Loss is zeroed out on prompt tokens and only backprop'ed on answered tokens.
RLHF:
Reward model(s):
novel bit: 2 separate reward models, one for safety and one for helpfulness
This prevents a single reward model's confusion about often-conflicting objectives
pairwise preference data collected from humans, with 4 degrees of preference strengths
Initialized from SFT'ed chat model
Trained with margin-based pairwise ranking loss, margin scaled to preference strength
Initially trained for 1 epoch over training data, longer leads to overfitting
Then iteratively improve with RLHF-version 1 to 5:
if version <= 4:
RL-Method = rejection sampling
else:
RL-Method = rejection sampling training followed by PPO training
Llama 2-Chat model fine-tuning with RLHF with reward models using RL-Method
Sampling from the latest Llama 2-Chat and human-annotate preferences
Reward model fitting
The iterative setup ensures that the reward model is always on-distribution
Reward dataset ~3M annotations, about half from open sources (which allows bootstrapping)
Reward model training using the same optimizer setting but max lr = 5e-6 for 70B and 1e-5 for the others, which is 1/2 or 1/4 of SFT lr.
Reward model accuracy is "one of the most important proxies for final performance of LLama2-chat"
RL-Method:
Rejection sampling:
Only sample from the largest model 70B, so there's a distillation effect
Each prompt is sampled with K answers, and the best answer according to the reward models at the time is kept for training
To avoid catastrophic forgetting, used a replay buffer with the best answers from previous chat model iterations
PPO:
The safety and helpfulness rewards are combined by a piece-wise linear function: below a threshold, it's all safety, above, it's all helpfulness. They also "whiten" the final linear scores. Not clear what they mean by "whiten" since it's a scalar reward per data point, rather than a vector. So perhaps just mean and std normalization over the samples?
Optimization hyp for RL :
Standard AdamW, lr=1e-6, which is 1/5~1/10 of reward model lr
PPO batch size of 512 and mini-batch size of 64, PPO clip threshold of 0.2, KL penalty of 0.01 for 7B and 13B, 0.005 for 34B and 70B
FSDP makes training fast but generation slow, so model weights are consolidated to each node for generation, then freed after generation before training resumes.
Ghost Attention (GAtt)
A novel trick to force the model to respect system messages; used from RLHF-V3 onward
During generation for RLHF, concatenate the system instruction to the user message of each turn (rather than just the first turn), before generation, so that the generation respects the system instruction. For training, drop the system instructions in all but the first turn.
RLHF results
PPO in RLHF-V5 definitely helps beyond rejection sampling
Human evaluation shows it’s competitive with ChatGPT (actually slightly better) and significantly better than other open-source chat models
Safety
Authors put lots of work into safety measures during data collection, training and analysis. The information takes up almost half of the paper. If there is enough interest, I'll write a separate post for it.