-
Notifications
You must be signed in to change notification settings - Fork 559
[SPMD] Mesh to support custom device order. #4162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported." | ||
) | ||
@unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"), | ||
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@will-cromar I think PJRT-GPU single core is ready now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?
Args: | ||
device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped | ||
to an `mesh_shape` array, filling the elements using C-like index order. For example, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the example lol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok it is below, you might want to change the wording here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
970536f
to
716865c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology | ||
of the device mesh, and each element describes the number of devices in | ||
the corresponding axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like mesh_shape
can be removed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch :)
|
||
def test_custom_tile_assignment(self): | ||
xt = torch.randn(10, 20).to(device=xm.xla_device()) | ||
mesh_shape = (1, self.n_devices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. mesh_shape = (2, self.n_devices / 2)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1).
def __init__(self, | ||
device_ids: Union[np.ndarray, List], | ||
mesh_shape: Tuple[int, ...], | ||
axis_names: Tuple[str, ...] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious - how will axis_names
be used long-term? Is it just for annotating the mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices.
716865c
to
15a30e4
Compare
|
3d6bc93
to
eea8e9c
Compare
eea8e9c
to
234871f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
This implements
Mesh
class from #3871 , to support custom device order in logical XLA device mesh topology.