Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed mlx_lm.evaluate #1174

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Distributed mlx_lm.evaluate #1174

wants to merge 2 commits into from

Conversation

barronalex
Copy link
Collaborator

@barronalex barronalex commented Dec 19, 2024

Add a distributed version of mlx_lm.evaluate that runs on multiple nodes and produces identical outputs.

Also fix a few bugs:

  • Add masking so that changing the batch_size no longer affects the output
  • Fixed a bug in loglikelihood_rolling tasks, e.g. wiki text
mlx_lm.evaluate --model mlx-community/Qwen2.5-7B-Instruct-bf16 --tasks winogrande

On 1 M2 Ultra:

Acc:   0.6992896606156275
Time (post init): 64 sec 

On 4 M2 Ultra:

Acc:   0.6985003946329913
Time (post init): 16 sec 

@ivanfioravanti
Copy link
Contributor

This is great! I'm testing it with M2 Ultra + 2 M4 Max. WOW! Great job @barronalex
When will this be reviewed and merged?

Comment on lines +42 to +43
lengths = mx.array([len(x) for x in inputs])
maxlen = lengths.max()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid doing small computations like this which require just a couple kernel launches + a graph eval in MLX. It will be like 100x (or more) faster to do:

maxlen = max(len(x) for x in inputs)

T = inp.shape[1]

offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a nicer solution might be to allow boolean masks in our SDPA. It's a really easy mistake to make the mask the wrong type and get inadvertent up-casting. If you send a bool instead it will always cast to the right type. (Just making a comment for later). It will also use a lot less memory for large contexts.

Comment on lines +44 to +45
padded = mx.stack(
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably also do this in Python.. but it's minor since padded goes into the main graph.

Suggested change
padded = mx.stack(
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs],
padded = mx.array([x + [0] * (maxlen - len(x)) for x in inputs])

if score_spans is None: # full sequence score
l = length[j].item()
score = scores[j][:l].astype(mx.float32).sum()
ig = is_greedy[j][:l].astype(mx.int32).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really necessary to cast this here since mx.array([False, True]).sum() has type mx.int32

else: # subsequence score
start, end = sorted_spans[i + j]
start, end = score_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the cast to int32 here.

scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch)
for j in range(len(batch)):
Copy link
Member

@awni awni Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably over optimization.. but having a bunch of evals in a loop like this is not such a good pattern. It hits latency really hard. Especially things like:

for i in range(len(lengths)):
  l = lengths[i].item()

Which does a kernel launch for the gather + full GPU synch at each iteration.

I know it probably will make little to no difference in runtime since this isn't the bottleneck, but perhaps still good to change it to set a good example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to do it would be:

  • have a mask making function which takes start_offset = mx.array, end_offset = mx.array and makes the mask.
  • Then multiply scores and is_greedy by the mask.
  • Then sum along the time axis.
  • Then eval everything in one shot mx.eval(scores, is_greedy)
  • Then convert them all to lists, zip and return

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it looks like keeping them as mx.array could be useful rather than converting back and forth here and at L205

Comment on lines +193 to +197
group = mx.distributed.init() if mx.distributed.is_available() else None
if group is not None:
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't necessarily need to check is_available(). If it's not available then rank = 0 and size = 1 and everything should work fine. So you can avoid needing to condition code on group is not None.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants