Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating point conversion issue with use_tucker #7

Open
jgreener64 opened this issue Apr 18, 2023 · 2 comments
Open

Floating point conversion issue with use_tucker #7

jgreener64 opened this issue Apr 18, 2023 · 2 comments

Comments

@jgreener64
Copy link

The following works fine for me:

python eat.py --lookup test.fasta --queries test.fasta --output test/

But when I add --use_tucker 1 I get:

Start loading ProtT5...
Finished loading Rostlab/prot_t5_xl_half_uniref50-enc in 28.2[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 1.4[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 0.7[s]
No existing model found. Start downloading pre-trained ProtTucker(ProtT5)...
Loading Tucker checkpoint from: temp/tucker_weights.pt
Traceback (most recent call last):
  File "/home/jgreener/soft/EAT/eat.py", line 515, in <module>
    main()
  File "/home/jgreener/soft/EAT/eat.py", line 496, in main
    eater = EAT(lookup_p, query_p, output_d,
  File "/home/jgreener/soft/EAT/eat.py", line 220, in __init__
    self.lookup_embs = self.tucker_embeddings(self.lookup_embs)
  File "/home/jgreener/soft/EAT/eat.py", line 245, in tucker_embeddings
    dataset = model.single_pass(dataset)
  File "/home/jgreener/soft/EAT/eat.py", line 36, in single_pass
    return self.tucker(x)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Half

I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0. test.fasta is uploaded as test.txt.

@mheinzinger
Copy link
Contributor

Hey :)
first of all: thanks for your feedback!
On your issue: the problem is that the current embedder, ProtT5, is run in half-precision which also produces embeddings of this datatype. In its current version, ProtTucker's weight are still loaded in full-precision which causes this RuntimeError.
So you can either a) up-cast the embeddings to fp32 before feeding them to Tucker or b) down-cast ProtTucker to fp16.
Depending on the size of your set and how speed-sensitive your application is, I would probably go for solution a) if you have a small enough set and only for version b) if you want to search large sets (millions of proteins) against each other.
For a) you would need to add an up-casting of the embeddings before this line via. sth like self.lookup_embs=self.lookup_embs.astype(np.float) (do for both, lookup & targets). For b) you would only need to add model=model.half() somewhere here.

Hope this helps; let me know if this solved your issue;

@jgreener64
Copy link
Author

That worked, thanks. I added the following lines before https://github.com/Rostlab/EAT/blob/main/eat.py#L220:

            self.lookup_embs = self.lookup_embs.to(torch.float)
            self.query_embs = self.query_embs.to(torch.float)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants