Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ray multiple host support #63

Merged
merged 9 commits into from
Apr 30, 2024

Conversation

FanhaiLu1
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 commented Apr 26, 2024

This PR enable pytorch engine multiple host on TPU POD slices.

MVP Goal:

The current PR is MVP version of multiple host, two goals in MPV:

  1. Load weight and sharding in multiple host

  2. Compute meaningful decode result

Result validation:

Weight and sharding

With this PR, the weight and sharding on v5e-16 (4 host , total 16 chips) use 50 of memory compared with 8 chips. It worked as expected.

memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_10(process=3,(2,2,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_11(process=3,(3,2,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_14(process=3,(2,3,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_15(process=3,(3,3,0,0))

The weight and sharding on v5e-8 (4 host , total 8 chips)

memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_4(process=0,(0,2,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_5(process=0,(1,2,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_6(process=0,(0,3,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_7(process=0,(1,3,0,0))

Meaningful Result

With this PR, first two line of the result on v5e-16 (4 host , total 16 chips). The result are meaningful and looking good to human.


to find purpose and fulfillment.

I believe that everyone has a unique purpose and that it is up to each individual to discover and pursue theirs.

first two line of the result on v5e-8 (4 host , total 8 chips)

to find purpose, happiness, and fulfillment. Here are some reasons why:

1. Purpose: Having a sense of purpose gives life meaning and direction. It helps individuals set goals and work towards achieving them, which can lead to a sense of accomplishment and fulfillment.

Caveats

As this is the MVP, it's important to note that there could be some limitations in terms of performance and accuracy right now.

@FanhaiLu1 FanhaiLu1 requested review from qihqi, gangji and wang2yn84 April 26, 2024 23:48
@FanhaiLu1 FanhaiLu1 requested a review from vipannalla April 27, 2024 00:03
prefix=prefix, decode_state=decode_state, slot=slot
)
all_outputs.append(output)
_ = ray.get(all_outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to return anything here? And for prefill?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In interleave serving ray multiple host, prefill doesn't return anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, can we return None with that comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Comment on lines 140 to 141
pod_name = tpu.get_current_pod_name()
num_hosts = tpu.get_current_pod_worker_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add an assertion check here that the pod_name is not None and num_hosts >= 1 ?

(None, 0) would be the situation if we're not running on a TPU VM at all

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Added assert for None and num_hosts.

Comment on lines 144 to 146
engine_worker_with_tpu_resource = PyTorchEngineRayWorker.options(
resources={"TPU": 4}
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - we could remove this line altogether if we did @ray.remote(resources={"TPU": 4}) in the PyTorchEngineRayWorker decorator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep worker without assigning any resource as intended. Move resource allocation in master to dynamic assign different amount of chips for different type of TPU. Right now, I still use a fix number, but I use the value either by passed as parameter or calculated from (pod_worker_count / num_hosts) to assign resource for the worker.

Copy link
Collaborator

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!! Mostly nits, but IIUC this same approach would work for JAX/JetStream right?

f"memory using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}"
)

def init_decode_state_ray(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe this can be init_decode_state and the init_decode_state from above can become _init_decode_state?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To easy debug and reproduce the code in single host (lots of jax debug info lost in master worker mode compared with single host mode), I let ray_worker can run in single host without ray with only one line change (comment the @ray.remote). Keep init_decode_state function name same as engine without ray.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay that makes a lot of sense!

@FanhaiLu1
Copy link
Collaborator Author

This looks great!! Mostly nits, but IIUC this same approach would work for JAX/JetStream right?

Thanks Allen for reviewing it! Yes, the Jax/Jetstream should use same approach in general, we could extract common part of code (For example: Master code could be exactly same for both Jax and Pytorch), so both Jax and Pytorch part can share same code base.

@allenwang28
Copy link
Collaborator

Could we also add a basic unit test for Ray? One way is to take advantage of Ray's "fake" cluster if you look at this file for instance.

output = worker.init_decode_state_ray.remote()
all_outputs.append(output)
_ = ray.get(all_outputs)
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of returning None, is it meaningful to return some sort of "handle" here?

So basically the reason we don't need to return is because the worker will hold on the state in a local variable. However, we can also have the worker to store the state in a list and return it's index to master

master then keep a list of ints as the "handle".

This way on good thing is that you dont need to worry about overwrite if prefill / insert is called multiple times . and helps with disagg setup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! It would be very important for disaggregation serving. The API might have large change with disaggregation serving, can we hold it for now? We will have a clear API when we do disaggregation serving.

Copy link
Collaborator

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few more style focused comments

@@ -22,7 +22,8 @@ pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly
pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# torch cpu
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage safetensors colorama coverage
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
pip3 install safetensors colorama coverage ray[serve] humanize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ray[serve] dependency can be simplified to just ray[default] which should lower the total package size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updated.

prefix=prefix, decode_state=decode_state, slot=slot
)
all_outputs.append(output)
_ = ray.get(all_outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, can we return None with that comment?

params=params, decode_state=decode_state
)
all_outputs.append(output)
state, result_tokens = ray.get(all_outputs)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment here saying that we're assuming that the worker does an all_gather, else it could be confusing for future readers why we toss the other results

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, right now all the workers do all_gather.

DecodeState = Any


class PyTorchEngineRayMaster(engine_api.Engine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if PyTorchRayEngine could be a more descriptive name here (although we should still mention that it is the leader of PyTorchInterleavedRayWorkers)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

"""Ray engine master to orchestrate requests and collect token response"""

def __init__(
self, engine_workers, tokenizer_path, context_length, batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add typing here in the init, that way readers can tell that engine_workers should be Iterable[PyTorchEngineRayWorker]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it would be more clear, updated.

@ray.remote
# pylint: disable-next=all
class PyTorchEngineRayWorker:
"""Wraps functions to the Jet Engine API format."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be helpful to add more documentation here, for instance:

class PyTorchEngineRayWorker:
  """Ray actor representation for a PyTorch engine worker.

  PyTorchEngineRayWorker enables multi-host serving for models that exceed the memory
  capabilities of single-host TPU VMs with Ray. PyTorchEngineRayWorker should be used with a "leader" engine,
  e.g. `PyTorchRayEngine`.

  Note: For `PyTorchRayEngine` to return consistent results, it's important that `PyTorchEngineRayWorker` is able
  to maintain its state within process and only return results once its transferred to CPU device.

  """

(as I write this out, if we're renaming the above to PyTorchRayEngine then this should be PyTorchRayWorker)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on renaming, updated.

@FanhaiLu1
Copy link
Collaborator Author

Could we also add a basic unit test for Ray? One way is to take advantage of Ray's "fake" cluster if you look at this file for instance.

Looks great with Ray's "fake" cluster test. Can I add test with another PR (I tested current PR e2e manually)? I also plan to add unit test for RayWorkers in next PR.

@allenwang28
Copy link
Collaborator

Looks great with Ray's "fake" cluster test. Can I add test with another PR (I tested current PR e2e manually)? I also plan to add unit test for RayWorkers in next PR.

Sounds good to me! Thanks for patiently addressing all of the comments, LGTM!

@FanhaiLu1 FanhaiLu1 requested review from qihqi and removed request for gangji, wang2yn84 and vipannalla April 30, 2024 17:51
@FanhaiLu1 FanhaiLu1 merged commit a58051d into AI-Hypercomputer:main Apr 30, 2024
3 checks passed
@FanhaiLu1 FanhaiLu1 deleted the multiple-host branch May 8, 2024 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants