-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Executorch Slows Down after second or third response in Torchchat #1399
Comments
I transferred this over to the torchchat repo since it seems TC-related on its surface |
I havent looked at the ET repos runner in a while, but do our apps actually have a chat function or is it just calling generate each time and having to repopulate the cache with every new message? I remember having to fix that issue in the torchchat cli chat command. edit: looks like jni has this for the multimodal runner https://github.com/baseweight/executorch/blob/baseweight_demo/extension/android/jni/jni_layer_llama.cpp#L290 now gonna try and see if thats the same runner used for text only. The runner.cpp you linked above doesnt have a way to generate at start_pos > 0 which is why I'm concerned |
Thanks for spinning this up @infil00p. The export in TC is based on that of ET, so my gut says either: cc: @kirklandsign |
Ok yeah it looks like the demo app effectively starts from scratch every chat message and treats the entire chat history as a new context from zero instead of just prefilling the new user message from a start_pos == length of chat history so far. https://github.com/baseweight/executorch/blob/baseweight_demo/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java#L705 This would be the first thing someone should fix probably. |
@ infil00p If you run the model with generate instead of chat do you still hit the same performance throttling? Is generate(4096) significantly faster then N chats summing up to 4096? |
Yes, I'm actually experiencing this in our own app which just calls generate every time. I definitely think the lack of generate_from_pos() on the runner is probably the issue here. I haven't tested this with a multimodal yet to confirm. |
@kirklandsign Would you be the right person to do this either on the ET or TC side? (or if @infil00p figures it out we'd love the contribution on ExecuTorch/torchchat) |
🐛 Describe the bug
As discussed earlier in pytorch/executorch#3674, to increase the size of max_seq_len we have to both increase it in the export scripts as well as bump up the hardcoded max_seq_len in runner.cpp. We're using Executorch in our Proof-of-Concept demo that we're looking to release at NeurIPS and we discovered this bug when using the LlamaRunner with Ktor. We also notice it with torchchat, BUT since Torchchat is local, it won't just time out if Llama fails to generate in time.
Step 1. Update the export.py, as done on this forked repo here: https://github.com/baseweight/torchchat/blob/hardcoded_default/torchchat/export.py#L393
Step 2. Update the runner, as done on this forked repo here:
https://github.com/baseweight/executorch/blob/baseweight_demo/examples/models/llama/runner/runner.cpp#L53
Step 3. Follow the instructions to export the model and build the AAR. I used Llama-3.2-3b-instruct, since it produces actual good demo results about Vancouver (because NeurIPS)
Step 4. Copy the model onto a phone and load in torchchat. I used a Pixel 9 running Android 15, but I also confirmed this on a OnePlus 12R
Step 4. Type a prompt (i.e. "Tell me about Nardwuar")
Step 5. Type a follow up prompt (i.e. "And the Evaporators?")
Step 6. Attempt to type another follow up prompt.
It seems that this MIGHT be the limit for actual chat on an Android phone on Executorch, since the device starts to overheat. Maybe it's not the case and I'm just missing something?
Versions
Here's the info from my Gaming PC that I'm using to build Executorch. I have a conda environment setup for this.
Collecting environment information...
PyTorch version: 2.6.0.dev20241007+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.39
Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 13:27:36) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-48-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti SUPER
Nvidia driver version: 555.58.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 7800X3D 8-Core Processor
CPU family: 25
Model: 97
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
CPU(s) scaling MHz: 52%
CPU max MHz: 5050.0000
CPU min MHz: 545.0000
BogoMIPS: 8383.77
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 96 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==0.5.0a0+72b3bb3
[pip3] flake8==6.0.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==23.6.5
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241007+cpu
[pip3] torchao==0.5.0
[pip3] torchaudio==2.5.0.dev20241007+cpu
[pip3] torchsr==1.0.4
[pip3] torchtune==0.4.0.dev20241010+cu121
[pip3] torchvision==0.20.0.dev20241007+cpu
[conda] executorch 0.5.0a0+aa67cd9 pypi_0 pypi
[conda] numpy 2.0.2 pypi_0 pypi
[conda] torch 2.6.0.dev20241112+cpu pypi_0 pypi
[conda] torch-stoi 0.2.3 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20241112+cpu pypi_0 pypi
[conda] torchgen 0.0.1 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.20.0.dev20241112+cpu pypi_0 pypi
The text was updated successfully, but these errors were encountered: