Skip to content

End-to-end tests for torch_xla llama and mixtral #15

End-to-end tests for torch_xla llama and mixtral

End-to-end tests for torch_xla llama and mixtral #15

Workflow file for this run

name: E2E tests
on:
push:
branches:
- main
pull_request:
schedule: # Schedule the job run at 12AM PST daily.
- cron: "0 8 * * *"
# Note:
# v6e cluster used: http://shortn/_oswg6PyPo6
# If we migrate to another project, these configs should be updated accordingly.
env:
XPK_CLUSTER_NAME: tpu-v6e-ci
GCP_PROJECT: tpu-pytorch
GCP_ZONE: us-central2
TPU_TYPE: v6e-4
jobs:
tp-run:
name: Train models via 'tp run'
runs-on: [ubuntu-22.04]
env:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
RUN_ID: ${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
steps:
- name: Maximize build space
uses: AdityaGarg8/[email protected]
with:
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- name: Use Docker in rootless mode
uses: ScribeMD/[email protected]
- name: Add user to docker group
run: |
sudo usermod -aG docker $USER
newgrp docker
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- uses: actions/checkout@v4
- name: Install dev dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
- uses: 'google-github-actions/auth@v2'
with:
# Googlers: if this fails, follow http://shortn/_61iSj31q1b to debug.
credentials_json: '${{ secrets.GCP_SA_KEY }}'
- uses: google-github-actions/setup-gcloud@v2
with:
version: '>= 363.0.0'
install_components: 'beta,gke-gcloud-auth-plugin'
- name: Verify gcp setup
run: gcloud info
- name: Authenticate Docker
run: gcloud auth configure-docker --quiet
- name: Activate SA credentials
run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS
- name: tp doctor
run: tp doctor
- name: tp use
run: >
tp use
--project $GCP_PROJECT
--zone $GCP_ZONE
--cluster $XPK_CLUSTER_NAME
--num-slices 1
--artifact-dir $ARTIFACT_DIR
--tpu-type $TPU_TYPE
- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token.
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
random_id=$(cat /dev/urandom | tr -dc A-Za-z0-9 | head -c 8)
name="llama-3-8b-$random_id"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-103-raw-v1 \
profile_step=3 \
max_steps=15
echo "name=$name" >> "$GITHUB_OUTPUT"
llama-3-8b:
name: Check Llama 3.0 8B
needs: tp-run
runs-on: [ubuntu-22.04]
env:
JOBSET_NAME: ${{needs.tp-run.outputs.llama-3-8b-name}}
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- uses: 'google-github-actions/auth@v2'
with:
credentials_json: '${{ secrets.GCP_SA_KEY }}'
- uses: google-github-actions/setup-gcloud@v2
with:
version: '>= 363.0.0'
install_components: 'beta,gke-gcloud-auth-plugin'
- name: Verify gcp setup
run: gcloud info
- name: Activate SA credentials
run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS
- name: Get GKE credentials
run: |
gcloud container clusters get-credentials $XPK_CLUSTER_NAME --region=$GCP_ZONE --project=$GCP_PROJECT
kubectl config view
kubectl config set-context --current --namespace=default
- name: Stream logs
run: |
pod_name=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=JOBSET_NAME -o json | jq --raw-output '.items[0].metadata.name')
kubectl logs -c jax-tpu -f pod_name
# TODO: stream logs
# TODO: wait for completion
# TODO: check that there is a 'Step duration: ... s'
# TODO: check that there is a 'Finished training run'
# TODO: check that there is profile