-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ray multiple host support #63
Conversation
jetstream_pt/ray_engine_master.py
Outdated
prefix=prefix, decode_state=decode_state, slot=slot | ||
) | ||
all_outputs.append(output) | ||
_ = ray.get(all_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to return anything here? And for prefill?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In interleave serving ray multiple host, prefill doesn't return anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, can we return None with that comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
jetstream_pt/ray_engine_master.py
Outdated
pod_name = tpu.get_current_pod_name() | ||
num_hosts = tpu.get_current_pod_worker_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can add an assertion check here that the pod_name is not None
and num_hosts >= 1
?
(None, 0)
would be the situation if we're not running on a TPU VM at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Added assert for None and num_hosts.
jetstream_pt/ray_engine_master.py
Outdated
engine_worker_with_tpu_resource = PyTorchEngineRayWorker.options( | ||
resources={"TPU": 4} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - we could remove this line altogether if we did @ray.remote(resources={"TPU": 4})
in the PyTorchEngineRayWorker
decorator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I keep worker without assigning any resource as intended. Move resource allocation in master to dynamic assign different amount of chips for different type of TPU. Right now, I still use a fix number, but I use the value either by passed as parameter or calculated from (pod_worker_count / num_hosts) to assign resource for the worker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great!! Mostly nits, but IIUC this same approach would work for JAX/JetStream right?
jetstream_pt/ray_engine_worker.py
Outdated
f"memory using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}" | ||
) | ||
|
||
def init_decode_state_ray( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe this can be init_decode_state
and the init_decode_state
from above can become _init_decode_state
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To easy debug and reproduce the code in single host (lots of jax debug info lost in master worker mode compared with single host mode), I let ray_worker can run in single host without ray with only one line change (comment the @ray.remote). Keep init_decode_state function name same as engine without ray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay that makes a lot of sense!
Thanks Allen for reviewing it! Yes, the Jax/Jetstream should use same approach in general, we could extract common part of code (For example: Master code could be exactly same for both Jax and Pytorch), so both Jax and Pytorch part can share same code base. |
Could we also add a basic unit test for Ray? One way is to take advantage of Ray's "fake" cluster if you look at this file for instance. |
jetstream_pt/ray_engine_master.py
Outdated
output = worker.init_decode_state_ray.remote() | ||
all_outputs.append(output) | ||
_ = ray.get(all_outputs) | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of returning None, is it meaningful to return some sort of "handle" here?
So basically the reason we don't need to return is because the worker will hold on the state in a local variable. However, we can also have the worker to store the state in a list and return it's index to master
master then keep a list of ints as the "handle".
This way on good thing is that you dont need to worry about overwrite if prefill / insert is called multiple times . and helps with disagg setup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! It would be very important for disaggregation serving. The API might have large change with disaggregation serving, can we hold it for now? We will have a clear API when we do disaggregation serving.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few more style focused comments
install_everything.sh
Outdated
@@ -22,7 +22,8 @@ pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly | |||
pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |||
# torch cpu | |||
pip3 install torch --index-url https://download.pytorch.org/whl/cpu | |||
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage safetensors colorama coverage | |||
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage | |||
pip3 install safetensors colorama coverage ray[serve] humanize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ray[serve]
dependency can be simplified to just ray[default]
which should lower the total package size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, updated.
jetstream_pt/ray_engine_master.py
Outdated
prefix=prefix, decode_state=decode_state, slot=slot | ||
) | ||
all_outputs.append(output) | ||
_ = ray.get(all_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, can we return None with that comment?
jetstream_pt/ray_engine_master.py
Outdated
params=params, decode_state=decode_state | ||
) | ||
all_outputs.append(output) | ||
state, result_tokens = ray.get(all_outputs)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment here saying that we're assuming that the worker does an all_gather, else it could be confusing for future readers why we toss the other results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, right now all the workers do all_gather.
jetstream_pt/ray_engine_master.py
Outdated
DecodeState = Any | ||
|
||
|
||
class PyTorchEngineRayMaster(engine_api.Engine): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if PyTorchRayEngine
could be a more descriptive name here (although we should still mention that it is the leader of PyTorchInterleavedRayWorker
s)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
jetstream_pt/ray_engine_master.py
Outdated
"""Ray engine master to orchestrate requests and collect token response""" | ||
|
||
def __init__( | ||
self, engine_workers, tokenizer_path, context_length, batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add typing here in the init, that way readers can tell that engine_workers
should be Iterable[PyTorchEngineRayWorker]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, it would be more clear, updated.
jetstream_pt/ray_engine_worker.py
Outdated
@ray.remote | ||
# pylint: disable-next=all | ||
class PyTorchEngineRayWorker: | ||
"""Wraps functions to the Jet Engine API format.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be helpful to add more documentation here, for instance:
class PyTorchEngineRayWorker:
"""Ray actor representation for a PyTorch engine worker.
PyTorchEngineRayWorker enables multi-host serving for models that exceed the memory
capabilities of single-host TPU VMs with Ray. PyTorchEngineRayWorker should be used with a "leader" engine,
e.g. `PyTorchRayEngine`.
Note: For `PyTorchRayEngine` to return consistent results, it's important that `PyTorchEngineRayWorker` is able
to maintain its state within process and only return results once its transferred to CPU device.
"""
(as I write this out, if we're renaming the above to PyTorchRayEngine
then this should be PyTorchRayWorker
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree on renaming, updated.
Looks great with Ray's "fake" cluster test. Can I add test with another PR (I tested current PR e2e manually)? I also plan to add unit test for RayWorkers in next PR. |
Sounds good to me! Thanks for patiently addressing all of the comments, LGTM! |
This PR enable pytorch engine multiple host on TPU POD slices.
MVP Goal:
The current PR is MVP version of multiple host, two goals in MPV:
Load weight and sharding in multiple host
Compute meaningful decode result
Result validation:
Weight and sharding
With this PR, the weight and sharding on v5e-16 (4 host , total 16 chips) use 50 of memory compared with 8 chips. It worked as expected.
The weight and sharding on v5e-8 (4 host , total 8 chips)
Meaningful Result
With this PR, first two line of the result on v5e-16 (4 host , total 16 chips). The result are meaningful and looking good to human.
first two line of the result on v5e-8 (4 host , total 8 chips)
Caveats
As this is the MVP, it's important to note that there could be some limitations in terms of performance and accuracy right now.