-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do I get the sentence embedding from GritLM/emb_m7_nodes16_fast? #36
Comments
You should be able to just load it as follows: model = GritLM("GritLM/emb_m7_nodes16_fast", torch_dtype="auto", mode='embedding') and use it in the same way as |
I know standard grit generally uses mean pooling of the last hidden state for embeddings. I know it can use weighted mean, CLS, or last token instead. I know mean pooling of the token embeddings is a common way to generate sentence embeddings, but I've also seen fully connected "pooler" layers that are just one final dense layer that generates the embedding. |
So you would apply that head over the seq len? The problem is that seq len may change depending on the sample. You'd likely have to pad it to always the same number of tokens & those padding tokens would then become part of the embedding as they're part of the matrix multiply which may hurt performance. |
Yeah, I see what you're saying. The reason the language model head works is because it only uses the last token embedding to generate the logits for the next token. I didn't realize that, so if i want something comparable to that I wouldn't make an "embedding head" I would just use the last token approach which you already support. I understand now, thanks for the response. |
Continuing from our conversation in #13 I just think it needed a new ticket at this point.
I am trying to finetune embeddings only so I took your(@Muennighoff 's) recommendation of using GritLM/emb_m7_nodes16_fast but I don't see the embedding for the entire sentence/article only the token embeddings. Am I misunderstanding something?
The standard grit model is both a generative and an encoder so the forward function is generative and encode is the embedding. So I use model.encode(input_tokens, instruction) which returns a vector with shape (4096,) which works great. Using the model you recommended there is no generative part so I assumed forward is the embedding function and there is no encode function, right? The issue I'm hitting is that when i run model(input_tokens) i get back a tuple for a 4096 embedding for each token as oppose to a single embedding for the entire article. Should I be doing pooling on these or is there some other function I should use to get the embedding?
Here is some example code
Also the embeddings won't be the same since they are different models, but they result in similar similarity scores, right?
The text was updated successfully, but these errors were encountered: