-
Notifications
You must be signed in to change notification settings - Fork 926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed mlx_lm.evaluate
#1174
base: main
Are you sure you want to change the base?
Conversation
This is great! I'm testing it with M2 Ultra + 2 M4 Max. WOW! Great job @barronalex |
lengths = mx.array([len(x) for x in inputs]) | ||
maxlen = lengths.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would avoid doing small computations like this which require just a couple kernel launches + a graph eval in MLX. It will be like 100x (or more) faster to do:
maxlen = max(len(x) for x in inputs)
T = inp.shape[1] | ||
|
||
offset = cache[0].offset | ||
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a nicer solution might be to allow boolean masks in our SDPA. It's a really easy mistake to make the mask the wrong type and get inadvertent up-casting. If you send a bool instead it will always cast to the right type. (Just making a comment for later). It will also use a lot less memory for large contexts.
padded = mx.stack( | ||
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably also do this in Python.. but it's minor since padded
goes into the main graph.
padded = mx.stack( | |
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs], | |
padded = mx.array([x + [0] * (maxlen - len(x)) for x in inputs]) |
if score_spans is None: # full sequence score | ||
l = length[j].item() | ||
score = scores[j][:l].astype(mx.float32).sum() | ||
ig = is_greedy[j][:l].astype(mx.int32).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not really necessary to cast this here since mx.array([False, True]).sum()
has type mx.int32
else: # subsequence score | ||
start, end = sorted_spans[i + j] | ||
start, end = score_spans[i + j] | ||
score = scores[j][start:end].astype(mx.float32).sum() | ||
ig = is_greedy[j][start:end].astype(mx.int32).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the cast to int32
here.
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) | ||
for i in tqdm(range(0, len(texts), self._batch_size)): | ||
batch = texts[i : i + self._batch_size] | ||
scores, length, is_greedy = self._score_fn(batch) | ||
for j in range(len(batch)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably over optimization.. but having a bunch of evals in a loop like this is not such a good pattern. It hits latency really hard. Especially things like:
for i in range(len(lengths)):
l = lengths[i].item()
Which does a kernel launch for the gather + full GPU synch at each iteration.
I know it probably will make little to no difference in runtime since this isn't the bottleneck, but perhaps still good to change it to set a good example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to do it would be:
- have a mask making function which takes
start_offset = mx.array, end_offset = mx.array
and makes the mask. - Then multiply
scores
andis_greedy
by the mask. - Then sum along the time axis.
- Then eval everything in one shot
mx.eval(scores, is_greedy)
- Then convert them all to lists, zip and return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although it looks like keeping them as mx.array
could be useful rather than converting back and forth here and at L205
group = mx.distributed.init() if mx.distributed.is_available() else None | ||
if group is not None: | ||
# split strided so we have approximately the same lengths on each node | ||
shortened = shortened[group.rank() :: group.size()] | ||
completion_spans = completion_spans[group.rank() :: group.size()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't necessarily need to check is_available()
. If it's not available then rank = 0
and size = 1
and everything should work fine. So you can avoid needing to condition code on group is not None
.
Add a distributed version of
mlx_lm.evaluate
that runs on multiple nodes and produces identical outputs.Also fix a few bugs:
batch_size
no longer affects the outputloglikelihood_rolling
tasks, e.g. wiki textOn 1 M2 Ultra:
On 4 M2 Ultra: