diff --git a/notebook/decoder-only.svg b/notebook/decoder-only.svg new file mode 100644 index 0000000..e11e94c --- /dev/null +++ b/notebook/decoder-only.svg @@ -0,0 +1,21 @@ + + + + + + + + SoftmaxLinearAdd & NormFeedForwardEmbeddingInput contextPositionalEncodingNxAdd & NormMaskedMulti-HeadAttentionTokenizerSamplingOutput tokenResidualResidual \ No newline at end of file diff --git a/notebook/interactive-transformer.ipynb b/notebook/interactive-transformer.ipynb new file mode 100644 index 0000000..1a6bedf --- /dev/null +++ b/notebook/interactive-transformer.ipynb @@ -0,0 +1,2895 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5aeadcc-f3a0-451f-bdcb-3074df5f2363", + "metadata": {}, + "source": [ + "# The Interactive Transformer" + ] + }, + { + "cell_type": "markdown", + "id": "23bb2ecc-2822-4143-bfb3-a93de5d5b435", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9f1a32da-97bf-4ea4-9c9c-0709a500f382", + "metadata": {}, + "outputs": [], + "source": [ + "// Setup\n", + "const GPT2_EOT int32 = 50256\n", + "const delta = 1e-5" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "150162b5-1556-4663-9a66-9d4d8a28231a", + "metadata": {}, + "outputs": [], + "source": [ + "import \"math\"\n", + "\n", + "func Abs(x float32) float32 {\n", + "\tif x > 0 {\n", + "\t\treturn x\n", + "\t}\n", + "\treturn -x\n", + "}\n", + "\n", + "func Cosh(x float32) float32 {\n", + "\treturn float32(math.Cosh(float64(x)))\n", + "}\n", + "\n", + "func Exp(x float32) float32 {\n", + "\treturn float32(math.Exp(float64(x)))\n", + "}\n", + "\n", + "func Inf(sign int) float32 {\n", + "\treturn float32(math.Inf(sign))\n", + "}\n", + "\n", + "func Log(x float32) float32 {\n", + "\treturn float32(math.Log(float64(x)))\n", + "}\n", + "\n", + "func IsNaN(f float32) bool {\n", + "\treturn math.IsNaN(float64(f))\n", + "}\n", + "\n", + "func Pow(x, y float32) float32 {\n", + "\treturn float32(math.Pow(float64(x), float64(y)))\n", + "}\n", + "\n", + "func Sqrt(x float32) float32 {\n", + "\treturn float32(math.Sqrt(float64(x)))\n", + "}\n", + "\n", + "func Tanh(x float32) float32 {\n", + "\treturn float32(math.Tanh(float64(x)))\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "af8bb417-4277-44c2-b58d-b0c782d0af57", + "metadata": {}, + "source": [ + "# Data loading" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a3f465c6-a248-443b-a1d8-f1160e35446d", + "metadata": {}, + "outputs": [], + "source": [ + "import (\n", + "\t\"bytes\"\n", + "\t\"encoding/binary\"\n", + "\t\"errors\"\n", + "\t\"io\"\n", + ")\n", + "\n", + "const Int32ByteLen = 4\n", + "\n", + "type DataLoader struct {\n", + "\tfilename string\n", + "\tbatchSize int\n", + "\tseqLength int\n", + "\tcurrentPosition int64\n", + "\tfileSize int64\n", + "\tNumBatches int\n", + "\tdata []int32\n", + "\tdataAll []int32\n", + "}\n", + "\n", + "func NewDataLoader(filename string, batchSize, seqLength int) (*DataLoader, error) {\n", + "\tfile, err := os.Open(filename)\n", + "\tif err != nil {\n", + "\t\treturn nil, err\n", + "\t}\n", + "\treturn newDataLoader(file, batchSize, seqLength)\n", + "}\n", + "\n", + "func newDataLoader(file io.Reader, batchSize, seqLength int) (*DataLoader, error) {\n", + "\tdata, err := io.ReadAll(file)\n", + "\tif err != nil {\n", + "\t\treturn nil, err\n", + "\t}\n", + "\tsize := len(data)\n", + "\tif size < (batchSize*seqLength+1)*Int32ByteLen {\n", + "\t\treturn nil, errors.New(\"error: file size is too small for the batch size and sequence length\")\n", + "\t}\n", + "\tloader := &DataLoader{\n", + "\t\tbatchSize: batchSize,\n", + "\t\tseqLength: seqLength,\n", + "\t\tNumBatches: size / (batchSize * seqLength * Int32ByteLen),\n", + "\t\tdata: make([]int32, size/Int32ByteLen),\n", + "\t\tfileSize: int64(size / Int32ByteLen),\n", + "\t}\n", + "\tif err := binary.Read(bytes.NewReader(data), binary.LittleEndian, loader.data); err != nil {\n", + "\t\treturn nil, err\n", + "\t}\n", + "\treturn loader, nil\n", + "}\n", + "\n", + "func newDataLoaderFromInts(data []int32, batchSize, seqLength int) (*DataLoader, error) {\n", + "\tsize := len(data)\n", + "\tif size < (batchSize*seqLength + 1) {\n", + "\t\treturn nil, errors.New(\"error: file size is too small for the batch size and sequence length\")\n", + "\t}\n", + "\tloader := &DataLoader{\n", + "\t\tbatchSize: batchSize,\n", + "\t\tseqLength: seqLength,\n", + "\t\tNumBatches: size / (batchSize * seqLength),\n", + "\t\tdata: data,\n", + "\t\tfileSize: int64(size),\n", + "\t}\n", + "\treturn loader, nil\n", + "}\n", + "\n", + "func (loader *DataLoader) Reset() {\n", + "\tloader.currentPosition = 0\n", + "}\n", + "\n", + "func (loader *DataLoader) NextBatch() ([]int32, []int32, error) {\n", + "\tnextPos := loader.currentPosition + int64(loader.batchSize*loader.seqLength)\n", + "\tif nextPos+1 > loader.fileSize {\n", + "\t\tloader.Reset()\n", + "\t\tnextPos = loader.currentPosition + int64(loader.batchSize*loader.seqLength)\n", + "\t}\n", + "\t// don't x4 because we're indexing int32 not byte\n", + "\tinputs := loader.data[loader.currentPosition:nextPos]\n", + "\ttargets := loader.data[loader.currentPosition+1 : nextPos+1]\n", + "\tloader.currentPosition = nextPos\n", + "\treturn inputs, targets, nil\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7ad42eaa-817f-420d-a215-da5546f12879", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestDataLoader_NextBatch\n", + "=== RUN TestDataLoader_NextBatch/1char\n", + "=== RUN TestDataLoader_NextBatch/endOfFile\n", + "=== RUN TestDataLoader_NextBatch/seqLen4\n", + "=== RUN TestDataLoader_NextBatch/seqLen!=batchSize\n", + "--- PASS: TestDataLoader_NextBatch (0.00s)\n", + " --- PASS: TestDataLoader_NextBatch/1char (0.00s)\n", + " --- PASS: TestDataLoader_NextBatch/endOfFile (0.00s)\n", + " --- PASS: TestDataLoader_NextBatch/seqLen4 (0.00s)\n", + " --- PASS: TestDataLoader_NextBatch/seqLen!=batchSize (0.00s)\n", + "PASS\n" + ] + } + ], + "source": [ + "%test\n", + "func TestDataLoader_NextBatch(t *testing.T) {\n", + "\tzeroTo100 := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99}\n", + "\ttype want struct {\n", + "\t\treset bool\n", + "\t\tinput []int32\n", + "\t\ttarget []int32\n", + "\t\tcurrentPosition int64\n", + "\t}\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\tcontents []int32\n", + "\t\tfilename string\n", + "\t\tbatchSize, seqLen int\n", + "\t\twant []want\n", + "\t\twantNumBatches int\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"1char\",\n", + "\t\t\tcontents: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},\n", + "\t\t\tbatchSize: 1,\n", + "\t\t\tseqLen: 1,\n", + "\t\t\twantNumBatches: 10,\n", + "\t\t\twant: []want{\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{0},\n", + "\t\t\t\t\ttarget: []int32{1},\n", + "\t\t\t\t\tcurrentPosition: 1,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{1},\n", + "\t\t\t\t\ttarget: []int32{2},\n", + "\t\t\t\t\tcurrentPosition: 2,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\treset: true,\n", + "\t\t\t\t\tinput: []int32{0},\n", + "\t\t\t\t\ttarget: []int32{1},\n", + "\t\t\t\t\tcurrentPosition: 1,\n", + "\t\t\t\t},\n", + "\t\t\t},\n", + "\t\t},\n", + "\t\t{\n", + "\t\t\tname: \"endOfFile\",\n", + "\t\t\tcontents: []int32{0, 1, 2},\n", + "\t\t\tbatchSize: 1,\n", + "\t\t\tseqLen: 1,\n", + "\t\t\twantNumBatches: 3,\n", + "\t\t\twant: []want{\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{0},\n", + "\t\t\t\t\ttarget: []int32{1},\n", + "\t\t\t\t\tcurrentPosition: 1,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{1},\n", + "\t\t\t\t\ttarget: []int32{2},\n", + "\t\t\t\t\tcurrentPosition: 2,\n", + "\t\t\t\t},\n", + "\t\t\t\t{ // should loop back\n", + "\t\t\t\t\tinput: []int32{0},\n", + "\t\t\t\t\ttarget: []int32{1},\n", + "\t\t\t\t\tcurrentPosition: 1,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\treset: true,\n", + "\t\t\t\t\tinput: []int32{0},\n", + "\t\t\t\t\ttarget: []int32{1},\n", + "\t\t\t\t\tcurrentPosition: 1,\n", + "\t\t\t\t},\n", + "\t\t\t},\n", + "\t\t},\n", + "\t\t{\n", + "\t\t\tname: \"seqLen4\",\n", + "\t\t\tcontents: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},\n", + "\t\t\tbatchSize: 1,\n", + "\t\t\tseqLen: 4,\n", + "\t\t\twantNumBatches: 2,\n", + "\t\t\twant: []want{\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{0, 1, 2, 3},\n", + "\t\t\t\t\ttarget: []int32{1, 2, 3, 4},\n", + "\t\t\t\t\tcurrentPosition: 4,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{4, 5, 6, 7},\n", + "\t\t\t\t\ttarget: []int32{5, 6, 7, 8},\n", + "\t\t\t\t\tcurrentPosition: 8,\n", + "\t\t\t\t},\n", + "\t\t\t},\n", + "\t\t},\n", + "\t\t{\n", + "\t\t\tname: \"seqLen!=batchSize\",\n", + "\t\t\tcontents: zeroTo100,\n", + "\t\t\tbatchSize: 2,\n", + "\t\t\tseqLen: 4,\n", + "\t\t\twantNumBatches: 12,\n", + "\t\t\twant: []want{\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{0, 1, 2, 3, 4, 5, 6, 7},\n", + "\t\t\t\t\ttarget: []int32{1, 2, 3, 4, 5, 6, 7, 8},\n", + "\t\t\t\t\tcurrentPosition: 8,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\tinput: []int32{8, 9, 10, 11, 12, 13, 14, 15},\n", + "\t\t\t\t\ttarget: []int32{9, 10, 11, 12, 13, 14, 15, 16},\n", + "\t\t\t\t\tcurrentPosition: 16,\n", + "\t\t\t\t},\n", + "\t\t\t\t{\n", + "\t\t\t\t\treset: true,\n", + "\t\t\t\t\tinput: []int32{0, 1, 2, 3, 4, 5, 6, 7},\n", + "\t\t\t\t\ttarget: []int32{1, 2, 3, 4, 5, 6, 7, 8},\n", + "\t\t\t\t\tcurrentPosition: 8,\n", + "\t\t\t\t},\n", + "\t\t\t},\n", + "\t\t},\n", + "\t}\n", + "\tnewInt32Reader := func(data []int32) (io.Reader, int) {\n", + "\t\tvar b bytes.Buffer\n", + "\t\trequire.NoError(t, binary.Write(&b, binary.LittleEndian, data))\n", + "\t\treturn &b, b.Len()\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\treader, _ := newInt32Reader(tt.contents)\n", + "\t\t\tif tt.filename != \"\" {\n", + "\t\t\t\t_, err := os.Stat(tt.filename)\n", + "\t\t\t\tassert.NoError(t, err)\n", + "\t\t\t\tfile, err := os.Open(tt.filename)\n", + "\t\t\t\tassert.NoError(t, err)\n", + "\t\t\t\tdefer file.Close()\n", + "\t\t\t\treader = file\n", + "\t\t\t}\n", + "\t\t\tloader, err := newDataLoader(reader, tt.batchSize, tt.seqLen)\n", + "\t\t\tassert.NoError(t, err)\n", + "\t\t\tassert.Equal(t, tt.wantNumBatches, loader.NumBatches)\n", + "\t\t\tfor _, want := range tt.want {\n", + "\t\t\t\tif want.reset {\n", + "\t\t\t\t\tloader.Reset()\n", + "\t\t\t\t}\n", + "\t\t\t\tinput, target, err := loader.NextBatch()\n", + "\t\t\t\tassert.NoError(t, err)\n", + "\t\t\t\tassert.Equal(t, want.input, input)\n", + "\t\t\t\tassert.Equal(t, want.target, target)\n", + "\t\t\t\tassert.Equal(t, want.currentPosition, loader.currentPosition)\n", + "\t\t\t}\n", + "\t\t})\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "975f58c5-a793-412a-aca6-1bc7579d300b", + "metadata": {}, + "source": [ + "# Tensors\n", + "\n", + "What is a tensor?\n", + "A tensor is a multi-dimensional array. A regular slice is one-dimensional, holding elements in a sequence. A tensor can have multiple dimensions, like a 2D array (grid) or even a 3D array (cube).\n", + "\n", + "[Computerphile video](https://www.youtube.com/watch?v=DfK83xEtJ_k)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "307f09fa-b67f-4e10-a8b0-ebdc401ad45e", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "type tensor struct {\n", + "\tdata []float32\n", + "\tdims []int\n", + "}\n", + "\n", + "// TODO: make this better\n", + "func (t tensor) Data() []float32 {\n", + "\treturn t.data\n", + "}\n", + "\n", + "func newTensor(data []float32, dims ...int) (tensor, int) {\n", + "\ts := 1\n", + "\tfor _, d := range dims {\n", + "\t\ts *= d\n", + "\t}\n", + "\tif s > len(data) {\n", + "\t\tpanic(\"dimensions larger than supplied data\")\n", + "\t}\n", + "\tss := min(s, len(data))\n", + "\treturn tensor{\n", + "\t\tdata: data[:ss],\n", + "\t\tdims: dims,\n", + "\t}, ss\n", + "}\n", + "\n", + "func (t tensor) size() int {\n", + "\tsize := 1\n", + "\tfor _, dim := range t.dims {\n", + "\t\tsize *= dim\n", + "\t}\n", + "\treturn size\n", + "}\n", + "\n", + "func (t tensor) index(idx ...int) tensor {\n", + "\t// 1. Error Handling (Partially Adjusted)\n", + "\tif len(idx) > len(t.dims) {\n", + "\t\tpanic(\"Too many indices for tensor dimensions\")\n", + "\t}\n", + "\tfor i, dim := range idx {\n", + "\t\tif dim < 0 || dim >= t.dims[i] {\n", + "\t\t\tpanic(\"Index out of bounds\")\n", + "\t\t}\n", + "\t}\n", + "\t// 2. Calculate Linear Index (Partially Adjusted)\n", + "\tlinearIndex := idx[0]\n", + "\tstride := t.size()\n", + "\tfor i := 1; i < len(idx); i++ {\n", + "\t\tstride /= t.dims[i]\n", + "\t\tlinearIndex += idx[i] * stride\n", + "\t}\n", + "\t// 3. Adjust Dimensions and Return Sub-Tensor\n", + "\tnewDims := t.dims[len(idx):] // Keep remaining dimensions\n", + "\tend := linearIndex + t.subTensorSize(newDims) // Size based on remaining dimensions\n", + "\n", + "\treturn tensor{\n", + "\t\tdata: t.data[linearIndex:end],\n", + "\t\tdims: newDims,\n", + "\t}\n", + "}\n", + "\n", + "// Helper function to calculate the size of a sub-tensor\n", + "func (t tensor) subTensorSize(idx []int) int {\n", + "\tsubTensorSize := 1\n", + "\tfor _, dim := range t.dims[len(idx):] {\n", + "\t\tsubTensorSize *= dim\n", + "\t}\n", + "\treturn subTensorSize\n", + "}\n", + "\n", + "// ParameterTensors are the parameters of the model\n", + "type ParameterTensors struct {\n", + "\tMemory []float32\n", + "\tWordTokEmbed tensor // (V, C) - Word/Token Embedding weights (Vocabulary size, Embedding dimension)\n", + "\tWordPosEmbed tensor // (maxT, C) - Positional Embedding weights (Maximum Sequence length, Embedding dimension)\n", + "\tLayerNorm1W tensor // (L, C) - Weights for Layer Normalization 1 (Number of layers, Embedding dimension)\n", + "\tLayerNorm1B tensor // (L, C) - Biases for Layer Normalization 1\n", + "\tQueryKeyValW tensor // (L, 3*C, C) - Attention QKV weights (Layers, 3 * Embedding dimension, Embedding dimension)\n", + "\tQueryKeyValB tensor // (L, 3*C) - Attention QKV biases\n", + "\tAttProjW tensor // (L, C, C) - Attention projection weights (Layers, Embedding dimension, Embedding dimension)\n", + "\tAttProjB tensor // (L, C) - Attention projection biases\n", + "\tLayer2NormW tensor // (L, C) - Weights for Layer Normalization 2\n", + "\tLayer2NormB tensor // (L, C) - Biases for Layer Normalization 2\n", + "\tFeedFwdW tensor // (L, 4*C, C) - Feed-forward layer weights (Layers, 4 * Embedding Dimension, Embedding Dimension)\n", + "\tFeedFwdB tensor // (L, 4*C) - Feed-forward layer biases\n", + "\tFeedFwdProjW tensor // (L, C, 4*C) - Feed-forward projection weights\n", + "\tFeedFwdProjB tensor // (L, C)- Feed-forward projection biases\n", + "\tLayerFinNormW tensor // (C) - Final layer normalization weights\n", + "\tLayerFinNormB tensor // (C) - Final layer normalization biases\n", + "}\n", + "\n", + "func newParameterTensors(V, C, maxSeqLen, L int) ParameterTensors {\n", + "\tvar tensor ParameterTensors\n", + "\ttensor.Init(V, C, maxSeqLen, L)\n", + "\treturn tensor\n", + "}\n", + "\n", + "func (tensor *ParameterTensors) Len() int {\n", + "\treturn len(tensor.Memory)\n", + "}\n", + "\n", + "// Init initialises the ParameterTensors with specific sizes for each tensor based on the model architecture.\n", + "func (tensor *ParameterTensors) Init(V, C, maxSeqLen, L int) {\n", + "\ttensor.Memory = make([]float32,\n", + "\t\tV*C+ // WordTokEmbed\n", + "\t\t\tmaxSeqLen*C+ // WordPosEmbed\n", + "\t\t\tL*C+ // LayerNorm1W\n", + "\t\t\tL*C+ // LayerNorm1B\n", + "\t\t\tL*3*C*C+ // QueryKeyValW\n", + "\t\t\tL*3*C+ // QueryKeyValB\n", + "\t\t\tL*C*C+ // AttProjW\n", + "\t\t\tL*C+ // AttProjB\n", + "\t\t\tL*C+ // Layer2NormW\n", + "\t\t\tL*C+ // Layer2NormB\n", + "\t\t\tL*4*C*C+ // FeedFwdW\n", + "\t\t\tL*4*C+ // FeedFwdB\n", + "\t\t\tL*C*4*C+ // FeedFwdProjW\n", + "\t\t\tL*C+ // FeedFwdProjB\n", + "\t\t\tC+ // LayerFinNormW\n", + "\t\t\tC, // LayerFinNormB\n", + "\t)\n", + "\tvar ptr int\n", + "\tmemPtr := tensor.Memory\n", + "\ttensor.WordTokEmbed, ptr = newTensor(memPtr, V, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.WordPosEmbed, ptr = newTensor(memPtr, maxSeqLen, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm1W, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm1B, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.QueryKeyValW, ptr = newTensor(memPtr, L, 3*C, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.QueryKeyValB, ptr = newTensor(memPtr, L, 3*C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.AttProjW, ptr = newTensor(memPtr, L, C, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.AttProjB, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Layer2NormW, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Layer2NormB, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedFwdW, ptr = newTensor(memPtr, L, 4*C, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedFwdB, ptr = newTensor(memPtr, L, 4*C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedFwdProjW, ptr = newTensor(memPtr, L, C, 4*C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedFwdProjB, ptr = newTensor(memPtr, L, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerFinNormW, ptr = newTensor(memPtr, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerFinNormB, ptr = newTensor(memPtr, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\tif len(memPtr) != 0 {\n", + "\t\tpanic(\"something went real bad here\")\n", + "\t}\n", + "}\n", + "\n", + "// ActivationTensors\n", + "type ActivationTensors struct {\n", + "\tMemory []float32\n", + "\tEncoded tensor // (B, T, C) - Initial encoded input representations (Batch size, Sequence length, Embedding dimension)\n", + "\tLayer1Act tensor // (L, B, T, C) - Activations after Layer Normalization 1\n", + "\tLayerNorm1Mean tensor // (L, B, T) - Mean values for Layer Normalization 1\n", + "\tLayerNorm1Rstd tensor // (L, B, T) - Reciprocal of standard deviation for Layer Normalization 1\n", + "\tQueryKeyVal tensor // (L, B, T, 3*C) - Combined Query, Key, Value representations for attention\n", + "\tAttentionInter tensor // (L, B, T, C) - Intermediate attention-like result\n", + "\tPreAttention tensor // (L, B, NH, T, T) - Pre-attention scores\n", + "\tAttention tensor // (L, B, NH, T, T) - Normalized attention weights (Number of layers, Batch size, Number of Attention Heads, Sequence length, Sequence length)\n", + "\tAttentionProj tensor // (L, B, T, C) - Projected attention outputs\n", + "\tResidual2 tensor // (L, B, T, C) - Residual connection after attention\n", + "\tLayerNorm2Act tensor // (L, B, T, C) - Activations after Layer Normalization 2\n", + "\tLayerNorm2Mean tensor // (L, B, T) - Mean values for Layer Normalization 2\n", + "\tLayerNorm2Rstd tensor // (L, B, T) - Reciprocal of standard deviation for Layer Normalization 2\n", + "\tFeedForward tensor // (L, B, T, 4*C) - Intermediate Feed-Forward Network activations\n", + "\tFeedForwardGelu tensor // (L, B, T, 4*C) - FeedForward activations after applying GELU (non-linearity)\n", + "\tFeedForwardProj tensor // (L, B, T, C) - Projected output of the Feed-Forward Network\n", + "\tResidual3 tensor // (L, B, T, C) - Residual connection after Feed-Forward Network\n", + "\tLayerNormFinal tensor // (B, T, C) - Final activations after Layer Normalization\n", + "\tLayerNormFinalMean tensor // (B, T) - Mean values for final Layer Normalization\n", + "\tLayerNormFinalStd tensor // (B, T) - Reciprocal of standard deviation for final Layer Normalization\n", + "\tLogits tensor // (B, T, V) - Raw output scores (before softmax)\n", + "\tProbabilities tensor // (B, T, V) - Softmax probabilities over the vocabulary\n", + "\tLosses tensor // (B, T) - Loss values per token in the batch\n", + "}\n", + "\n", + "func (tensor *ActivationTensors) Init(B, C, T, L, NH, V int) {\n", + "\ttensor.Memory = make([]float32,\n", + "\t\tB*T*C+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*T+\n", + "\t\t\tL*B*T+\n", + "\t\t\tL*B*T*C*3+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*NH*T*T+\n", + "\t\t\tL*B*NH*T*T+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*T+\n", + "\t\t\tL*B*T+\n", + "\t\t\tL*B*T*C*4+\n", + "\t\t\tL*B*T*C*4+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tL*B*T*C+\n", + "\t\t\tB*T*C+\n", + "\t\t\tB*T+\n", + "\t\t\tB*T+\n", + "\t\t\tB*T*V+\n", + "\t\t\tB*T*V+\n", + "\t\t\tB*T)\n", + "\tvar ptr int\n", + "\tmemPtr := tensor.Memory\n", + "\ttensor.Encoded, ptr = newTensor(memPtr, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Layer1Act, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm1Mean, ptr = newTensor(memPtr, L, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm1Rstd, ptr = newTensor(memPtr, L, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.QueryKeyVal, ptr = newTensor(memPtr, L, B, T, C*3)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.AttentionInter, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.PreAttention, ptr = newTensor(memPtr, L, B, NH, T, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Attention, ptr = newTensor(memPtr, L, B, NH, T, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.AttentionProj, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Residual2, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm2Act, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm2Mean, ptr = newTensor(memPtr, L, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNorm2Rstd, ptr = newTensor(memPtr, L, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedForward, ptr = newTensor(memPtr, L, B, T, C*4)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedForwardGelu, ptr = newTensor(memPtr, L, B, T, C*4)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.FeedForwardProj, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Residual3, ptr = newTensor(memPtr, L, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNormFinal, ptr = newTensor(memPtr, B, T, C)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNormFinalMean, ptr = newTensor(memPtr, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.LayerNormFinalStd, ptr = newTensor(memPtr, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Logits, ptr = newTensor(memPtr, B, T, V)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Probabilities, ptr = newTensor(memPtr, B, T, V)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\ttensor.Losses, ptr = newTensor(memPtr, B, T)\n", + "\tmemPtr = memPtr[ptr:]\n", + "\tif len(memPtr) != 0 {\n", + "\t\tpanic(\"something went real bad here\")\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "4d0ed719-c21a-4046-8c63-464818414fc6", + "metadata": {}, + "source": [ + "# Table of contents\n", + "- Tokenization\n", + "- Architecture\n", + "- Training\n", + " - Forward pass\n", + " - Backward pass" + ] + }, + { + "cell_type": "markdown", + "id": "d86a9c14-c28d-4dc4-b69e-7d316ecf9b7c", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Tokenization" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f6e52b05-a25c-403f-a870-0bb7363da62e", + "metadata": {}, + "outputs": [], + "source": [ + "import (\n", + "\t\"encoding/binary\"\n", + "\t\"errors\"\n", + " \"os\"\n", + ")\n", + "\n", + "type Tokenizer struct {\n", + "\tvocabSize uint32\n", + "\ttokenTable []string\n", + "\ttokenMap map[string]int32\n", + "\tinit bool\n", + "}\n", + "\n", + "func newTokenizer(vocab []string) Tokenizer {\n", + "\ttokenizer := Tokenizer{\n", + "\t\tvocabSize: uint32(len(vocab)),\n", + "\t\ttokenTable: vocab,\n", + "\t\ttokenMap: make(map[string]int32),\n", + "\t\tinit: true,\n", + "\t}\n", + "\tfor i, token := range vocab {\n", + "\t\ttokenizer.tokenMap[token] = int32(i)\n", + "\t}\n", + "\treturn tokenizer\n", + "}\n", + "\n", + "func NewTokenizer(filename string) (Tokenizer, error) {\n", + "\tf, err := os.Open(filename)\n", + "\tif err != nil {\n", + "\t\treturn Tokenizer{}, err\n", + "\t}\n", + "\tdefer f.Close()\n", + "\theader := make([]uint32, 256)\n", + "\tif err := binary.Read(f, binary.LittleEndian, header); err != nil {\n", + "\t\treturn Tokenizer{}, err\n", + "\t}\n", + "\tif header[0] != 20240328 || header[1] != 1 {\n", + "\t\treturn Tokenizer{}, errors.New(\"incorrect header for tokenizer\")\n", + "\t}\n", + "\ttok := Tokenizer{\n", + "\t\tvocabSize: header[2],\n", + "\t\ttokenTable: make([]string, header[2]),\n", + "\t\ttokenMap: make(map[string]int32),\n", + "\t\tinit: true,\n", + "\t}\n", + "\tvar length byte\n", + "\tfor i := range tok.tokenTable {\n", + "\t\tif err := binary.Read(f, binary.LittleEndian, &length); err != nil {\n", + "\t\t\treturn tok, err\n", + "\t\t}\n", + "\t\tif length <= 0 {\n", + "\t\t\treturn tok, errors.New(\"tokenizer failure\")\n", + "\t\t}\n", + "\t\ttokenBytes := make([]byte, length)\n", + "\t\tif err := binary.Read(f, binary.LittleEndian, tokenBytes); err != nil {\n", + "\t\t\treturn tok, err\n", + "\t\t}\n", + "\t\ttok.tokenTable[i] = string(tokenBytes)\n", + "\t\ttok.tokenMap[tok.tokenTable[i]] = int32(i)\n", + "\t}\n", + "\treturn tok, nil\n", + "}\n", + "\n", + "func (t Tokenizer) Decode(tokens []int32) (string, error) {\n", + "\ts := \"\"\n", + "\tfor _, token := range tokens {\n", + "\t\tif token >= int32(len(t.tokenTable)) {\n", + "\t\t\treturn \"\", errors.New(\"not valid token\")\n", + "\t\t}\n", + "\t\tif token != GPT2_EOT {\n", + "\t\t\ts += t.tokenTable[token]\n", + "\t\t}\n", + "\t}\n", + "\treturn s, nil\n", + "}\n", + "\n", + "func (t Tokenizer) Encode(text string) ([]int32, error) {\n", + "\ttokens := []int32{}\n", + "\tfor len(text) > 0 {\n", + "\t\tlongestMatch := \"\"\n", + "\t\tlongestMatchToken := int32(GPT2_EOT)\n", + "\t\tfor i := len(text); i > 0; i-- {\n", + "\t\t\tsubStr := text[:i]\n", + "\t\t\tif token, exists := t.tokenMap[subStr]; exists {\n", + "\t\t\t\tlongestMatch = subStr\n", + "\t\t\t\tlongestMatchToken = token\n", + "\t\t\t\tbreak\n", + "\t\t\t}\n", + "\t\t}\n", + "\t\tif longestMatch == \"\" {\n", + "\t\t\t// If no match found, treat the first character as an unknown token\n", + "\t\t\ttokens = append(tokens, GPT2_EOT)\n", + "\t\t\ttext = text[1:]\n", + "\t\t} else {\n", + "\t\t\ttokens = append(tokens, longestMatchToken)\n", + "\t\t\ttext = text[len(longestMatch):]\n", + "\t\t}\n", + "\t}\n", + "\treturn tokens, nil\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f79b8806-0380-4925-a6b6-54415a449dfa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestTokenizer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "000000000000000000000000000000\n", + "0000000000000000, 00000000, 000000, " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[25645 8269 10535]\n", + "--- PASS: TestTokenizer (0.08s)\n", + "PASS\n" + ] + } + ], + "source": [ + "%test\n", + "func TestTokenizer(t *testing.T) {\n", + "\ttext := \"000000000000000000000000000000\"\n", + "\tprintln(text)\n", + "\ttokenizer, err := NewTokenizer(\"./gpt2_tokenizer.bin\")\n", + "\tassert.NoError(t, err)\n", + "\tencoded, err := tokenizer.Encode(text)\n", + "\tfmt.Println(encoded)\n", + "\tfor _, tok := range encoded {\n", + "\t\tdecoded, err := tokenizer.Decode([]int32{tok})\n", + "\t\tassert.NoError(t, err)\n", + "\t\tprint(decoded)\n", + "\t\tprint(\", \")\n", + "\t}\n", + "\tassert.NoError(t, err)\n", + "\tdecoded, err := tokenizer.Decode(encoded)\n", + "\tassert.NoError(t, err)\n", + "\tassert.Equal(t, text, decoded)\n", + "\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "168a356a-a504-4bbd-95d9-5df1daef5e8b", + "metadata": {}, + "source": [ + "# Architecture" + ] + }, + { + "cell_type": "markdown", + "id": "eaa9b73f-cdcf-457f-b02f-24f6b17c6bed", + "metadata": {}, + "source": [ + "\"decoder-architecture\"" + ] + }, + { + "cell_type": "markdown", + "id": "70117143-70e8-47bf-b3c5-71e70c6a2eb9", + "metadata": {}, + "source": [ + "- [Tokenizer](#tokenizer)\n", + "- [Embedding](#embedding)\n", + "- [Masked Multi-Head Attention](#masked-multi-head-attention)\n", + "- [Add and Norm](#add-norm)\n", + "- [Feed Forward](#feed-forward)\n", + "- [Linear](#linear)\n", + "- [Softmax](#softmax)\n", + "- [Sampling](#sampling)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1d9fa878-667c-49d7-b8b6-b7589d53b4bf", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "973f678c-3961-44dd-b70f-2eafbbfdc51e", + "metadata": {}, + "outputs": [], + "source": [ + "type GPT2Config struct {\n", + "\tMaxSeqLen int `json:\"max_seq_len\"`\n", + "\tV int `json:\"vocab_size\"`\n", + "\tL int `json:\"num_layers\"`\n", + "\tNH int `json:\"num_heads\"`\n", + "\tC int `json:\"channels\"`\n", + "\tEOT int32\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "0c705437-6396-4a96-8004-9894541187c0", + "metadata": {}, + "outputs": [], + "source": [ + "type GPT2 struct {\n", + "\tTokenizer Tokenizer\n", + "\tConfig GPT2Config // Hyper-parameters of the model\n", + "\t// Params has the actual weights of the model. Params.Memory is for convenience to be able to set/reset parameters simply\n", + "\tParams ParameterTensors // Weights of the model\n", + "\t// Grads contains the delta/gradient that will eventually be applied to the params in the model\n", + "\tGrads ParameterTensors // Gradients of the weights\n", + "\t// Fields for AdamW optimizer\n", + "\tMMemory []float32 // First moment estimates (for AdamW)\n", + "\tVMemory []float32 // Second moment estimates (for AdamW)\n", + "\tActs ActivationTensors // Activations of the model\n", + "\t// gradients of the activations\n", + "\tGradsActs ActivationTensors\n", + "\tB int // Current batch size (B)\n", + "\tT int // Current sequence length (T)\n", + "\tInputs []int32 // Input tokens\n", + "\tTargets []int32 // Target tokens\n", + "\tMeanLoss float32 // Mean loss after a forward pass\n", + "\tRand *rand.Rand\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "a6135550-0a1a-4d92-94ae-d807fffd9f3f", + "metadata": {}, + "outputs": [], + "source": [ + "func loadFromReader(f io.Reader) (*GPT2, error) {\n", + "\theader := make([]int32, 256)\n", + "\terr := binary.Read(f, binary.LittleEndian, header)\n", + "\tif err != nil {\n", + "\t\treturn nil, fmt.Errorf(\"error reading model header: %v\", err)\n", + "\t}\n", + "\tif header[0] != 20240326 || header[1] != 1 {\n", + "\t\treturn nil, fmt.Errorf(\"bad model file format\")\n", + "\t}\n", + "\tmodel := &GPT2{\n", + "\t\tConfig: GPT2Config{\n", + "\t\t\tMaxSeqLen: int(header[2]),\n", + "\t\t\tV: int(header[3]),\n", + "\t\t\tL: int(header[4]),\n", + "\t\t\tNH: int(header[5]),\n", + "\t\t\tC: int(header[6]),\n", + "\t\t\tEOT: GPT2_EOT,\n", + "\t\t},\n", + "\t\tRand: rand.New(rand.NewSource(21)),\n", + "\t}\n", + "\tmodel.Params.Init(model.Config.V, model.Config.C, model.Config.MaxSeqLen, model.Config.L)\n", + "\tif err := binary.Read(f, binary.LittleEndian, model.Params.Memory); err != nil {\n", + "\t\treturn nil, fmt.Errorf(\"error reading model: %v\", err)\n", + "\t}\n", + "\treturn model, nil\n", + "}\n", + "// LoadGPT2Model loads the GPT-2 model from a checkpoint file.\n", + "func LoadGPT2Model(checkpointPath, tokenizerFile string) (*GPT2, error) {\n", + "\t// File Reading\n", + "\tf, err := os.Open(checkpointPath)\n", + "\tif err != nil {\n", + "\t\treturn nil, fmt.Errorf(\"Error opening model file: %v\", err)\n", + "\t}\n", + "\tdefer f.Close()\n", + "\t// Read Model Header\n", + "\tmodel, err := loadFromReader(f)\n", + "\tif err != nil {\n", + "\t\treturn nil, err\n", + "\t}\n", + "\tif tokenizerFile == \"\" {\n", + "\t\treturn model, err\n", + "\t}\n", + "\ttok, err := NewTokenizer(tokenizerFile)\n", + "\tif err != nil {\n", + "\t\treturn nil, err\n", + "\t}\n", + "\tmodel.Tokenizer = tok\n", + "\treturn model, nil\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "def9dfc8-906a-4785-96f6-b41fd97dc1cd", + "metadata": {}, + "source": [ + "# Encoder forward" + ] + }, + { + "cell_type": "markdown", + "id": "44322cf3-2af6-4a44-846e-fd85f6564912", + "metadata": {}, + "source": [ + "encoderForward iterates through the batch/sequence and combines the word token embeddings\n", + "with the word position embeddings. This allows out vector to encode tokens and positions in one." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "098e1e8c-fd88-469c-aa63-30c7b31377f5", + "metadata": {}, + "outputs": [], + "source": [ + "func encoderForward(out []float32, inp []int32, wte []float32, wpe []float32, B, T, C int) {\n", + "\t// Iterate over each batch\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\t// Iterate over each time step in the sequence\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\t// Calculate the index in the output slice. Each vector is C elements long.\n", + "\t\t\tstartOutIndex := b*T*C + t*C\n", + "\t\t\t// Calculate the token ID index in the input\n", + "\t\t\t// inp is the tokenized input, each number in inp char is an index within wte (word token embeddings)\n", + "\t\t\tix := inp[b*T+t]\n", + "\t\t\t// Calculate the index in the token embeddings slice\n", + "\t\t\t// inp -> id -> wte[id]\n", + "\t\t\tstartWteIndex := ix * int32(C)\n", + "\t\t\t// Calculate the index in the position embeddings slice\n", + "\t\t\t// Wpe starts at 0 (when t is zero) which is basically mapping directly to index\n", + "\t\t\tstartWpeIndex := t * C\n", + "\t\t\t// Add the vectors from `wte` and `wpe` and store the result in `out`\n", + "\t\t\t// here we combine the vectors in the C dimensions.\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\tout[startOutIndex+i] = wte[startWteIndex+int32(i)] + wpe[startWpeIndex+i]\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "ce9d4fc8-f565-4e16-a5b7-f4118d6f92b4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestEncoderForward\n", + "=== RUN TestEncoderForward/#00\n", + "--- PASS: TestEncoderForward (0.00s)\n", + " --- PASS: TestEncoderForward/#00 (0.00s)\n", + "PASS\n" + ] + } + ], + "source": [ + "%test\n", + "func TestEncoderForward(t *testing.T) {\n", + "\ttype args struct {\n", + "\t\tout []float32\n", + "\t\tinp []int32\n", + "\t\twte []float32\n", + "\t\twpe []float32\n", + "\t\tB int\n", + "\t\tT int\n", + "\t\tC int\n", + "\t}\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\targs args\n", + "\t\twantOut []float32\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"\",\n", + "\t\t\targs: args{\n", + "\t\t\t\tinp: []int32{1, 0}, // [1 -> wte (2, 3), wpe(4, 5)] [0 -> wte (0, 1), wpe(6, 7)]\n", + "\t\t\t\twte: []float32{0, 1, 2, 3},\n", + "\t\t\t\twpe: []float32{4, 5, 6, 7},\n", + "\t\t\t\tB: 1, // Batch size\n", + "\t\t\t\tT: 1, // Sequence Len\n", + "\t\t\t\tC: 2, // Dimensions\n", + "\t\t\t},\n", + "\t\t\twantOut: []float32{6, 8},\n", + "\t\t},\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\tout := make([]float32, len(tt.args.inp))\n", + "\t\t\tencoderForward(out, tt.args.inp, tt.args.wte, tt.args.wpe, tt.args.B, tt.args.T, tt.args.C)\n", + "\t\t\tassert.Equal(t, tt.wantOut, out)\n", + "\t\t})\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "d45024b9-9bc0-46e1-b43f-4300df11e607", + "metadata": {}, + "source": [ + "# Layernorm forward\n", + "\n", + "layernormForward normalizes the activations in each layer.\n", + "It improves convergence in training and reduces sensitivity to initial parameters.\n", + "For each vector, the mean and variance are calculated.\n", + "\n", + "Reference: https://arxiv.org/abs/1607.06450\n", + "\n", + "\n", + "Parameters:\n", + " - out: output activations (B,T,C)\n", + " - mean: mean values (B,T) for each position (b,t)\n", + " - rstd: reciprocal standard deviations (B,T) for each position (b,t)\n", + " - inp: input activations (B,T,C)\n", + " - weight: learnable weight (C) for scaling\n", + " - bias: learnable bias (C) for shifting\n", + " - B: batch size\n", + " - T: sequence length (number of time steps)\n", + " - C: embedding dimension (number of features)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8b35d061-c965-42b4-b6ce-ea7d42a36593", + "metadata": {}, + "outputs": [], + "source": [ + "func layernormForward(out, mean, rstd, inp, weight, bias []float32, B, T, C int) {\n", + "\tvar eps float32 = 1e-5\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\tx := inp[b*T*C+t*C:]\n", + "\t\t\t// Calculate mean\n", + "\t\t\tvar m float32 = 0.0\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\tm += x[i]\n", + "\t\t\t}\n", + "\t\t\tm /= float32(C)\n", + "\t\t\t// Calculate variance\n", + "\t\t\tvar v float32 = 0.0\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\txshift := x[i] - m\n", + "\t\t\t\tv += xshift * xshift\n", + "\t\t\t}\n", + "\t\t\tv /= float32(C)\n", + "\t\t\t// Calculate rstd (reciprocal standard deviation)\n", + "\t\t\ts := 1.0 / Sqrt((v)+eps)\n", + "\t\t\t// Normalize, scale, shift, and store output\n", + "\t\t\toutBT := out[b*T*C+t*C:]\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t// subtract mean to center data\n", + "\t\t\t\t// divide by std to scale variance\n", + "\t\t\t\t// (val - mean) / std\n", + "\t\t\t\tn := s * (x[i] - m)\n", + "\t\t\t\t// Multiply the weight\n", + "\t\t\t\to := n*weight[i] + bias[i]\n", + "\t\t\t\toutBT[i] = o\n", + "\t\t\t}\n", + "\t\t\t// Store mean and rstd for backward pass\n", + "\t\t\tmean[b*T+t] = m\n", + "\t\t\trstd[b*T+t] = s\n", + "\t\t}\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c2508379-75cd-4642-9144-19e1c18c4ed7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestLayernormForward\n", + "=== RUN TestLayernormForward/#00\n", + "--- PASS: TestLayernormForward (0.00s)\n", + " --- PASS: TestLayernormForward/#00 (0.00s)\n", + "PASS\n" + ] + } + ], + "source": [ + "%test\n", + "func TestLayernormForward(t *testing.T) {\n", + "\ttype args struct {\n", + "\t\tinp []float32\n", + "\t\tweight []float32\n", + "\t\tbias []float32\n", + "\t\tB int\n", + "\t\tT int\n", + "\t\tC int\n", + "\t}\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\targs args\n", + "\t\twantOut []float32\n", + "\t\twantMean []float32\n", + "\t\twantRstd []float32\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"\",\n", + "\t\t\targs: args{\n", + "\t\t\t\tinp: []float32{0.2, 0.1, 0.3, 0.5, 0.1, 0.1},\n", + "\t\t\t\tweight: []float32{1, 1, 1, 1, 1, 1},\n", + "\t\t\t\tbias: []float32{0, 0, 0, 0, 0, 0},\n", + "\t\t\t\tB: 2,\n", + "\t\t\t\tT: 1,\n", + "\t\t\t\tC: 3,\n", + "\t\t\t},\n", + "\t\t\twantOut: []float32{0, -1.2238272, 1.2238274, 1.4140146, -0.70700747, -0.70700747},\n", + "\t\t\twantMean: []float32{0.2, 0.23333335},\n", + "\t\t\twantRstd: []float32{12.238273, 5.302555},\n", + "\t\t},\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\tout, mean, rstd := make([]float32, len(tt.args.inp)), make([]float32, tt.args.B*tt.args.T), make([]float32, tt.args.B*tt.args.T)\n", + "\t\t\tlayernormForward(out, mean, rstd, tt.args.inp, tt.args.weight, tt.args.bias, tt.args.B, tt.args.T, tt.args.C)\n", + "\t\t\trequire.InDeltaSlice(t, tt.wantOut, out, delta)\n", + "\t\t\trequire.InDeltaSlice(t, tt.wantMean, mean, delta)\n", + "\t\t\trequire.InDeltaSlice(t, tt.wantRstd, rstd, delta)\n", + "\t\t})\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "2cf70f39-3cb7-4314-9318-92623c35e773", + "metadata": {}, + "source": [ + "# Matmul forward\n", + "\n", + "matmulForward performs matrix multiplication and adds bias.\n", + "Parameters:\n", + " - out: output matrix\n", + " - inp: input matrix\n", + " - weight: weight matrix\n", + " - bias: bias vector\n", + " - B: batch size\n", + " - T: sequence length (number of time steps)\n", + " - C: input dimension (number of features)\n", + " - OC: number of output channels" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b95fe909-e5c1-4159-82a8-a2d969978a0b", + "metadata": {}, + "outputs": [], + "source": [ + "func matmulForward(out, inp, weight, bias []float32, B, T, C, OC int) {\n", + "\t// Iterate over each batch\n", + "\tvar wg sync.WaitGroup\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\t// Iterate over each time step in the sequence\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\twg.Add(1)\n", + "\t\t\tgo func(b, t int) {\n", + "\t\t\t\tdefer wg.Done()\n", + "\t\t\t\t// Calculate the index in the output slice\n", + "\t\t\t\tinp_bt := inp[b*T*C+t*C:]\n", + "\t\t\t\tout_bt := out[b*T*OC+t*OC:]\n", + "\t\t\t\tfor o := 0; o < OC; o++ {\n", + "\t\t\t\t\tvar val float32\n", + "\t\t\t\t\tif bias != nil {\n", + "\t\t\t\t\t\tval = bias[o]\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t// Calculate the index in the weight slice\n", + "\t\t\t\t\twrow := weight[o*C:]\n", + "\t\t\t\t\t// Perform the dot product between the input and weight row\n", + "\t\t\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t\t\tval += inp_bt[i] * wrow[i]\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t// Store the output value in the output slice\n", + "\t\t\t\t\tout_bt[o] = val\n", + "\t\t\t\t}\n", + "\t\t\t}(b, t)\n", + "\t\t}\n", + "\t}\n", + "\twg.Wait()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3bf1229d-cb85-41f3-b425-4db84940a49b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestMatmulForward\n", + "=== RUN TestMatmulForward/simple\n", + "--- PASS: TestMatmulForward (0.00s)\n", + " --- PASS: TestMatmulForward/simple (0.00s)\n", + "PASS\n" + ] + } + ], + "source": [ + "% test\n", + "func TestMatmulForward(t *testing.T) {\n", + "\ttype args struct {\n", + "\t\tinp []float32\n", + "\t\tweight []float32\n", + "\t\tbias []float32\n", + "\t\tB int\n", + "\t\tT int\n", + "\t\tC int\n", + "\t\tOC int\n", + "\t}\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\targs args\n", + "\t\twantOut []float32\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"simple\",\n", + "\t\t\targs: args{\n", + "\t\t\t\tweight: []float32{ // OC (3) * C(2)\n", + "\t\t\t\t\t1, 2,\n", + "\t\t\t\t\t3, 4,\n", + "\t\t\t\t\t5, 6,\n", + "\t\t\t\t},\n", + "\t\t\t\tinp: []float32{ // B(1) * T(1) * T(1) * C(2)\n", + "\t\t\t\t\t1,\n", + "\t\t\t\t\t2,\n", + "\t\t\t\t},\n", + "\t\t\t\tbias: []float32{1, 2, 3}, // OC\n", + "\t\t\t\t// WEIGHT * INP + BIAS\n", + "\t\t\t\tB: 1,\n", + "\t\t\t\tT: 1,\n", + "\t\t\t\tC: 2,\n", + "\t\t\t\tOC: 3,\n", + "\t\t\t},\n", + "\t\t\twantOut: []float32{\n", + "\t\t\t\t6,\n", + "\t\t\t\t13,\n", + "\t\t\t\t20,\n", + "\t\t\t},\n", + "\t\t},\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\tout := make([]float32, tt.args.OC)\n", + "\t\t\tmatmulForward(out, tt.args.inp, tt.args.weight, tt.args.bias, tt.args.B, tt.args.T, tt.args.C, tt.args.OC)\n", + "\t\t\tassert.Equal(t, tt.wantOut, out)\n", + "\t\t})\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "53439744-7221-40ce-8a89-3815e9962453", + "metadata": {}, + "source": [ + "# AttentionForward\n", + "\n", + "attentionForward performs the attention forward pass.\n", + "\n", + "attention is the only layer that mixes information across time\n", + "every other operation is applied at every (b,t) position independently\n", + "(and of course, no layer mixes information across batch)\n", + "\n", + "Parameters:\n", + " - out: output matrix (B,T,C)\n", + " - preatt: pre-attention scores (B,NH,T,T)\n", + " - att: post-attention scores (B,NH,T,T)\n", + " - inp: input matrix (B,T,3C) holding Query, Key, Value vectors\n", + " - B: batch size\n", + " - T: sequence length (number of time steps)\n", + " - C: input dimension (number of features)\n", + " - NH: number of attention heads" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b5b4dc98-1c0c-45c6-bb16-f7167fb97584", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "func attentionForward(out, preatt, att, inp []float32, B, T, C, NH int) {\n", + "\tC3 := C * 3 // This is the dimensions for the key, query and values\n", + "\ths := C / NH // head size\n", + "\tscale := 1.0 / Sqrt(float32(hs))\n", + "\t// Iterate over batch, sequence length, and number of heads\n", + "\tvar wg sync.WaitGroup\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\t// Sequence length\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\tfor h := 0; h < NH; h++ {\n", + "\t\t\t\twg.Add(1)\n", + "\t\t\t\tgo func(b, t, h int) {\n", + "\t\t\t\t\tdefer wg.Done()\n", + "\t\t\t\t\t// Calculate indices for query, pre-attention, and attention arrays\n", + "\t\t\t\t\t// query is any particular input asking for information from other inputs\n", + "\t\t\t\t\tqueryT := inp[b*T*C3+t*C3+h*hs:] // inp[B][T][C3]\n", + "\t\t\t\t\tpreattBth := preatt[b*NH*T*T+h*T*T+t*T:]\n", + "\t\t\t\t\tattBth := att[b*NH*T*T+h*T*T+t*T:]\n", + "\t\t\t\t\t// Pass 1: Calculate query dot key and max value\n", + "\t\t\t\t\t// The dot product is described in the paper as being better because\n", + "\t\t\t\t\t// it can be optimized with matrix multiplication\n", + "\t\t\t\t\tvar maxval float32 = -10000.0\n", + "\t\t\t\t\t// range from 0 to the current inp\n", + "\t\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\t\t// Calculate key index for t2\n", + "\t\t\t\t\t\tkey_t2 := inp[b*T*C3+t2*C3+h*hs+C:] // +C because it's key\n", + "\t\t\t\t\t\t// Compute dot product and update max value\n", + "\t\t\t\t\t\tvar val float32\n", + "\t\t\t\t\t\tfor i := 0; i < hs; i++ {\n", + "\t\t\t\t\t\t\tval += queryT[i] * key_t2[i]\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t\tval *= scale\n", + "\t\t\t\t\t\tif val > maxval {\n", + "\t\t\t\t\t\t\tmaxval = val\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t\t// preatt[b][h][t1][t2] == dot product (similarity) between query vector at position t1 and\n", + "\t\t\t\t\t\t// key vector at t2.\n", + "\t\t\t\t\t\tpreattBth[t2] = val\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t// Pass 2: Calculate the exp and keep track of sum\n", + "\t\t\t\t\t// Calculate exponential sum and update preatt and att arrays\n", + "\t\t\t\t\t// maps the max value to zero,\n", + "\t\t\t\t\t// and everything else negative.\n", + "\t\t\t\t\t// when the exp function is called then the range of numbers will be\n", + "\t\t\t\t\t// between 0 and e.\n", + "\t\t\t\t\tvar expsum float32\n", + "\t\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\t\texpv := Exp((preattBth[t2]) - maxval)\n", + "\t\t\t\t\t\t// expsum is a sum of all the exp'd pre_att values\n", + "\t\t\t\t\t\texpsum += expv\n", + "\t\t\t\t\t\t// att_bth[t2] is the exp'd preatt_bth[t2]\n", + "\t\t\t\t\t\tattBth[t2] = expv\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tvar expsum_inv float32\n", + "\t\t\t\t\tif expsum != 0.0 {\n", + "\t\t\t\t\t\texpsum_inv = 1.0 / expsum\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\t// Pass 3: Normalize to get softmax\n", + "\t\t\t\t\t// from 0 -> t2: att_bth[t2] = exp(preatt[t2]) / sum(exp(preatt[:]))\n", + "\t\t\t\t\t// for everything else it's zero\n", + "\t\t\t\t\tfor t2 := 0; t2 < T; t2++ {\n", + "\t\t\t\t\t\tif t2 <= t {\n", + "\t\t\t\t\t\t\tattBth[t2] *= expsum_inv\n", + "\t\t\t\t\t\t} else {\n", + "\t\t\t\t\t\t\t// Causal attention mask (optional; used for debugging and comparison)\n", + "\t\t\t\t\t\t\tattBth[t2] = 0.0\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\n", + "\t\t\t\t\t// Pass 4: Accumulate weighted values into the output of attention\n", + "\t\t\t\t\t// out = attention * values\n", + "\t\t\t\t\t// The values in this instance are the initial token/position embeddings that have gone through many linear\n", + "\t\t\t\t\t// transformations at this point.\n", + "\t\t\t\t\t// This is simply applying the learned attention \"weights\" to the lkqv values.\n", + "\t\t\t\t\t// These weights must change a whole bunch after back propagation.\n", + "\t\t\t\t\tout_bth := out[b*T*C+t*C+h*hs:]\n", + "\t\t\t\t\tfor i := 0; i < hs; i++ {\n", + "\t\t\t\t\t\tout_bth[i] = 0.0\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\t\tvalue_t2 := inp[b*T*C3+t2*C3+h*hs+C*2:] // +C*2 because it's value\n", + "\t\t\t\t\t\tatt_btht2 := attBth[t2]\n", + "\t\t\t\t\t\tfor i := 0; i < hs; i++ {\n", + "\t\t\t\t\t\t\tout_bth[i] += att_btht2 * value_t2[i]\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}(b, t, h)\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "\twg.Wait()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "5e8663da-22f7-42a6-831e-c2fc954a470d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestAttentionForward\n", + "=== RUN TestAttentionForward/Small_Input_Test\n", + "=== RUN TestAttentionForward/Larger_Input_Test\n", + "--- PASS: TestAttentionForward (0.00s)\n", + " --- PASS: TestAttentionForward/Small_Input_Test (0.00s)\n", + " --- PASS: TestAttentionForward/Larger_Input_Test (0.00s)\n", + "PASS\n" + ] + } + ], + "source": [ + "%test\n", + "func TestAttentionForward(t *testing.T) {\n", + "\ttype args struct {\n", + "\t\tinp []float32\n", + "\t\tB int\n", + "\t\tT int\n", + "\t\tC int\n", + "\t\tNH int\n", + "\t}\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\targs args\n", + "\t\twantOut []float32\n", + "\t\twantPreatt []float32\n", + "\t\twantAtt []float32\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"Small Input Test\",\n", + "\t\t\targs: args{\n", + "\t\t\t\tinp: []float32{1, 2, 3, 4, 5, 6},\n", + "\t\t\t\tB: 1,\n", + "\t\t\t\tT: 1,\n", + "\t\t\t\tC: 2,\n", + "\t\t\t\tNH: 1,\n", + "\t\t\t},\n", + "\t\t\twantOut: []float32{5, 6},\n", + "\t\t\twantPreatt: []float32{7.7781744},\n", + "\t\t\twantAtt: []float32{1},\n", + "\t\t},\n", + "\t\t{\n", + "\t\t\tname: \"Larger Input Test\",\n", + "\t\t\targs: args{\n", + "\t\t\t\tinp: []float32{ // (B, T, C3)\n", + "\t\t\t\t\t/* B = 1 */\n", + "\t\t\t\t\t/* T = 0 */\n", + "\t\t\t\t\t/*qry*/ 1, 2, 3, // query compared against (4, 5, 6) but not (13, 14, 15) because it's in the future (t=1)\n", + "\t\t\t\t\t/*key*/ 4, 5, 6,\n", + "\t\t\t\t\t/*val*/ 7, 8, 9,\n", + "\t\t\t\t\t/* T = 1 */\n", + "\t\t\t\t\t/*qry*/ 10, 11, 12, // will be compared against (4, 5, 6) (t-1) and (13, 14, 15)\n", + "\t\t\t\t\t/*key*/ 13, 14, 15,\n", + "\t\t\t\t\t/*val*/ 16, 17, 18, // vals are updated to\n", + "\t\t\t\t},\n", + "\t\t\t\tB: 1,\n", + "\t\t\t\tT: 2,\n", + "\t\t\t\tC: 3,\n", + "\t\t\t\tNH: 1,\n", + "\t\t\t},\n", + "\t\t\twantOut: []float32{ // (B, T, C)\n", + "\t\t\t\t/* B = 0 */\n", + "\t\t\t\t/* T = 0 */\n", + "\t\t\t\t/* C = 0 1 2 */\n", + "\t\t\t\t/* */ 7, 8, 9,\n", + "\t\t\t\t/* T = 1 */\n", + "\t\t\t\t/* C = 0 1 2 */\n", + "\t\t\t\t/* */ 16, 17, 18,\n", + "\t\t\t},\n", + "\t\t\twantPreatt: []float32{ // (B, NH, T, T)\n", + "\t\t\t\t/* B = 0 */\n", + "\t\t\t\t/* NH = 0 */\n", + "\t\t\t\t/*T = 1 2 */\n", + "\t\t\t\t/*T=1*/ 18.475208, 0, // preatt: 18 -> 1, 0 -> 0\n", + "\t\t\t\t/*T=2*/ 96.417496, 267.89053, // 96 -> 9, 267 -> 1\n", + "\t\t\t},\n", + "\t\t\twantAtt: []float32{ // (B, NH, T, T)\n", + "\t\t\t\t/* B = 0 */\n", + "\t\t\t\t/* NH = 0 */\n", + "\t\t\t\t/*T = 1 2 */\n", + "\t\t\t\t/*T=1*/ 1, 0,\n", + "\t\t\t\t/*T=2*/ 0, 1,\n", + "\t\t\t},\n", + "\t\t},\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\tout, preatt, att := make([]float32, len(tt.wantOut)), make([]float32, len(tt.wantPreatt)), make([]float32, len(tt.wantAtt))\n", + "\t\t\tattentionForward(out, preatt, att, tt.args.inp, tt.args.B, tt.args.T, tt.args.C, tt.args.NH)\n", + "\t\t\tassert.InDeltaSlice(t, tt.wantOut, out, 1e-4, fmt.Sprintf(\"want: %v got: %v\", tt.wantOut, out))\n", + "\t\t\tassert.InDeltaSlice(t, tt.wantPreatt, preatt, 1e-4, fmt.Sprintf(\"want: %v got: %v\", tt.wantPreatt, preatt))\n", + "\t\t\tassert.InDeltaSlice(t, tt.wantAtt, att, 1e-4, fmt.Sprintf(\"want: %v got: %v\", tt.wantAtt, att))\n", + "\t\t})\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "1465efe7-4987-474c-bcd8-fd65dcca2474", + "metadata": {}, + "source": [ + "# Residual forward\n", + "https://arxiv.org/abs/1512.03385\n", + "\n", + "residualForward implements a simple residual connection, a common technique used in deep neural networks to improve training and performance." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6b7e0789-1f0b-4d32-a922-029bfb109c3b", + "metadata": {}, + "outputs": [], + "source": [ + "func residualForward(out, inp1, inp2 []float32, N int) {\n", + "\tfor i := 0; i < N; i++ {\n", + "\t\tout[i] = inp1[i] + inp2[i]\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "224410fb-f88a-4b45-ae04-061a3c83541d", + "metadata": {}, + "source": [ + "# geluForward\n", + "The geluForward function applies the GELU activation to the input values stored in the inp slice and writes the activated values to the out slice.\n", + "\n", + "geluForward is the Gaussian Error Linear Units activation function.\n", + "It leaves positive values mostly unchanged but\n", + "maps negative value close to zero.\n", + "\n", + "https://arxiv.org/abs/1606.08415v5" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "dc544479-7bf5-4bdf-8913-f11718c06626", + "metadata": {}, + "outputs": [], + "source": [ + "var GELUSCALEFACTOR = Sqrt(2.0 / math.Pi)\n", + "func geluForward(out, inp []float32, n int) {\n", + "\tfor i := 0; i < n; i++ {\n", + "\t\tx := inp[i]\n", + "\t\tcube := 0.044715 * x * x * x\n", + "\t\tout[i] = 0.5 * x * (1.0 + Tanh(GELUSCALEFACTOR*(x+cube)))\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "cb28d168-d8c7-4dff-9b85-3a134cbcc6b7", + "metadata": {}, + "source": [ + "# Softmax" + ] + }, + { + "cell_type": "markdown", + "id": "2024f02c-a671-45e2-b22a-e552d839cdb5", + "metadata": {}, + "source": [ + "softmaxForward calculates the softmax probabilities for a batch of input logits, converting them into a probability distribution over multiple classes. It's a common operation in neural networks, especially for classification tasks." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "357b6249-95b3-4ee0-a59d-e659636e2b1a", + "metadata": {}, + "outputs": [], + "source": [ + "func softmaxForward(probs, logits []float32, B, T, V int) {\n", + "\tvar wg sync.WaitGroup\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\twg.Add(1)\n", + "\t\t\tgo func(b, t int) {\n", + "\t\t\t\tdefer wg.Done()\n", + "\t\t\t\tbaseIndex := b*T*V + t*V\n", + "\t\t\t\tlogitsBT := logits[baseIndex : baseIndex+V]\n", + "\t\t\t\tprobsBT := probs[baseIndex : baseIndex+V]\n", + "\t\t\t\t// Numerical Stability\n", + "\t\t\t\tvar maxval float32 = -10000.0\n", + "\t\t\t\tfor i := 0; i < V; i++ {\n", + "\t\t\t\t\tif logitsBT[i] > maxval {\n", + "\t\t\t\t\t\tmaxval = logitsBT[i]\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t// Calculate exponentials and sum\n", + "\t\t\t\tvar sum float32\n", + "\t\t\t\tfor i := 0; i < V; i++ {\n", + "\t\t\t\t\tprobsBT[i] = Exp((logitsBT[i] - maxval))\n", + "\t\t\t\t\tsum += probsBT[i] // Using float32 for potential precision gain\n", + "\t\t\t\t}\n", + "\t\t\t\t// Normalize\n", + "\t\t\t\tfor i := 0; i < V; i++ {\n", + "\t\t\t\t\tprobsBT[i] /= sum\n", + "\t\t\t\t}\n", + "\t\t\t}(b, t)\n", + "\t\t}\n", + "\t}\n", + "\twg.Wait()\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "7b8838c0-9f3a-4703-9bcd-6043f415b34d", + "metadata": {}, + "source": [ + "# crossEntropyForward\n", + "The function crossEntropyForward calculates the cross-entropy loss for a batch of predicted probability distributions and their corresponding target labels." + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "584ffc62-c932-40be-9aa0-0da83d866d4d", + "metadata": {}, + "outputs": [], + "source": [ + "// crossEntropyForward\n", + "func crossEntropyForward(losses []float32, probs []float32, targets []int32, B, T, V int) {\n", + "\t// Iterate over each batch\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\t// Iterate over each time step in the sequence\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\t// Calculate the index in the probability slice\n", + "\t\t\tstartIndex := int32(b*T*V + t*V)\n", + "\t\t\t// Get the correct index in the logits for the current batch and time step\n", + "\t\t\tix := targets[b*T+t]\n", + "\t\t\t// Calculate the cross-entropy loss\n", + "\t\t\tprob := probs[startIndex+ix]\n", + "\t\t\t// Calculate the negative log of the probability for the correct target index\n", + "\t\t\tlosses[b*T+t] = -Log((prob))\n", + "\t\t}\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "2f35d169-7ae0-4063-bc2f-0b1afb7a366e", + "metadata": {}, + "source": [ + "# Forward\n", + "The function Forward implements the forward pass of a GPT-2 language model. It takes a sequence of input tokens and a sequence of target tokens (if available) as input, and it calculates the model's output probabilities for the next token in the sequence." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "deac87d4-e901-401e-b109-9642ea02bc8d", + "metadata": {}, + "outputs": [], + "source": [ + "func (model *GPT2) Forward(input, target []int32, B, T int) {\n", + "\tV, L, NH, C := model.Config.V, model.Config.L, model.Config.NH, model.Config.C\n", + "\tif model.Acts.Memory == nil {\n", + "\t\tmodel.B, model.T = B, T\n", + "\t\tmodel.Acts.Init(B, C, T, L, NH, V)\n", + "\t\tmodel.Inputs = make([]int32, B*T)\n", + "\t\tmodel.Targets = make([]int32, B*T)\n", + "\t}\n", + "\tcopy(model.Inputs, input)\n", + "\tcopy(model.Targets, target)\n", + "\tparams, acts := model.Params, model.Acts\n", + "\t// This encodes the word token embeddings with the positional embeddings\n", + "\t// so that those vectors have spacial information and aren't just purely made up of the\n", + "\t// token embeddings. The result of this is stored in acts.Encoded.\n", + "\t// Input is a slice of ids/tokens that correspond to the vectors in WTE and their index is the \"position\"\n", + "\tencoderForward(acts.Encoded.data, input, params.WordTokEmbed.data, params.WordPosEmbed.data, B, T, C)\n", + "\tvar residual []float32\n", + "\tfor l := 0; l < L; l++ {\n", + "\t\t// residual is a connection between the last layers output, or the initial token/pos embedding (as applied above)\n", + "\t\tif l == 0 {\n", + "\t\t\tresidual = acts.Encoded.data\n", + "\t\t} else {\n", + "\t\t\tresidual = acts.Residual3.data[(l-1)*B*T*C:]\n", + "\t\t}\n", + "\t\t// Parameters\n", + "\t\tl_ln1w := params.LayerNorm1W.data[l*C:]\n", + "\t\tl_ln1b := params.LayerNorm1B.data[l*C:]\n", + "\t\tl_qkvw := params.QueryKeyValW.data[l*3*C*C:]\n", + "\t\tl_qkvb := params.QueryKeyValB.data[l*3*C:]\n", + "\t\tl_attprojw := params.AttProjW.data[l*C*C:]\n", + "\t\tl_attprojb := params.AttProjB.data[l*C:]\n", + "\t\tl_ln2w := params.Layer2NormW.data[l*C:]\n", + "\t\tl_ln2b := params.Layer2NormB.data[l*C:]\n", + "\t\tl_fcw := params.FeedFwdW.data[l*4*C*C:]\n", + "\t\tl_fcb := params.FeedFwdB.data[l*4*C:]\n", + "\t\tl_fcprojw := params.FeedFwdProjW.data[l*C*4*C:]\n", + "\t\tl_fcprojb := params.FeedFwdProjB.data[l*C:]\n", + "\t\t// Activations\n", + "\t\tl_ln1 := acts.Layer1Act.data[l*B*T*C:]\n", + "\t\tl_ln1_mean := acts.LayerNorm1Mean.data[l*B*T:]\n", + "\t\tl_ln1_rstd := acts.LayerNorm1Rstd.data[l*B*T:]\n", + "\t\tl_qkv := acts.QueryKeyVal.data[l*B*T*3*C:]\n", + "\t\tl_atty := acts.AttentionInter.data[l*B*T*C:]\n", + "\t\tl_preatt := acts.PreAttention.data[l*B*NH*T*T:]\n", + "\t\tl_att := acts.Attention.data[l*B*NH*T*T:]\n", + "\t\tl_attproj := acts.AttentionProj.data[l*B*T*C:]\n", + "\t\tl_residual2 := acts.Residual2.data[l*B*T*C:]\n", + "\t\tl_ln2 := acts.LayerNorm2Act.data[l*B*T*C:]\n", + "\t\tl_ln2_mean := acts.LayerNorm2Mean.data[l*B*T:]\n", + "\t\tl_ln2_rstd := acts.LayerNorm2Rstd.data[l*B*T:]\n", + "\t\tl_fch := acts.FeedForward.data[l*B*T*4*C:]\n", + "\t\tl_fch_gelu := acts.FeedForwardGelu.data[l*B*T*4*C:]\n", + "\t\tl_fcproj := acts.FeedForwardProj.data[l*B*T*C:]\n", + "\t\tl_residual3 := acts.Residual3.data[l*B*T*C:]\n", + "\t\t// Here we normalise the layer so that the mean is 0 and the standard deviation is ~1.\n", + "\t\t// residual contains the un-edited activations\n", + "\t\tlayernormForward(l_ln1, l_ln1_mean, l_ln1_rstd, residual /*inp*/, l_ln1w /*weight*/, l_ln1b /*bias*/, B, T, C)\n", + "\t\t/*\n", + "\t\t\t\t\tl_qkvw = weight = Query Key Val Weights (C * 3C)\n", + "\t\t\t\t\tl_ln1 = inp = layer activations\n", + "\t\t\t\t\tl_qkvb = bias = Query Key Val Bias\n", + "\t\t\t\t\tl_qkv = out = key/query/value matrix\n", + "\t\t\t\tHere we're matrix multiplying l_ln1(inp)*l_qkvw(weight) + l_qkvb(bias)\n", + "\t\t\t\tThis matrix multiplication essentially gets a layer activation for the model inputs (activations) which are multiplied\n", + "\t\t\t\tby the model weights.\n", + "\t\t\tThis does the input \"projection\" via linear transformations via the model query/key/value weights into higher dimensionality.\n", + "\t\t*/\n", + "\t\tmatmulForward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C)\n", + "\t\t/*\n", + "\t\t\tThe attention forward pass takes these query/key/value vectors, along with the model attention weights\n", + "\t\t\tThe model pre-attention scores, after the forward pass, have the un-normalised attention scores\n", + "\t\t\tatt has the attention acores and l_atty has the attention scores + the query/key/value scores\n", + "\t\t\tl_qkv has the projection of the activations into a higher dimension.\n", + "\t\t\tl_preatt: has the projection qkv vectors dot product(similarity), between an input's query and another input's key.\n", + "\t\t\t\tThis basically goes like this:\n", + "\t\t\t\tword a: has a query vector \"what am i looking for\"\n", + "\t\t\t\tword b: has a query vector \"what do i need\"\n", + "\t\t\t\tif they're similar, these vectors will be similar, therefore the scores will be high and be stored in l_preatt\n", + "\t\t\tthe v in the qkv is the original token/position embeddings which have been through a number of linear transformations at this point.\n", + "\t\t*/\n", + "\t\tattentionForward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH)\n", + "\n", + "\t\t/*\n", + "\t\t\tHere we do another matrix multiplication of attention weights and biases\n", + "\t\t\tThis projects the l_atty into another dimension. These will probably also get back propagated.\n", + "\t\t*/\n", + "\t\tmatmulForward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C)\n", + "\t\t/*\n", + "\t\t\tThe residual forward simply adds the attention projection and the residual layer, which is the\n", + "\t\t\tweights(or activations?) before any of the previous transformations. This allows a stronger signal and\n", + "\t\t\tprevents weight dropout and i think makes back propagation more efficient.\n", + "\t\t*/\n", + "\t\tresidualForward(l_residual2, residual, l_attproj, B*T*C)\n", + "\t\t/*\n", + "\t\t\tThe weights in this level are the layer 2 activations, which are multiplied with the residual through the above sections\n", + "\t\t\tThis is normalised and everything into layernorm2\n", + "\t\t*/\n", + "\t\tlayernormForward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C)\n", + "\t\t/*\n", + "\t\t\tFeedforward is just another layer of a multi layer perceptron to make the \"higher level\" connections.\n", + "\t\t*/\n", + "\t\tmatmulForward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C)\n", + "\t\t/*\n", + "\t\t\tThis is an acitvation function which maps large values to close to one and smaller values to zero.\n", + "\t\t*/\n", + "\t\tgeluForward(l_fch_gelu, l_fch, B*T*4*C)\n", + "\t\t/*\n", + "\t\t\tThis now squishes the last layer into a smaller dimension so it can be added to the next layer.\n", + "\t\t*/\n", + "\t\tmatmulForward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C)\n", + "\t\t/*\n", + "\t\t\tNow we set the next residual layer as the output of this layer. This is the l_fcproj + the current layer residual\n", + "\t\t*/\n", + "\t\tresidualForward(l_residual3, l_residual2, l_fcproj, B*T*C)\n", + "\t}\n", + "\tresidual = acts.Residual3.data[(L-1)*B*T*C:]\n", + "\n", + "\t/*\n", + "\t\tNow this is the last thing. We're layer norming the final layer activations so that the logits can be calculated\n", + "\n", + "\t*/\n", + "\tlayernormForward(acts.LayerNormFinal.data, acts.LayerNormFinalMean.data, acts.LayerNormFinalStd.data, residual, params.LayerFinNormW.data, params.LayerFinNormB.data, B, T, C)\n", + "\t/*\n", + "\t\t\tMatrix multiplying the Word Token embedding gives us the logits.\n", + "\t\tThis is calculating a weighted sum. More likely tokens will be blown up and less likely will be zero or negative.\n", + "\t*/\n", + "\tmatmulForward(acts.Logits.data, acts.LayerNormFinal.data, params.WordTokEmbed.data, nil, B, T, C, V)\n", + "\t/*\n", + "\t\tAfter all of this we can softmax the logits to get probabilities over the entire vocabulary\n", + "\t*/\n", + "\tsoftmaxForward(acts.Probabilities.data, acts.Logits.data, B, T, V)\n", + "\t// also forward the cross-entropy loss function if we have the targets\n", + "\tif len(target) > 0 {\n", + "\t\t/*\n", + "\t\t\tThis compares the probabilities for each token and compares it to the target to calculate a loss.\n", + "\t\t*/\n", + "\t\tcrossEntropyForward(model.Acts.Losses.data, model.Acts.Probabilities.data, target, B, T, V)\n", + "\t\t// for convenience also evaluate the mean loss\n", + "\t\tvar meanLoss float32\n", + "\t\tfor i := range model.Acts.Losses.data {\n", + "\t\t\tmeanLoss += model.Acts.Losses.data[i]\n", + "\t\t}\n", + "\t\tmeanLoss /= float32(B * T)\n", + "\t\tmodel.MeanLoss = meanLoss\n", + "\n", + "\t} else {\n", + "\t\tmodel.MeanLoss = -1.0\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "cb2fe3e3-d8ff-40a0-98e5-a9d8666f380d", + "metadata": {}, + "source": [ + "# sampleMult" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "1b21fd77-d192-4d1a-9a9d-14f36ba3d812", + "metadata": {}, + "outputs": [], + "source": [ + "func sampleMult(probabilities []float32, coin float32) int {\n", + "\tvar cdf float32\n", + "\tfor i, prob := range probabilities {\n", + "\t\tcdf += prob\n", + "\t\tif coin < cdf {\n", + "\t\t\treturn i\n", + "\t\t}\n", + "\t}\n", + "\treturn len(probabilities) - 1\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "b539b57e-025a-4f73-94cf-02ac09860f01", + "metadata": {}, + "outputs": [], + "source": [ + "func (model *GPT2) Inference(input string, B, T int) (string, error) {\n", + "\ttokens, err := model.Tokenizer.Encode(input)\n", + "\tif err != nil {\n", + "\t\treturn \"\", err\n", + "\t}\n", + "\tif len(tokens) < T {\n", + "\t\tfor i := len(tokens); i <= T; i++ {\n", + "\t\t\ttokens = append(tokens, model.Config.EOT)\n", + "\t\t}\n", + "\t}\n", + "\tfmt.Printf(\"input is %d tokens long\\n\", len(tokens))\n", + "\tmodel.Forward(tokens, tokens[1:], B, T)\n", + "\tgenTokens := make([]int32, B*T)\n", + "\tfor i := 0; i < B*T; i++ {\n", + "\t\tgenTokens[i] = model.Config.EOT\n", + "\t}\n", + "\tfor t := 0; t < B*T; t++ {\n", + "\t\tfmt.Printf(\"generating token: %d\\n\", t)\n", + "\t\t// for each t, we re-compute all activations between 0 and t\n", + "\t\t// leaving this alone because you want separate code for inference anyway\n", + "\t\t// the inference here is just for sanity checking purposes\n", + "\t\tmodel.Forward(genTokens, nil, B, t)\n", + "\t\tprobabilities := model.Acts.Probabilities.data[(t-1)*model.Config.V:]\n", + "\t\tcoin := model.Rand.Float32()\n", + "\t\tnextToken2 := sampleMult(probabilities, coin)\n", + "\t\tgenTokens[t] = rune(nextToken2)\n", + "\t}\n", + "\tif model.Tokenizer.init {\n", + "\t\treturn model.Tokenizer.Decode(genTokens)\n", + "\t}\n", + "\treturn \"\", errors.New(\"tokenizer not initialised\")\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "ea5451cb-3790-4935-a7cb-7238d0b739b3", + "metadata": {}, + "outputs": [], + "source": [ + "func newGPT2(MaxSeqLen, V, L, NH, C int, vocab []string) GPT2 {\n", + "\tmodel := GPT2{\n", + "\t\tConfig: GPT2Config{\n", + "\t\t\tMaxSeqLen: MaxSeqLen,\n", + "\t\t\tV: V,\n", + "\t\t\tL: L,\n", + "\t\t\tNH: NH,\n", + "\t\t\tC: C,\n", + "\t\t},\n", + "\t\tParams: newParameterTensors(V, C, MaxSeqLen, L),\n", + "\t\tTokenizer: newTokenizer(vocab),\n", + "\t\tRand: rand.New(rand.NewSource(21)),\n", + "\t}\n", + "\treturn model\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "92b53304-1961-48ee-93d3-b679c68526fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestLoadGPT2Model\n", + "=== RUN TestLoadGPT2Model/#00\n", + "input is 4 tokens long\n", + "generating token: 0\n", + "inference time took: 31.292µs\n", + "--- FAIL: TestLoadGPT2Model (0.00s)\n", + " --- FAIL: TestLoadGPT2Model/#00 (0.00s)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "panic: runtime error: slice bounds out of range [-3:] [recovered]\n", + "\tpanic: runtime error: slice bounds out of range [-3:]\n", + "\n", + "goroutine 7 [running]:\n", + "testing.tRunner.func1.2({0x100e1cfa0, 0x14000184018})\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1631 +0x1c4\n", + "testing.tRunner.func1()\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1634 +0x33c\n", + "panic({0x100e1cfa0?, 0x14000184018?})\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/runtime/panic.go:770 +0x124\n", + "gonb_706a570e.(*GPT2).Inference(0x14000172f70, {0x100d7984e, 0x4}, 0x1, 0x2)\n", + "\t \u001b[7m[[ Cell [25] Line 28 ]]\u001b[0m /var/folders/b_/lv3cnbp904q9_0ndh5mkmp2h0000gn/T/gonb_706a570e/main_test.go:507 +0x4e4\n", + "gonb_706a570e.TestLoadGPT2Model.func1(0x140001169c0)\n", + "\t \u001b[7m[[ Cell [49] Line 29 ]]\u001b[0m /var/folders/b_/lv3cnbp904q9_0ndh5mkmp2h0000gn/T/gonb_706a570e/main_test.go:1119 +0x9c\n", + "testing.tRunner(0x140001169c0, 0x14000138380)\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1689 +0xec\n", + "created by testing.(*T).Run in goroutine 6\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1742 +0x318\n", + "exit status 2\n" + ] + } + ], + "source": [ + "%test\n", + "func TestLoadGPT2Model(t *testing.T) {\n", + "\ttests := []struct {\n", + "\t\tname string\n", + "\t\tmaxSeqLen int\n", + "\t\tv int\n", + "\t\tl int\n", + "\t\tnh int\n", + "\t\tc int\n", + "\t\tvocab []string\n", + "\t\tinput string\n", + "\t\toutput string\n", + "\t}{\n", + "\t\t{\n", + "\t\t\tname: \"\",\n", + "\t\t\tmaxSeqLen: 3,\n", + "\t\t\tv: 3,\n", + "\t\t\tl: 2,\n", + "\t\t\tnh: 1,\n", + "\t\t\tc: 1,\n", + "\t\t\tvocab: []string{\"a\", \"b\", \"c\"},\n", + "\t\t\tinput: \"abcd\",\n", + "\t\t\toutput: \"acc\",\n", + "\t\t},\n", + "\t}\n", + "\tfor _, tt := range tests {\n", + "\t\tt.Run(tt.name, func(t *testing.T) {\n", + "\t\t\tmodel := newGPT2(tt.maxSeqLen, tt.v, tt.l, tt.nh, tt.c, tt.vocab)\n", + "\t\t\toutput, err := model.Inference(tt.input, 1, 2)\n", + "\t\t\tassert.NoError(t, err)\n", + "\t\t\tprintln(output)\n", + "\t\t})\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "639850ea-2c6f-4939-8893-41cbb91e677e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== RUN TestInference\n", + "input is 10 tokens long\n", + "generating token: 0\n", + "inference time took: 226.473625ms\n", + "--- FAIL: TestInference (0.72s)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "panic: runtime error: slice bounds out of range [-50257:] [recovered]\n", + "\tpanic: runtime error: slice bounds out of range [-50257:]\n", + "\n", + "goroutine 6 [running]:\n", + "testing.tRunner.func1.2({0x104eccf40, 0x1403c6dc018})\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1631 +0x1c4\n", + "testing.tRunner.func1()\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1634 +0x33c\n", + "panic({0x104eccf40?, 0x1403c6dc018?})\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/runtime/panic.go:770 +0x124\n", + "gonb_706a570e.(*GPT2).Inference(0x1400011f008, {0x0, 0x0}, 0x1, 0x9)\n", + "\t \u001b[7m[[ Cell [25] Line 28 ]]\u001b[0m /var/folders/b_/lv3cnbp904q9_0ndh5mkmp2h0000gn/T/gonb_706a570e/main_test.go:411 +0x4e4\n", + "gonb_706a570e.TestInference(0x14000116b60)\n", + "\t \u001b[7m[[ Cell [28] Line 6 ]]\u001b[0m /var/folders/b_/lv3cnbp904q9_0ndh5mkmp2h0000gn/T/gonb_706a570e/main_test.go:849 +0x78\n", + "testing.tRunner(0x14000116b60, 0x104ee2138)\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1689 +0xec\n", + "created by testing.(*T).Run in goroutine 1\n", + "\t/opt/homebrew/Cellar/go/1.22.2/libexec/src/testing/testing.go:1742 +0x318\n", + "exit status 2\n" + ] + } + ], + "source": [ + "%test\n", + "func TestInference(b *testing.T) {\n", + "\trandomText := \"\"\n", + "\tmodel, err := LoadGPT2Model(\"./gpt2_124M.bin\", \"./gpt2_tokenizer.bin\")\n", + "\trequire.NoError(b, err)\n", + " output, err := model.Inference(randomText, 1, 9)\n", + " require.NoError(b, err)\n", + " b.Log(output)\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "b4e65fbf-a631-457c-876d-450c4fe5af70", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "hello world" + ] + } + ], + "source": [ + "%main\n", + "\n", + "print(\"hello world\")" + ] + }, + { + "cell_type": "markdown", + "id": "d5cc94c7-a0e0-4e06-adff-4ed170b3fef1", + "metadata": {}, + "source": [ + "# Backward Pass" + ] + }, + { + "cell_type": "markdown", + "id": "32464a08-e511-4e6b-9e6b-9a500f8c4810", + "metadata": {}, + "source": [ + "# crossentropySoftmaxBackward\n", + "The function computes the gradients of the logits (dlogits) with respect to the loss, given the probabilities (probs) and target labels (targets).\n", + "This gradient information is used during backpropagation to update the weights and biases of the network to minimize the cross-entropy loss." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "9e966913-1023-4307-b419-51a0642c618e", + "metadata": {}, + "outputs": [], + "source": [ + "// crossentropySoftmaxBackward calculates the cross entropy\n", + "func crossentropySoftmaxBackward(dlogits, dlosses, probs []float32, targets []int32, B, T, V int) {\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\tbaseIndex := b*T*V + t*V\n", + "\t\t\tdlogitsBT := dlogits[baseIndex : baseIndex+V]\n", + "\t\t\tprobsBT := probs[baseIndex : baseIndex+V]\n", + "\t\t\tdloss := dlosses[b*T+t]\n", + "\t\t\tix := targets[b*T+t]\n", + "\t\t\tfor i := 0; i < V; i++ {\n", + "\t\t\t\tp := probsBT[i]\n", + "\t\t\t\tvar indicator float32\n", + "\t\t\t\tif int32(i) == ix {\n", + "\t\t\t\t\tindicator = 1.0\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tindicator = 0.0\n", + "\t\t\t\t}\n", + "\t\t\t\tdlogitsBT[i] += (p - indicator) * dloss\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "a77e28fd-34bd-40ac-aa0b-accff0f91e3b", + "metadata": {}, + "source": [ + "# matmulBackward\n", + "\n", + "The function computes the gradients of the inputs (dinp), weights (dweight), and biases (dbias) for a matrix multiplication operation. These gradients are necessary for adjusting the model parameters during training to minimize the error." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "7529a3c7-6a19-4efe-93c2-41fa750cb3b8", + "metadata": {}, + "outputs": [], + "source": [ + "func matmulBackward(dinp, dweight, dbias, dout, inp, weight []float32, B, T, C, OC int) {\n", + "\tvar wg sync.WaitGroup\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\twg.Add(1)\n", + "\t\t\tgo func(b, t int) {\n", + "\t\t\t\tdefer wg.Done()\n", + "\t\t\t\tdoutBt := dout[b*T*OC+t*OC:]\n", + "\t\t\t\tdinpBt := dinp[b*T*C+t*C:]\n", + "\t\t\t\tfor o := 0; o < OC; o++ {\n", + "\t\t\t\t\twrow := weight[o*C:]\n", + "\t\t\t\t\td := doutBt[o]\n", + "\t\t\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t\t\tdinpBt[i] += wrow[i] * d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}(b, t)\n", + "\t\t}\n", + "\t}\n", + "\twg.Wait()\n", + "\tfor o := 0; o < OC; o++ {\n", + "\t\twg.Add(1)\n", + "\t\tgo func(o int) {\n", + "\t\t\tdefer wg.Done()\n", + "\t\t\tfor b := 0; b < B; b++ {\n", + "\t\t\t\tfor t := 0; t < T; t++ {\n", + "\t\t\t\t\tdoutBt := dout[b*T*OC+t*OC:]\n", + "\t\t\t\t\tinpBt := inp[b*T*C+t*C:]\n", + "\t\t\t\t\tdwrow := dweight[o*C:]\n", + "\t\t\t\t\td := doutBt[o]\n", + "\t\t\t\t\tif dbias != nil {\n", + "\t\t\t\t\t\tdbias[o] += d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t\t\tdwrow[i] += inpBt[i] * d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t}(o)\n", + "\t}\n", + "\twg.Wait()\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "4286368c-93e7-48e0-8a6b-41800301c081", + "metadata": {}, + "source": [ + "# layernormBackward\n", + "The function layernormBackward calculates the gradients for the backward pass of a Layer Normalization (LayerNorm) operation in a neural network. Here's a breakdown of what it does:\n", + "\n", + "Layer Normalization is a technique used to normalize the activations of a layer across its features, improving the training stability and performance of deep neural networks. It involves normalizing the input to have zero mean and unit variance. This function calculates the gradients needed to update the weights and biases of the LayerNorm operation during backpropagation." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "b13930e2-bdfd-4724-8213-af1087ae32a6", + "metadata": {}, + "outputs": [], + "source": [ + "func layernormBackward(dinp, dweight, dbias, dout, inp, weight, mean, rstd []float32, B, T, C int) {\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\tbaseIndex := b*T*C + t*C\n", + "\t\t\tdoutBT := dout[baseIndex : baseIndex+C]\n", + "\t\t\tinpBT := inp[baseIndex : baseIndex+C]\n", + "\t\t\tdinpBT := dinp[baseIndex : baseIndex+C]\n", + "\t\t\tmeanBT := mean[b*T+t]\n", + "\t\t\trstdBT := rstd[b*T+t]\n", + "\n", + "\t\t\t// Reduce operations\n", + "\t\t\tvar dnormMean float32 = 0.0\n", + "\t\t\tvar dnormNormMean float32 = 0.0\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\tnormBTI := (inpBT[i] - meanBT) * rstdBT\n", + "\t\t\t\tdnormI := weight[i] * doutBT[i]\n", + "\t\t\t\tdnormMean += dnormI\n", + "\t\t\t\tdnormNormMean += dnormI * normBTI\n", + "\t\t\t}\n", + "\t\t\tdnormMean /= float32(C)\n", + "\t\t\tdnormNormMean /= float32(C)\n", + "\n", + "\t\t\t// Accumulation loop\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\tnormBTI := (inpBT[i] - meanBT) * rstdBT\n", + "\t\t\t\tdnormI := weight[i] * doutBT[i]\n", + "\t\t\t\tdbias[i] += doutBT[i]\n", + "\t\t\t\tdweight[i] += normBTI * doutBT[i]\n", + "\n", + "\t\t\t\tvar dval float32\n", + "\t\t\t\tdval += dnormI // Term 1\n", + "\t\t\t\tdval -= dnormMean // Term 2\n", + "\t\t\t\tdval -= normBTI * dnormNormMean // Term 3\n", + "\t\t\t\tdval *= rstdBT // Final scale\n", + "\t\t\t\tdinpBT[i] += dval\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "ed355458-db9b-4d6a-8521-c3a6870ebfa8", + "metadata": {}, + "source": [ + "# residualBackward\n", + "The function residualBackward calculates the gradients for the backward pass of a residual connection in a neural network. Here's a breakdown of what it does:" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "c7868ba4-d181-4ee1-80e6-6cedd23944a9", + "metadata": {}, + "outputs": [], + "source": [ + "func residualBackward(dinp1, dinp2, dout []float32, N int) {\n", + "\tfor i := 0; i < N; i++ {\n", + "\t\tdinp1[i] += dout[i]\n", + "\t\tdinp2[i] += dout[i]\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "5ccf94ba-a185-4f5e-be9f-befb7c326ffa", + "metadata": {}, + "source": [ + "# geluBackward\n", + "\n", + " computes the gradient of the Gaussian Error Linear Unit (GELU) activation function for backpropagation in a neural network." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "037bb57f-2a32-473e-be3e-986e7aa12853", + "metadata": {}, + "outputs": [], + "source": [ + "// geluBackward computes the backward pass of the GeLU non-linearity\n", + "func geluBackward(dinp, inp, dout []float32, n int) {\n", + "\tfor i := 0; i < n; i++ {\n", + "\t\tx := inp[i]\n", + "\t\tcube := 0.044715 * x * x * x\n", + "\t\ttanhArg := GELUSCALEFACTOR * (x + cube)\n", + "\t\ttanhOut := Tanh(tanhArg)\n", + "\t\tcoshfOut := Cosh(tanhArg)\n", + "\t\tsechOut := 1.0 / (coshfOut * coshfOut)\n", + "\t\tlocalGrad := 0.5*(1.0+tanhOut) + x*0.5*sechOut*GELUSCALEFACTOR*(1.0+3.0*0.044715*x*x)\n", + "\t\tdinp[i] += localGrad * dout[i]\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "2188cdbd-518c-4576-b0e2-9c528ec22d60", + "metadata": {}, + "source": [ + "# attentionBackward\n", + "The attentionBackward function implements the backward pass for a self-attention mechanism in a neural network. This is a crucial part of training attention-based models, like transformers. It calculates the gradients of the attention weights, queries, keys, and values with respect to the outputs of the attention layer, allowing the model to adjust its parameters to improve performance." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "f833485b-57c3-4f6a-bfab-c1ed0cd6e449", + "metadata": {}, + "outputs": [], + "source": [ + "// attentionBackward performs the backward pass for an attention mechanism\n", + "func attentionBackward(dinp, dpreatt, datt, dout, inp, att []float32, B, T, C, NH int) {\n", + "\t// C3 is 3 times C, representing the size of Q, K, and V combined\n", + "\tC3 := C * 3\n", + "\t// hs is the size of each head\n", + "\ths := C / NH\n", + "\t// scale is the factor used in the forward pass to scale the dot product\n", + "\tscale := 1.0 / Sqrt(float32(hs))\n", + "\t// Iterate through batch, time, and heads\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\tfor h := 0; h < NH; h++ {\n", + "\t\t\t\t// Calculate the indices for the arrays in this specific iteration\n", + "\t\t\t\tattBTH := att[b*NH*T*T+h*T*T+t*T:]\n", + "\t\t\t\tdattBTH := datt[b*NH*T*T+h*T*T+t*T:]\n", + "\t\t\t\tdpreattBTH := dpreatt[b*NH*T*T+h*T*T+t*T:]\n", + "\t\t\t\tdqueryT := dinp[b*T*C3+t*C3+h*hs:]\n", + "\t\t\t\tqueryT := inp[b*T*C3+t*C3+h*hs:]\n", + "\t\t\t\t// Backward pass 4: value accumulation\n", + "\t\t\t\tdoutBTH := dout[b*T*C+t*C+h*hs:]\n", + "\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\tvalueT2 := inp[b*T*C3+t2*C3+h*hs+C*2:]\n", + "\t\t\t\t\tdvalueT2 := dinp[b*T*C3+t2*C3+h*hs+C*2:]\n", + "\t\t\t\t\tfor i := 0; i < hs; i++ {\n", + "\t\t\t\t\t\t// Compute gradients for attention and value accumulation\n", + "\t\t\t\t\t\tdattBTH[t2] += valueT2[i] * doutBTH[i]\n", + "\t\t\t\t\t\tdvalueT2[i] += attBTH[t2] * doutBTH[i]\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t// Backward pass 2 & 3: softmax backward\n", + "\t\t\t\t// Softmax does not require input (preatt) to backward\n", + "\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\tfor t3 := 0; t3 <= t; t3++ {\n", + "\t\t\t\t\t\tvar indicator float32\n", + "\t\t\t\t\t\tif t2 == t3 {\n", + "\t\t\t\t\t\t\tindicator = 1.0\n", + "\t\t\t\t\t\t}\n", + "\t\t\t\t\t\tlocalDerivative := attBTH[t2] * (indicator - attBTH[t3])\n", + "\t\t\t\t\t\tdpreattBTH[t3] += localDerivative * dattBTH[t2]\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t\t// Backward pass 1: query @ key matmul\n", + "\t\t\t\tfor t2 := 0; t2 <= t; t2++ {\n", + "\t\t\t\t\tkeyT2 := inp[b*T*C3+t2*C3+h*hs+C:]\n", + "\t\t\t\t\tdkeyT2 := dinp[b*T*C3+t2*C3+h*hs+C:]\n", + "\t\t\t\t\tfor i := 0; i < hs; i++ {\n", + "\t\t\t\t\t\t// Compute gradients for query and key\n", + "\t\t\t\t\t\tdqueryT[i] += keyT2[i] * dpreattBTH[t2] * scale\n", + "\t\t\t\t\t\tdkeyT2[i] += queryT[i] * dpreattBTH[t2] * scale\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "4745fcbc-d3be-4c4b-9679-83a326f0c912", + "metadata": {}, + "source": [ + "# matmulBackward\n", + "The function computes the gradients of the inputs (dinp), weights (dweight), and biases (dbias) for a matrix multiplication operation. These gradients are necessary for adjusting the model parameters during training to minimize the error.\n", + "\n", + "\n", + "dinp: A slice of floats representing the gradients of the outputs with respect to the inputs of the matrix multiplication. This is often calculated by the subsequent layer in the network.\n", + "dweight: A slice of floats representing the gradients of the outputs with respect to the weights. Initially, this slice is filled with zeros.\n", + "dbias: A slice of floats representing the gradients of the outputs with respect to the biases. Initially, this slice is filled with zeros.\n", + "dout: A slice of floats representing the outputs of the matrix multiplication.\n", + "inp: A slice of floats representing the inputs to the matrix multiplication.\n", + "weight: A slice of floats representing the weights of the matrix multiplication.\n", + "B: The batch size (number of samples).\n", + "T: The time steps or sequence length.\n", + "C: The number of input features.\n", + "OC: The number of output features." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "d6ddc99a-0b48-43aa-a3c6-76b824a84dc2", + "metadata": {}, + "outputs": [], + "source": [ + "func matmulBackward(dinp, dweight, dbias, dout, inp, weight []float32, B, T, C, OC int) {\n", + "\tvar wg sync.WaitGroup\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\twg.Add(1)\n", + "\t\t\tgo func(b, t int) {\n", + "\t\t\t\tdefer wg.Done()\n", + "\t\t\t\tdoutBt := dout[b*T*OC+t*OC:]\n", + "\t\t\t\tdinpBt := dinp[b*T*C+t*C:]\n", + "\t\t\t\tfor o := 0; o < OC; o++ {\n", + "\t\t\t\t\twrow := weight[o*C:]\n", + "\t\t\t\t\td := doutBt[o]\n", + "\t\t\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t\t\tdinpBt[i] += wrow[i] * d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}(b, t)\n", + "\t\t}\n", + "\t}\n", + "\twg.Wait()\n", + "\tfor o := 0; o < OC; o++ {\n", + "\t\twg.Add(1)\n", + "\t\tgo func(o int) {\n", + "\t\t\tdefer wg.Done()\n", + "\t\t\tfor b := 0; b < B; b++ {\n", + "\t\t\t\tfor t := 0; t < T; t++ {\n", + "\t\t\t\t\tdoutBt := dout[b*T*OC+t*OC:]\n", + "\t\t\t\t\tinpBt := inp[b*T*C+t*C:]\n", + "\t\t\t\t\tdwrow := dweight[o*C:]\n", + "\t\t\t\t\td := doutBt[o]\n", + "\t\t\t\t\tif dbias != nil {\n", + "\t\t\t\t\t\tdbias[o] += d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t\t\tdwrow[i] += inpBt[i] * d\n", + "\t\t\t\t\t}\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t}(o)\n", + "\t}\n", + "\twg.Wait()\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "80081fb0-0698-4190-bcba-6a68aaa43477", + "metadata": {}, + "source": [ + "# encoderBackward\n", + "encoderBackward calculates gradients during backpropagation\n", + "Parameters:\n", + " - dwte: gradients with respect to word embeddings (wte)\n", + " - dwpe: gradients with respect to positional embeddings (wpe)\n", + " - dout: the gradient to apply to dwte and dwpe\n", + " - inp: input tokens (ids that refer to indexes within wte)\n", + " - B: batch size\n", + " - T: sequence length (number of time steps)\n", + " - C: embedding dimension (number of features)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "e38b3ac4-ff9f-439b-ac9a-9fbedd9b1bbe", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "func encoderBackward(dwte, dwpe []float32, dout []float32, inp []int32, B, T, C int) {\n", + "\t// Iterate over the batch and time steps\n", + "\tfor b := 0; b < B; b++ {\n", + "\t\tfor t := 0; t < T; t++ {\n", + "\t\t\t// Calculate offsets for indexing\n", + "\t\t\tdoutBTOffset := b*T*C + t*C\n", + "\t\t\tix := inp[b*T+t] // Get the input token id\n", + "\t\t\tdwteIxOffset := ix * int32(C) // Calculate the offset for dwte\n", + "\t\t\tdwpeTOffset := t * C // Calculate the offset for dwpe\n", + "\n", + "\t\t\t// Iterate over the embedding dimension and apply computations\n", + "\t\t\tfor i := 0; i < C; i++ {\n", + "\t\t\t\t// Get the gradient value from dout\n", + "\t\t\t\td := dout[doutBTOffset+i]\n", + "\t\t\t\t// Update the gradients for word embeddings (dwte) and positional embeddings (dwpe)\n", + "\t\t\t\tdwte[dwteIxOffset+int32(i)] += d\n", + "\t\t\t\tdwpe[dwpeTOffset+i] += d\n", + "\t\t\t}\n", + "\t\t}\n", + "\t}\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "8580eaa8-af05-4f51-b83b-94d9c44c5de1", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "func (model *GPT2) ZeroGradient() {\n", + "\tfor i := range model.GradsActs.Memory {\n", + "\t\tmodel.GradsActs.Memory[i] = 0.0\n", + "\t}\n", + "\tfor i := range model.Grads.Memory {\n", + "\t\tmodel.Grads.Memory[i] = 0.0\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "275503ff-b344-4f1c-9619-b3c9c1dc9956", + "metadata": {}, + "outputs": [], + "source": [ + "func (model *GPT2) Update(learningRate, beta1, beta2, eps, weightDecay float32, t int) {\n", + "\t// Lazy memory allocation\n", + "\tif model.MMemory == nil {\n", + "\t\tmodel.MMemory = make([]float32, model.Params.Len())\n", + "\t\tmodel.VMemory = make([]float32, model.Params.Len())\n", + "\t}\n", + "\t// Parameter updates\n", + "\tfor i := 0; i < model.Params.Len(); i++ {\n", + "\t\tparameter := model.Params.Memory[i]\n", + "\t\tgradient := model.Grads.Memory[i]\n", + "\t\t// Momentum update\n", + "\t\tm := beta1*model.MMemory[i] + (1.0-beta1)*gradient\n", + "\t\t// RMSprop update\n", + "\t\tv := beta2*model.VMemory[i] + (1.0-beta2)*gradient*gradient\n", + "\t\t// Bias correction\n", + "\t\tmHat := m / (1.0 - Pow(beta1, float32(t)))\n", + "\t\tvHat := v / (1.0 - Pow(beta2, float32(t)))\n", + "\t\t// Parameter update\n", + "\t\tmodel.MMemory[i] = m\n", + "\t\tmodel.VMemory[i] = v\n", + "\t\tmodel.Params.Memory[i] -= learningRate * (mHat/(Sqrt(vHat)+eps) + weightDecay*parameter)\n", + "\t}\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "ada64c1b-6e14-45cd-99c7-b04139e07726", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "func (model *GPT2) Backward() error {\n", + "\t//// double check we forwarded previously, with targets\n", + "\tif model.MeanLoss == -1.0 {\n", + "\t\treturn errors.New(\"error: must forward with targets before backward\")\n", + "\t}\n", + "\t// lazily allocate the memory for gradients of the weights and activations, if needed\n", + "\t// convenience shortcuts\n", + "\tB, T, V, L, NH, C := model.B, model.T, model.Config.V, model.Config.L, model.Config.NH, model.Config.C\n", + "\tif len(model.Grads.Memory) == 0 {\n", + "\t\tmodel.Grads.Init(V, C, model.Config.MaxSeqLen, L)\n", + "\t\tmodel.GradsActs.Init(B, C, T, L, NH, V)\n", + "\t\tmodel.ZeroGradient()\n", + "\t}\n", + "\t// backward pass\n", + "\tparams, grads, acts, gradsActs := model.Params, model.Grads, model.Acts, model.GradsActs\n", + "\t// we kick off the chain by filling in dlosses with 1.0f/(B*T), to get the mean loss\n", + "\tdlossMean := 1.0 / float32(B*T)\n", + "\tfor i := range gradsActs.Losses.data {\n", + "\t\tgradsActs.Losses.data[i] = dlossMean\n", + "\t}\n", + "\tcrossentropySoftmaxBackward(gradsActs.Logits.data, gradsActs.Losses.data, acts.Probabilities.data, model.Targets, B, T, V)\n", + "\tmatmulBackward(gradsActs.LayerNormFinal.data, grads.WordTokEmbed.data, nil, gradsActs.Logits.data, acts.LayerNormFinal.data, params.WordTokEmbed.data, B, T, C, V)\n", + "\tresidual := acts.Residual3.data[(L-1)*B*T*C:] // last layer's residual\n", + "\tdresidual := gradsActs.Residual3.data[(L-1)*B*T*C:] // write to last layer's residual\n", + "\tlayernormBackward(dresidual, grads.LayerFinNormW.data, grads.LayerFinNormB.data, gradsActs.LayerNormFinal.data, residual, params.LayerFinNormW.data, acts.LayerNormFinalMean.data, acts.LayerNormFinalStd.data, B, T, C)\n", + "\tfor l := L - 1; l >= 0; l-- {\n", + "\t\tif l == 0 {\n", + "\t\t\tresidual = acts.Encoded.data\n", + "\t\t\tdresidual = gradsActs.Encoded.data\n", + "\t\t} else {\n", + "\t\t\tresidual = acts.Residual3.data[(l-1)*B*T*C:]\n", + "\t\t\tdresidual = gradsActs.Residual3.data[(l-1)*B*T*C:]\n", + "\t\t}\n", + "\n", + "\t\t// Assuming you have a 'params' variable of your ParameterTensors type\n", + "\t\tl_ln1w := params.LayerNorm1W.data[l*C:]\n", + "\t\tl_qkvw := params.QueryKeyValW.data[l*3*C*C:]\n", + "\t\tl_attprojw := params.AttProjW.data[l*C*C:]\n", + "\t\tl_ln2w := params.Layer2NormW.data[l*C:]\n", + "\t\tl_fcw := params.FeedFwdW.data[l*4*C*C:]\n", + "\t\tl_fcprojw := params.FeedFwdProjW.data[l*C*4*C:]\n", + "\t\t// Gradients of weights\n", + "\t\tdl_ln1w := grads.LayerNorm1W.data[l*C:]\n", + "\t\tdl_ln1b := grads.LayerNorm1B.data[l*C:]\n", + "\t\tdl_qkvw := grads.QueryKeyValW.data[l*3*C*C:]\n", + "\t\tdl_qkvb := grads.QueryKeyValB.data[l*3*C:]\n", + "\t\tdl_attprojw := grads.AttProjW.data[l*C*C:]\n", + "\t\tdl_attprojb := grads.AttProjB.data[l*C:]\n", + "\t\tdl_ln2w := grads.Layer2NormW.data[l*C:]\n", + "\t\tdl_ln2b := grads.Layer2NormB.data[l*C:]\n", + "\t\tdl_fcw := grads.FeedFwdW.data[l*4*C*C:]\n", + "\t\tdl_fcb := grads.FeedFwdB.data[l*4*C:]\n", + "\t\tdl_fcprojw := grads.FeedFwdProjW.data[l*C*4*C:]\n", + "\t\tdl_fcprojb := grads.FeedFwdProjB.data[l*C:]\n", + "\t\t// Activations\n", + "\t\tl_ln1 := acts.Layer1Act.data[l*B*T*C:]\n", + "\t\tl_ln1_mean := acts.LayerNorm1Mean.data[l*B*T:]\n", + "\t\tl_ln1_rstd := acts.LayerNorm1Rstd.data[l*B*T:]\n", + "\t\tl_qkv := acts.QueryKeyVal.data[l*B*T*3*C:]\n", + "\t\tl_atty := acts.AttentionInter.data[l*B*T*C:]\n", + "\t\tl_att := acts.Attention.data[l*B*NH*T*T:]\n", + "\t\tl_residual2 := acts.Residual2.data[l*B*T*C:]\n", + "\t\tl_ln2 := acts.LayerNorm2Act.data[l*B*T*C:]\n", + "\t\tl_ln2_mean := acts.LayerNorm2Mean.data[l*B*T:]\n", + "\t\tl_ln2_rstd := acts.LayerNorm2Rstd.data[l*B*T:]\n", + "\t\tl_fch := acts.FeedForward.data[l*B*T*4*C:]\n", + "\t\tl_fch_gelu := acts.FeedForwardGelu.data[l*B*T*4*C:]\n", + "\n", + "\t\tdl_ln1 := gradsActs.Layer1Act.data[l*B*T*C:]\n", + "\t\tdl_qkv := gradsActs.QueryKeyVal.data[l*B*T*3*C:]\n", + "\t\tdl_atty := gradsActs.AttentionInter.data[l*B*T*C:]\n", + "\t\tdl_preatt := gradsActs.PreAttention.data[l*B*NH*T*T:]\n", + "\t\tdl_att := gradsActs.Attention.data[l*B*NH*T*T:]\n", + "\t\tdl_attproj := gradsActs.AttentionProj.data[l*B*T*C:]\n", + "\t\tdl_residual2 := gradsActs.Residual2.data[l*B*T*C:]\n", + "\t\tdl_ln2 := gradsActs.LayerNorm2Act.data[l*B*T*C:]\n", + "\t\tdl_fch := gradsActs.FeedForward.data[l*B*T*4*C:]\n", + "\t\tdl_fch_gelu := gradsActs.FeedForwardGelu.data[l*B*T*4*C:]\n", + "\t\tdl_fcproj := gradsActs.FeedForwardProj.data[l*B*T*C:]\n", + "\t\tdl_residual3 := gradsActs.Residual3.data[l*B*T*C:]\n", + "\t\tresidualBackward(dl_residual2, dl_fcproj, dl_residual3, B*T*C)\n", + "\t\tmatmulBackward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C)\n", + "\t\tgeluBackward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C)\n", + "\t\tmatmulBackward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C)\n", + "\t\tlayernormBackward(dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C)\n", + "\t\tresidualBackward(dresidual, dl_attproj, dl_residual2, B*T*C)\n", + "\t\tmatmulBackward(dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C)\n", + "\t\tattentionBackward(dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH)\n", + "\t\tmatmulBackward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C)\n", + "\t\tlayernormBackward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C)\n", + "\t}\n", + "\t// Here we want to apply our gradients to our encoded data.\n", + "\tencoderBackward(grads.WordTokEmbed.data, grads.WordPosEmbed.data, gradsActs.Encoded.data, model.Inputs, B, T, C)\n", + "\treturn nil\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "8db1366e-a3de-42ff-88d9-aecc8559f7b6", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "func (model *GPT2) Train(valDataloader, trainDataloader *DataLoader, B, T int) error {\n", + "\tfmt.Printf(\"train dataset num_batches: %d\\n\", valDataloader.NumBatches)\n", + "\tconst genMaxLength, valNumBatches = 64, 10\n", + "\tgenTokens := make([]int32, B*T)\n", + "\tfor step := 0; step <= 40; step++ {\n", + "\t\tif step%10 == 0 {\n", + "\t\t\tvar valLoss float32\n", + "\t\t\tvalDataloader.Reset()\n", + "\t\t\tfor i := 0; i < valNumBatches; i++ {\n", + "\t\t\t\tinput, target, err := valDataloader.NextBatch()\n", + "\t\t\t\tif err != nil {\n", + "\t\t\t\t\treturn err\n", + "\t\t\t\t}\n", + "\t\t\t\tmodel.Forward(input, target, B, T)\n", + "\t\t\t\tvalLoss += model.MeanLoss\n", + "\t\t\t}\n", + "\t\t\tvalLoss /= float32(valNumBatches)\n", + "\t\t\tfmt.Printf(\"val loss %f\\n\", valLoss)\n", + "\t\t}\n", + "\t\tif step > 0 && step%20 == 0 {\n", + "\t\t\tfor i := 0; i < B*T; i++ {\n", + "\t\t\t\tgenTokens[i] = model.Config.EOT\n", + "\t\t\t}\n", + "\t\t\tfor t := 1; t < len(genTokens); t++ {\n", + "\t\t\t\t// for each t, we re-compute all activations between 0 and t\n", + "\t\t\t\t// leaving this alone because you want separate code for inference anyway\n", + "\t\t\t\t// the inference here is just for sanity checking purposes\n", + "\t\t\t\tmodel.Forward(genTokens, nil, B, t)\n", + "\t\t\t\tprobabilities := model.Acts.Probabilities.data[(t-1)*model.Config.V:]\n", + "\t\t\t\tcoin := rand.Float32()\n", + "\t\t\t\tnextToken2 := sampleMult(probabilities, coin)\n", + "\t\t\t\tgenTokens[t] = rune(nextToken2)\n", + "\t\t\t}\n", + "\t\t\tfmt.Print(\"generated: \")\n", + "\t\t\tif model.Tokenizer.init {\n", + "\t\t\t\tstr, err := model.Tokenizer.Decode(genTokens)\n", + "\t\t\t\tif err != nil {\n", + "\t\t\t\t\treturn err\n", + "\t\t\t\t}\n", + "\t\t\t\tfmt.Println(str)\n", + "\t\t\t} else {\n", + "\t\t\t\tfmt.Println(genTokens)\n", + "\t\t\t}\n", + "\t\t\tfor t := 0; t < genMaxLength; t++ {\n", + "\t\t\t\tif model.Tokenizer.init {\n", + "\n", + "\t\t\t\t} else {\n", + "\t\t\t\t\tfmt.Printf(\"%d \", genTokens[t])\n", + "\t\t\t\t}\n", + "\t\t\t}\n", + "\t\t\tfmt.Println()\n", + "\t\t}\n", + "\t\t// do a training step\n", + "\t\tstart := time.Now()\n", + "\t\tinput, targets, err := trainDataloader.NextBatch()\n", + "\t\tif err != nil {\n", + "\t\t\treturn err\n", + "\t\t}\n", + "\t\tmodel.Forward(input, targets, B, T)\n", + "\t\tmodel.ZeroGradient()\n", + "\t\tmodel.Backward()\n", + "\t\tmodel.Update(1e-4, 0.9, 0.999, 1e-8, 0.0, step+1)\n", + "\t\tfmt.Printf(\"step %d: train loss %f (took %v ms)\\n\", step, model.MeanLoss, time.Since(start))\n", + "\t}\n", + "\treturn nil\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "71cded43-3dbe-4730-a6a5-62c7352c5ddf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train dataset num_batches: 3615833\n", + "train dataset num_batches: 74389\n", + "val loss 2.929664\n", + "step 0: train loss 3.149784 (took 10.937132667s ms)\n", + "step 1: train loss 3.005568 (took 12.967602333s ms)\n", + "step 2: train loss 2.609987 (took 12.78746925s ms)\n", + "step 3: train loss 2.554043 (took 12.881469875s ms)\n", + "step 4: train loss 2.782657 (took 12.977106375s ms)\n", + "step 5: train loss 2.478727 (took 12.807558625s ms)\n", + "step 6: train loss 2.788728 (took 12.958770958s ms)\n", + "step 7: train loss 2.532619 (took 13.202439292s ms)\n", + "step 8: train loss 2.785676 (took 13.262732s ms)\n", + "step 9: train loss 2.638746 (took 13.39796775s ms)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "signal: interrupt\n" + ] + } + ], + "source": [ + "%main\n", + "model, err := LoadGPT2Model(\"./gpt2_124M.bin\", \"./gpt2_tokenizer.bin\")\n", + "if err != nil {\n", + " log.Fatal(err)\n", + "}\n", + "B, T := 4, 64\n", + "trainDataloader, err := NewDataLoader(\"./TinyStories_train.bin\", B, T)\n", + "if err != nil {\n", + " log.Fatal(err)\n", + "}\n", + "fmt.Printf(\"train dataset num_batches: %d\\n\", trainDataloader.NumBatches)\n", + "valDataloader, err := NewDataLoader(\"./TinyStories_val.bin\", B, T)\n", + "if err != nil {\n", + " log.Fatal(err)\n", + "}\n", + "if err := model.Train(valDataloader, trainDataloader, B, T); err != nil {\n", + " log.Fatal(err)\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "7ce1401a-ba28-4dc2-a396-9b0a8d79307b", + "metadata": {}, + "source": [ + "# References\n", + "\n", + "- [**Attention Is All You Need**](https://arxiv.org/abs/1706.03762v7) - The OG introduced the idea of self-attention and the encoder/decoder architecture for language translation tasks (the encoder later got dropped because it was only used for translation). Another breakthrough from this paper was the training; “The Transformer allows for significantly more parallelisation and can reach a new state of the art in translation quality after being trained for as little as twelve hours on eight P100 GPUs.” - This fact here was what let it: overtake RNNs (which weren’t parallelisable), and lead NVIDIA to be worth more than 2.7 Trillion token credits.\n", + "\n", + "- [**Improving Language Understanding by Generative Pre-Training**](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) -** This paper introduced the “GPT” which was a breakthrough at the time. It introduced the idea of using next token prediction as a way to do self-supervised learning, which meant that we can put all of the internet into it and with a simple loss function over the vocabulary adjust the weights via backpropagation.\n", + "\n", + "- [**Language Models are Unsupervised Multitask Learners**](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) - This is the GPT-2 paper\n", + "\n", + "- [**Fast Transformer Decoding: One Write-Head is All\n", + "You Need**](https://arxiv.org/pdf/1911.02150) - People always point to the original Attention is all you need paper or the GPT paper that introduced the *decoder only* model, but this one was the first one that actually used it in practice.\n", + "\n", + "- [**Layer Normalization**](https://arxiv.org/abs/1607.06450) - Layernorm happens in each layer to make sure that the values don’t explode. I’ve implemented this in the [layernormForward function in llm.go](https://github.com/joshcarp/llm.go/blob/56de2430b95ff3f89657637a4c97794653a994ec/math.go#L81) and it’s pretty neat\n", + "\n", + "- [**Gaussian Error Linear Units (GELUs)**](https://arxiv.org/abs/1606.08415v5) - Activation function that leaves positive values unchanged but maps negative numbers to near zero. Implemented here in [llm.go](https://github.com/joshcarp/llm.go/blob/56de2430b95ff3f89657637a4c97794653a994ec/math.go#L414). Other architectures use different activation functions. For example, OpenElm uses SwiGLU FFN which I don’t exactly understand. Should probably add that to the reading list.\n", + "\n", + "- [**Deep Residual Learning for Image Recognition**](https://arxiv.org/pdf/1512.03385) - The introduction of residuals allowed for deeper networks. Before this paper the depth of a neural network was limited because it would diverge enough and back propagation was really, really difficult to do because of vanishing gradients. Residuals essentially have a “short circuit” past a block which limits how much the neural networks can influence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41a26412-a4e6-41e6-af41-3c12401669e6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Go (gonb)", + "language": "go", + "name": "gonb" + }, + "language_info": { + "codemirror_mode": "", + "file_extension": ".go", + "mimetype": "", + "name": "go", + "nbconvert_exporter": "", + "pygments_lexer": "", + "version": "go1.22.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}