-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/transformer sequence sharding #67
base: develop
Are you sure you want to change the base?
Conversation
…transformer_sequence_sharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice contribution :-)
@@ -130,6 +199,36 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor | |||
return _SplitSequenceParallelSection.apply(input_, shapes, mgroup) | |||
|
|||
|
|||
def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering: we now have
halo_exchange
_halo_exchange
_HaloExchange
would it make sense to come up with more unique / more descriptive names for these? I think this might be a bit confusing. I admit that the names for the other routines (shard_heads etc.) are not great either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to:
halo_exchange
_HaloExchangeParallelSection
_halo_comm
@@ -97,6 +97,7 @@ def __init__( | |||
num_heads: int = 16, | |||
mlp_hidden_ratio: int = 4, | |||
dropout_p: float = 0.1, | |||
shard_strategy: str = "shard_heads", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to doc string below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this value configurable? (how can one override the default?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to doc string below?
done
is this value configurable? (how can one override the default?)
ah yes, it is configurable from config.model.processor.shard_strategy, forgot to add that to anemoi-core
also set the default strategy to shard_sequence
einops.rearrange( | ||
t, | ||
"(batch grid) (heads vars) -> batch heads grid vars", | ||
if self.shard_strategy == "shard_sequence": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is now very long. can we introduce e.g. something like
`if if self.shard_strategy == "shard_sequence":
x = self.shard_sequence(x)
query, key, value = self.lin_qkv(x).chunk(3, -1)
query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)
if if self.shard_strategy == "shard_heads"
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
.
.
.
.
`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, the if & else blocks should be refactored as separate (member) functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, the if & else blocks should be refactored as separate (member) functions
I moved it to member functions get_qkv_shard_[heads/sequence]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to query, key, value = self.get_qkv_shard_sequence
and get_qkv_shard_sequence
respectively
@@ -104,7 +144,11 @@ def forward( | |||
dropout_p=dropout_p, | |||
) # expects (batch heads grid variable) format | |||
|
|||
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) | |||
if self.shard_strategy == "shard_sequence": | |||
out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if this would happen in a function that lives at the same place as halo_exchange, e.g. call halo_expand first and then halo_contract (not best names).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just remove_halos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer if this would happen in a function that lives at the same place as halo_exchange, e.g. call halo_expand first and then halo_contract (not best names).
good idea, I added add_halos and remove_halos
if self.shard_strategy == "shard_sequence": | ||
assert ( | ||
shapes[-1][0] // 2 >= self.window_size[0] | ||
), "Sharded sequence length must be at least twice the window size" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could have the assert print the sharded sequence length and window size so the user sees the values that raised the error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, done
excellent work! 👏 |
This PR adds a new sharding strategy shard_sequence for the transformer processor.
The current implementation (shard_heads) alternates between sharding across the sequence to sharding across heads for the sliding window attention mechanism. This requires two all-to-all communication steps per layer.
The shard_sequence strategy simplifies this process by keeping a sequence shard on each GPU and computing the sliding window attention locally. This requires a halo communication to exchange overlapping window segments (halos) between neighboring sequence shards.
Instead of 2 all-to-all communication steps per layer, the halo exchange only requires a single point-to-point communication between neighbouring GPUs, reducing communication time and improving scalability of model sharding across multiple GPUs.
The following benchmarking results show that using a 2 neighbor all-to-all (orange) is the best communication strategy to implement the halo exchange which consistently outperforms the old head-sharding strategy (blue):
This is an isolated fwd+bwds pass of 16 transformer layers with o96 input shapes, 1024 channels.
For a full training run on n320, o96 hidden we get the following increases in throughput (aligning with the benchmark results):
[mlflow](https://mlflow.ecmwf.int/#/metric?runs=%5B%22ff99c1c794be4c69849ca6ad7e98e21e%22,%222fb2e79ac56c4fcea0d33d05569098c8%22,%2248e3ec3a3e854702adfbd29622fac8e9%22,%22d1b8c835c9cc4fc9b40e014bc10f7333%22%5D&metric=%22train_wmse_step%22&experiments=%5B%2245%22%5D&plot_metric_keys=%5B%22train_wmse_step%22%5D&plot_layout=%7B%22autosize%22:true,%22xaxis%22:%7B%7D,%22yaxis%22:%7B%7D%7D&x_axis=relative&y_axis_scale=linear&line_smoothness=1&show_point=false&deselected_curves=%5B%5D&last_linear_y_axis_range=%5B%
@mishooax @ssmmnn11