You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After compiling my model using Torch-TensorRT, I observed a significant discrepancy between the outputs of the original model and the compiled model. The difference in results is substantial enough to affect performance in practical applications. I have ensured that inputs to both models are identical, and I used the same input data for comparison
To Reproduce
Steps to reproduce the behavior:
1.Compile the model using Torch-TensorRT with the following configurations:
class CLIPTextWrapper(nn.Module):
"""
Wrapper class to make text model compatible with TensorRT with simplified device handling
"""
def __init__(self, text_model):
super().__init__()
self.text_model = (
text_model.cuda().eval()
) # Move model to CUDA during initialization
for param in self.text_model.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
# Remove device movement in forward pass to avoid dynamo tracing issues
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
return outputs[1]
wrapped_model = CLIPTextWrapper(text_model).to(device).eval()
trt_model = torch_tensorrt.compile(
wrapped_model,
inputs=[
torch_tensorrt.Input(
shape=dummy_input_ids.shape,
dtype=torch.int64,
),
torch_tensorrt.Input(
shape=dummy_attention_mask.shape,
dtype=torch.int64,
),
],
ir="dynamo",
truncate_double=True,
enabled_precisions={torch.float32},
device=torch.device("cuda:0"),
disable_tf32=True,
use_explicit_typing=True,
use_fp32_acc=True,
)
2.Run inference using the original model and the compiled model with the same input data. test_output1 = wrapped_model(dummy_input_ids, dummy_attention_mask)
3.Compare the outputs of the two models.
Expected behavior
expect high cosine similarity
Cosine(torch.nn.CosineSimilarity(dim=1, eps=1e-6)) similarity between original and compiled model outputs: [0.1022, 0.1013, 0.1064, 0.1080, 0.1107, 0.1136, 0.1111, 0.1120]
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0):2.5.0+cu118
PyTorch Version (e.g. 1.0):2.5.1+cu118
CPU Architecture:Intel i7-14700K
OS (e.g., Linux):Ubuntu 20.04.6
How you installed PyTorch (conda, pip, libtorch, source):pip
Build command you used (if compiling from source):
Are you using local sources or building from archives:
Python version:3.12.7
CUDA version:11.8
GPU models and configuration:
Any other relevant information:
Additional context
The text was updated successfully, but these errors were encountered:
Bug Description
After compiling my model using Torch-TensorRT, I observed a significant discrepancy between the outputs of the original model and the compiled model. The difference in results is substantial enough to affect performance in practical applications. I have ensured that inputs to both models are identical, and I used the same input data for comparison
To Reproduce
Steps to reproduce the behavior:
1.Compile the model using Torch-TensorRT with the following configurations:
2.Run inference using the original model and the compiled model with the same input data.
test_output1 = wrapped_model(dummy_input_ids, dummy_attention_mask)
3.Compare the outputs of the two models.
Expected behavior
expect high cosine similarity
Cosine(torch.nn.CosineSimilarity(dim=1, eps=1e-6)) similarity between original and compiled model outputs: [0.1022, 0.1013, 0.1064, 0.1080, 0.1107, 0.1136, 0.1111, 0.1120]
Environment
conda
,pip
,libtorch
, source):pipAdditional context
The text was updated successfully, but these errors were encountered: