Layer Norm Code
class LayerNorm(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.w = nn.Parameter(t.ones(cfg.d_model))
self.b = nn.Parameter(t.zeros(cfg.d_model))
def forward(self, residual):
# residual: [batch, position, d_model]
# output: [batch, position, d_model]
mean = residual.mean(dim=-1, keepdim=True)
variance = residual.var(dim=-1, keepdim=True,correction=0) + self.cfg.layer_norm_eps
residual = (residual-mean)/(variance**0.5)
return residual*self.w + self.b
Embedding
class Embed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
nn.init.normal_(self.W_E, std=self.cfg.init_range)
def forward(self, tokens):
# tokens: [batch, position]
# output: [batch, position, d_model]
return self.W_E[tokens]
Positional embedding
class PosEmbed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
nn.init.normal_(self.W_pos, std=self.cfg.init_range)
def forward(self, tokens):
# tokens: [batch, position]
# output: [batch, position, d_model]
batch, seq_len = tokens.shape
return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)
Self-Attention
Attention is a tricky code block. We will be using einsum to make our calculations easier.
class Attention(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
nn.init.normal_(self.W_Q, std=self.cfg.init_range)
nn.init.normal_(self.W_K, std=self.cfg.init_range)
nn.init.normal_(self.W_V, std=self.cfg.init_range)
nn.init.normal_(self.W_O, std=self.cfg.init_range)
self.scale = cfg.d_head**0.5
self.softmaxi = nn.Softmax(dim=-1)
self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device="cuda"))
def forward(self, normalized_resid_pre: t.Tensor):
# normalized_resid_pre: [batch, position, d_model]
# output: [batch, position, d_model]
# Calculate query, key and value vectors
## Get the query matrix
query_mat = einsum("batch position_q d_model, n_heads d_model d_head -> batch position_q n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
key_mat = einsum("batch position_k d_model, n_heads d_model d_head -> batch position_k n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K
val_mat = einsum("batch position_v d_model, n_heads d_model d_head -> batch position_v n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V
# Calculate the attention scores
atten_qk = einsum("batch position_q n_heads d_head , batch position_k n_heads d_head -> batch n_heads position_q position_k ", query_mat, key_mat)
atten = (self.apply_causal_mask(atten_qk/self.scale)).softmax(-1)
val_mat_res = einsum("batch position_v n_heads d_head , batch n_heads position_q position_v -> batch position_q n_heads d_head", val_mat,atten)
attn_out = einsum("batch position_q n_heads d_head, n_heads d_head d_model -> batch position_q d_model", val_mat_res, self.W_O) + self.b_O
return attn_out
def apply_causal_mask(self, attn_scores: t.Tensor):
# attn_scores: [batch, n_heads, query_pos, key_pos]
# output: [batch, n_heads, query_pos, key_pos]
# Define a mask that is True for all positions we want to set probabilities to zero for
mask = t.triu(t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
# Apply the mask to attention scores, then return the masked scores
attn_scores.masked_fill_(mask, self.IGNORE)
return attn_scores
MLP
class MLP(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
nn.init.normal_(self.W_in, std=self.cfg.init_range)
self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
nn.init.normal_(self.W_out, std=self.cfg.init_range)
self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
def forward(self, normalized_resid_mid):
# normalized_resid_mid: [batch, position, d_model]
# output: [batch, position, d_model]
ll1 = einsum("batch position d_model , d_model d_mlp -> batch position d_mlp",normalized_resid_mid,self.W_in) + self.b_in
act1 = gelu_new(ll1)
ll2 = einsum("batch position d_mlp , d_mlp d_model -> batch position d_model",act1,self.W_out) + self.b_out
return ll2