You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Start loading ProtT5...
Finished loading Rostlab/prot_t5_xl_half_uniref50-enc in 28.2[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 1.4[s]
Start generating embeddings for 50 proteins.This process might take a few minutes.Using batch-processing! If you run OOM/RuntimeError, you should use single-sequence embedding by setting max_batch=1.
Creating per-protein embeddings took: 0.7[s]
No existing model found. Start downloading pre-trained ProtTucker(ProtT5)...
Loading Tucker checkpoint from: temp/tucker_weights.pt
Traceback (most recent call last):
File "/home/jgreener/soft/EAT/eat.py", line 515, in <module>
main()
File "/home/jgreener/soft/EAT/eat.py", line 496, in main
eater = EAT(lookup_p, query_p, output_d,
File "/home/jgreener/soft/EAT/eat.py", line 220, in __init__
self.lookup_embs = self.tucker_embeddings(self.lookup_embs)
File "/home/jgreener/soft/EAT/eat.py", line 245, in tucker_embeddings
dataset = model.single_pass(dataset)
File "/home/jgreener/soft/EAT/eat.py", line 36, in single_pass
return self.tucker(x)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/home/jgreener/soft/miniconda3/envs/pyt10b/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Half
I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0. test.fasta is uploaded as test.txt.
The text was updated successfully, but these errors were encountered:
Hey :)
first of all: thanks for your feedback!
On your issue: the problem is that the current embedder, ProtT5, is run in half-precision which also produces embeddings of this datatype. In its current version, ProtTucker's weight are still loaded in full-precision which causes this RuntimeError.
So you can either a) up-cast the embeddings to fp32 before feeding them to Tucker or b) down-cast ProtTucker to fp16.
Depending on the size of your set and how speed-sensitive your application is, I would probably go for solution a) if you have a small enough set and only for version b) if you want to search large sets (millions of proteins) against each other.
For a) you would need to add an up-casting of the embeddings before this line via. sth like self.lookup_embs=self.lookup_embs.astype(np.float) (do for both, lookup & targets). For b) you would only need to add model=model.half() somewhere here.
Hope this helps; let me know if this solved your issue;
The following works fine for me:
But when I add
--use_tucker 1
I get:I am on Python 3.9.16, PyTorch 1.10.0, h5py 3.6.0, numpy 1.22.0, scikit-learn 0.24.2 and transformers 4.17.0.
test.fasta
is uploaded as test.txt.The text was updated successfully, but these errors were encountered: