Skip to content

Commit

Permalink
Add init command to load base .bin files (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertpurnama authored May 10, 2024
1 parent 9e96932 commit 56de243
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 6 deletions.
114 changes: 108 additions & 6 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,36 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"time"

"github.com/spf13/cobra"
)

var tok Tokenizer
var err error
// CLI global variables
var (
cacheDir string
err error

// this is the base folder that consist of basic tokenizer
// and basic model + its weights
DEFAULT_BASE_FOLDER_NAME = "base"

// basic tokenizer and model
basicTokenizerURL string = "https://huggingface.co/joshcarp/llm.go/resolve/main/gpt2_tokenizer.bin"
basicModelURL string = "https://huggingface.co/joshcarp/llm.go/resolve/main/gpt2_124M.bin"
basicModelDebugURL string = "https://huggingface.co/joshcarp/llm.go/resolve/main/gpt2_124M_debug_state.bin"
)

// initializeCacheDir initializes the cache directory
func initializeCacheDir() {
homeDir, err := os.UserHomeDir()
if err != nil {
panic(err)
}
cacheDir = homeDir + "/.cache/llmgo"
os.MkdirAll(cacheDir, os.ModePerm) // Ensure the directory exists
}

// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Expand All @@ -22,6 +45,78 @@ var rootCmd = &cobra.Command{
`,
}

var initCmd = &cobra.Command{
Use: "init",
Short: "Initialize the GPT-2",
Long: `This command loads tokenizer and base model from Huggingface, intializes internal variables for use on inference/training`,
Run: func(cmd *cobra.Command, args []string) {
// check whether there's base configuration directory.
// if not, create it
baseConfigPath := filepath.Join(cacheDir, DEFAULT_BASE_FOLDER_NAME)
files, err := os.ReadDir(baseConfigPath)
if err != nil {
// create the base config file
if err := os.MkdirAll(baseConfigPath, os.ModePerm); err != nil {
fmt.Println("failed to create base configuration directory: %w", err)
return
}
}

// If somehow the base configuration directory is created,
// but if any of the base files is not found, download them
tokenizerExistAndValid := false
modelExistAndValid := false
modelDebugStateExistAndValid := false

// list files in the cache directory
// if the tokenizer/model is not in the cache
// download it from Huggingface
for _, file := range files {
if file.Name() == filepath.Base(basicTokenizerURL) {
_, err = NewTokenizer(filepath.Join(cacheDir, DEFAULT_BASE_FOLDER_NAME, file.Name()))
if err != nil {
continue
}
tokenizerExistAndValid = true
}

if file.Name() == filepath.Base(basicModelURL) {
// TODO: validate the model
modelExistAndValid = true
}

if file.Name() == filepath.Base(basicModelDebugURL) {
// TODO: validate the model debug state file
modelDebugStateExistAndValid = true
}
}

if !tokenizerExistAndValid {
fmt.Println("Tokenizer not found, downloading...")
if err := downloadFromHF(filepath.Join(baseConfigPath, filepath.Base(basicTokenizerURL)), basicTokenizerURL); err != nil {
fmt.Println("failed to download tokenizer: %w", err)
return
}
}

if !modelExistAndValid {
fmt.Println("Model not found, downloading...")
if err := downloadFromHF(filepath.Join(baseConfigPath, filepath.Base(basicModelURL)), basicModelURL); err != nil {
fmt.Println("failed to download model: %w", err)
return
}
}

if !modelDebugStateExistAndValid {
fmt.Println("Model debug state not found, downloading...")
if err := downloadFromHF(filepath.Join(baseConfigPath, filepath.Base(basicModelDebugURL)), basicModelDebugURL); err != nil {
fmt.Println("failed to download model debug state: %w", err)
return
}
}
},
}

// runCmd represents the run command
var runCmd = &cobra.Command{
Use: "run",
Expand All @@ -38,6 +133,15 @@ var gpt2Cmd = &cobra.Command{
Short: "Run GPT-2 inference",
Long: `This command specifically initiates the GPT-2 inference process. It allows users to input text and receive AI-generated text continuations based on the GPT-2 model.`,
Run: func(cmd *cobra.Command, args []string) {
// load tokenizer
// TODO: custom configuration
tok, err := NewTokenizer(filepath.Join(cacheDir, DEFAULT_BASE_FOLDER_NAME, filepath.Base(basicTokenizerURL)))
if err != nil {
fmt.Println("failed to load tokenizer: %w", err)
fmt.Println("did you forget to run `llmgo init`?")
return
}

for {
fmt.Printf(">>> ")
inputReader := bufio.NewReader(os.Stdin)
Expand Down Expand Up @@ -70,12 +174,10 @@ var gpt2Cmd = &cobra.Command{
}

func InitializeCommand() {
tok, err = NewTokenizer("./gpt2_tokenizer.bin")
if err != nil {
panic(err)
}
initializeCacheDir()

rootCmd.AddCommand(runCmd)
rootCmd.AddCommand(initCmd)
runCmd.AddCommand(gpt2Cmd)

err := rootCmd.Execute()
Expand Down
54 changes: 54 additions & 0 deletions hf_loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package llmgo

import (
"fmt"
"io"
"net/http"
"os"
)

func downloadFromHF(outputPath, url string) error {
fmt.Println("Downloading file from Huggingface...")

resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to get file: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

contentLength := resp.ContentLength
var totalRead int64 = 0

out, err := os.Create(outputPath)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", outputPath, err)
}
defer out.Close()

buf := make([]byte, 4096) // Adjust buffer size to your needs
for {
n, err := resp.Body.Read(buf)
if n > 0 {
totalRead += int64(n)
percentage := float64(totalRead) / float64(contentLength) * 100
fmt.Printf("\rDownloading... %.2f%% complete", percentage)

_, writeErr := out.Write(buf[:n])
if writeErr != nil {
return fmt.Errorf("failed to write to file %s: %w", outputPath, writeErr)
}
}
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read data: %w", err)
}
}
fmt.Println("\nDownload complete.")
return nil
}

0 comments on commit 56de243

Please sign in to comment.