Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Encountered bug when using CudaGraph in Torch-TensorRT #3349

Open
yjjinjie opened this issue Jan 9, 2025 · 14 comments
Open

🐛 [Bug] Encountered bug when using CudaGraph in Torch-TensorRT #3349

yjjinjie opened this issue Jan 9, 2025 · 14 comments
Labels
bug Something isn't working

Comments

@yjjinjie
Copy link

yjjinjie commented Jan 9, 2025

Bug Description

when I use cudagraph, torch_tensorrt.runtime.set_cudagraphs_mode(True), the program occasional issue

RuntimeError: CUDA error: operation would make the legacy stream depend on a capturing blocking stream
[default0]:CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[default0]:For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[default0]:Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

To Reproduce

Steps to reproduce the behavior:

my code is so large, and use the multi threads to predict the model

Expected behavior

Environment

Byte Order:                      Little Endian
CPU(s):                          104
On-line CPU(s) list:             0-103
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              2
Core(s) per socket:              26
Socket(s):                       2
Stepping:                        7
CPU max MHz:                     3800.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        5000.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       1.6 MiB (52 instances)
L1i cache:                       1.6 MiB (52 instances)
L2 cache:                        52 MiB (52 instances)
L3 cache:                        71.5 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-103
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Tsx async abort:   Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] torch==2.5.0+cu121
[pip3] torch_tensorrt==2.5.0
[pip3] torchmetrics==1.0.3
[pip3] torchrec==1.0.0+cu121
[pip3] triton==3.1.0
[conda] numpy                     1.26.3                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] torch                     2.5.0+cu121              pypi_0    pypi
[conda] torch-tensorrt            2.5.0               pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchrec                  1.0.0+cu121              pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
@yjjinjie yjjinjie added the bug Something isn't working label Jan 9, 2025
@yjjinjie
Copy link
Author

yjjinjie commented Jan 9, 2025

@keehyuna my trt is 2.5.0 ,and I add pr https://github.com/pytorch/TensorRT/pull/3289/files to slove dynamic shape error, and add https://github.com/pytorch/TensorRT/pull/3310/files to slove mutex,but when I use the cudagraph, it has error

@yjjinjie
Copy link
Author

@keehyuna

wget http://automl-nni.oss-cn-beijing.aliyuncs.com/trt/test_demo/test_demo2.tar.gz

# ok
python test_model_trt.py

# error
python test_model_trt_cudagraph.py

you can just reproduce it by this

@keehyuna
Copy link
Collaborator

Thanks @yjjinjie , I could reproduce the issue with your sample.
I increased log level in c++ runtime to verbose and this is the log when problem was reproduced.
+torch.ops.tensorrt.set_logging_level(4)

If https://github.com/pytorch/TensorRT/pull/3310/files is applied, we expect Input/output Name logging are atomic.
but it seems model was serialized with torch_tensorrt withtout pr3310
It also explains no issue with test_resnet.py. Could check the issue with correct model(serialized with fixed version of torchTRT)?

DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,41,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,1,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]  <-- race
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]  <-- race
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 41, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 1, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]

@yjjinjie
Copy link
Author

@keehyuna ok,the model is old,but I export the new model using new trt (with pr3310) ,it also has problem.

because the problem is ossurs in runtime, the model script version is not releated

wget http://automl-nni.oss-cn-beijing.aliyuncs.com/trt/test_demo/test_demo3.tar.gz

# ok
python test_model_trt.py

# error
python test_model_trt_cudagraph.py

@yjjinjie
Copy link
Author

yjjinjie commented Jan 10, 2025

@keehyuna your trt version contains the pr 3310 or not? the problem is runtime, whatever the model is old or new with pr3310, the logging is always aotmic, it is not releated to scripted model.

if your trt is not contains pr3310, the thread result is not equal to process result.

DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,41,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 41, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,1,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 1, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 1, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1

@yjjinjie
Copy link
Author

yjjinjie commented Jan 10, 2025

@keehyuna I find the resnet has no problem too, you can change the threads 10->8 in test_model_trt_cudagraph, my model also has no problem. I think it is releated with the multi-threads or the speed ,or the mutex in cudagraph?

@keehyuna
Copy link
Collaborator

@keehyuna your trt version contains the pr 3310 or not? the problem is runtime, whatever the model is old or new with pr3310, the logging is always aotmic, it is not releated to scripted model.

if your trt is not contains pr3310, the thread result is not equal to process result.

DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,41,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 41, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Resetting Cudagraph on New Shape Key (10,201)(10,1,41)(10,2,17)(10)(10,41)(10)(10,17)
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 1, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1
DEBUG: [Torch-TensorRT] - Input Name: args_0 Shape: [10, 201]
DEBUG: [Torch-TensorRT] - Input Name: args_2 Shape: [10, 1, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_5 Shape: [10, 2, 17]
DEBUG: [Torch-TensorRT] - Input Name: args_3 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_1 Shape: [10, 41]
DEBUG: [Torch-TensorRT] - Input Name: args_6 Shape: [10]
DEBUG: [Torch-TensorRT] - Input Name: args_4 Shape: [10, 17]
DEBUG: [Torch-TensorRT] - Output Name: output1 Shape: [10]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [10]
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 1

you are right, I was testing with wrong trt-version. Checking on it

@keehyuna
Copy link
Collaborator

Hi @yjjinjie
Your model has torch+trt submodule. There was cuda error when torch API and cuda capture were running concurrently.
Thread lock should fix the issue but if you compile and run with single trt runtime you will not see this issue.
If you can provide model source rather than serialized model, I can take a look if there is other option to fix this problem.

@yjjinjie
Copy link
Author

@keehyuna yes. my model is embedding+ trt(dense) , how can I give you model source? Are you Chinese?We can discuss by voice

@yjjinjie
Copy link
Author

@keehyuna if trt don't support some op, it will always use torch.api + trt module,then it will always has this issue?
image

@yjjinjie
Copy link
Author

@keehyuna my code is so large, you can use the code and container. https://github.com/alibaba/TorchEasyRec

git clone [email protected]:alibaba/TorchEasyRec.git


cd TorchEasyRec

sudo docker run -it --rm \
   --gpus all  \
   --shm-size=2g --ulimit memlock=-1 --network host --ulimit stack=67108864 \
   --workdir /larec/ \
   -v "$(pwd):/larec/" \
   -v "$(pwd)/data:/data/" \
   mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easyrec/tzrec-devel:0.6

bash scripts/gen_proto.sh
pip install -r requirements.txt
export PYTHONPATH=/larec

# ok
python tzrec/tests/rank_integration_test.py RankIntegrationTest.test_multi_tower_with_fg_train_eval_export_trt

# occasional error
# use cudagraph, https://github.com/alibaba/TorchEasyRec/blob/master/tzrec/main.py add 


import torch_tensorrt
torch_tensorrt.runtime.set_cudagraphs_mode(True)


update predict_threads: None->10, https://github.com/alibaba/TorchEasyRec/blob/master/tzrec/predict.py#L57

@keehyuna
Copy link
Collaborator

Thank @yjjinjie. Your model is wrapper module that contains pytorch + trt model.
torch_tensorrt.runtime.set_cudagraphs_mode(True) can only control cudagraphs in trt module and no visibility outside of module.
Hence proper synchronization should be handled in application layer. You might be able to implement cuda graphs with torch api. Still it may need proper synchronization for multi thread environment. Please let me know if you have question.

@yjjinjie
Copy link
Author

@keehyuna hello, if the model is all trt module ,but trt don't support all ops, it may just torch op + trt enigne op + torch op, cudagraph support this ??

and my model cannot trt(embedding+dense),becase it has problems #3355

@keehyuna
Copy link
Collaborator

yes, we have recent change to support cudagraph with graph break.
If there is graph break, wrapper module is used and capture/replayed. Please refer to below document.

https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_export_cudagraphs.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants