Skip to content

entrpn/jax-nanoGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-nanoGPT

A replicate nano-GPT in JAX.

Install

Install dependencies

pip install -r requirements.txt

If you want to use this code with TPUs, install:

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Train single node

To create a dataset run:

cd data/shakespeare
python prepare.py

This will create a train.bin and val.bin which holds GPT2 BPE token ids in one sequence. Now you can train. Go back to the folder with the training script and run.

python train.py --config shakespeare

Train multi node in GCP cloud

We can scale our training by using TPU pod slices and TPU-VMs. In short, we deploy multiple workers and execute the training job on each worker and let pmap handle scaling.

  1. We'll be using TPU-v4. which requires a subnet in the zone us-central2-b. Follow the instructions for Set up and prepare a Google Cloud project.

  2. Create an instance. Change your_project_id to yours.

    export TPU_NAME=tpu-v4
    export ZONE=us-central2-b
    export RUNTIME_VERSION=tpu-vm-v4-base
    export PROJECT_ID=<your_project_id>
    export ACCELERATOR_TYPE=v4-16
    
    gcloud compute tpus tpu-vm create ${TPU_NAME} \
    --zone us-central2-b \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --version ${RUNTIME_VERSION} \
    --subnetwork=tpusubnet \
    --network=tpu-network
  3. In order to ssh into the machine, you might need to modify ~/.ssh/config. Change <your_user_name> with your computer's use name (echo ~/) add the following:

    Host tpu-v4
    HostName 107.167.173.130
    IdentityFile /Users/<your_user_name>/.ssh/google_compute_engine
  4. As a test try to ssh. If this works, you're ready to move to the next steps.

    gcloud compute tpus tpu-vm ssh tpu-v4 --worker=0 --zone us-central2-b --project $PROJECT_ID
  5. Now we’ll run a training job on multiple machines. First, install jax[tpu], clone the repository on all machines and install dependencies

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    
    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="git clone https://github.com/entrpn/jax-nanoGPT.git"
    
    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="pip install -r jax-nanoGPT/requirements.txt"
  6. Generate the dataset in all devices - (TODO : generate data on single drive and mount it to all instances)

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="python3 jax-nanoGPT/data/openwebtext-10k/prepare.py"
  7. Kick off training.

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="cd jax-nanoGPT; python3 train.py --config openwebtext-10k"

Generate

To generate text, use the generate.py script with the config that was used for training and the last checkpoint step that was saved.

python generate.py --config shakespeare --checkpoint-step 7500

Tensorboard logs will be stored in out-{dataset-name} with train/eval loss, learning rate and sampled generations.

Examples

Training with openwebtext10k dataset for 25k steps, where the last 50 characters in the text are generated.



About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages