-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of python collect_env.py
==============================
System Info
==============================
OS : Amazon Linux 2023.9.20251110 (x86_64)
GCC version : (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version : Could not collect
CMake version : version 3.22.2
Libc version : glibc-2.34
==============================
PyTorch Info
==============================
PyTorch version : 2.9.0+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform : Linux-6.1.158-178.288.amzn2023.x86_64-x86_64-with-glibc2.34
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : 12.9.86
CUDA_MODULE_LOADING set to :
GPU models and configuration :
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version : 580.95.05
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R13 Processor
CPU family: 25
BogoMIPS: 5300.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 48 MiB (96 instances)
L3 cache: 384 MiB (12 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-47,96-143
NUMA node1 CPU(s): 48-95,144-191
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Mitigation; Clear CPU buffers
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.5.3
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.16.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cutlass-dsl==4.3.1
[pip3] nvidia-ml-py==13.580.82
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.3.20
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.9.0
[pip3] torchaudio==2.9.0
[pip3] torchvision==0.24.0
[pip3] transformers==4.57.1
[pip3] triton==3.5.0
[conda] flashinfer-python 0.5.3 pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cudnn-frontend 1.16.0 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-cufile-cu12 1.13.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-cutlass-dsl 4.3.1 pypi_0 pypi
[conda] nvidia-ml-py 13.580.82 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvshmem-cu12 3.3.20 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pyzmq 27.1.0 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] torchaudio 2.9.0 pypi_0 pypi
[conda] torchvision 0.24.0 pypi_0 pypi
[conda] transformers 4.57.1 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi
==============================
vLLM Info
==============================
ROCM Version : Could not collect
vLLM Version : 0.12.0
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 0-47,96-143 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 0-47,96-143 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 0-47,96-143 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 0-47,96-143 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 48-95,144-191 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 48-95,144-191 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 48-95,144-191 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X 48-95,144-191 1 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
==============================
Environment Variables
==============================
LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/opt/amazon/openmpi/lib64:/opt/amazon/efa/lib64:/opt/amazon/ofi-nccl/lib64:/usr/local/lib:/usr/lib:/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/opt/amazon/openmpi/lib64:/opt/amazon/efa/lib64:/opt/amazon/ofi-nccl/lib64:/usr/local/lib:/usr/lib:/lib
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
🐛 Describe the bug
Code to reproduce the problem:
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --served-model-name Qwen3-Next-80B-A3B-Instruct --port 8801 --tensor-parallel-size 8 --max-model-len 6400 --gpu-memory-utilization 0.9 --max-num-seqs 4 --enable-lora --lora-modules sql-lora=/path/to/mylora --tool-call-parser hermes --enable-auto-tool-choice
The error message I got:
(Worker_TP5 pid=116613) Exception ignored in: <function ExactWeakKeyDictionary.__setitem__.<locals>.<lambda> at 0x7f7a0ce67a30>
(Worker_TP5 pid=116613) Traceback (most recent call last):
(Worker_TP5 pid=116613) File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 988, in <lambda>
(Worker_TP5 pid=116613) self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
(Worker_TP5 pid=116613) File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/v1/executor/multiproc_executor.py", line 689, in signal_handler
(Worker_TP5 pid=116613) raise SystemExit()
(Worker_TP5 pid=116613) SystemExit:
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] WorkerProc hit an exception.
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] Traceback (most recent call last):
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/v1/executor/multiproc_executor.py", line 817, in worker_busy_loop
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] output = func(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return func(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/v1/worker/gpu_worker.py", line 324, in determine_available_memory
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] self.model_runner.profile_run()
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 4357, in profile_run
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] hidden_states, last_hidden_states = self._dummy_run(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return func(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/v1/worker/gpu_model_runner.py", line 4071, in _dummy_run
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] outputs = self.model(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/cuda_graph.py", line 126, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.runnable(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._call_impl(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return forward_call(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/qwen3_next.py", line 1226, in forward
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] hidden_states = self.model(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 514, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/wrapper.py", line 171, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._compiled_callable(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return fn(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/models/qwen3_next.py", line 999, in forward
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] def forward(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return fn(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/caching.py", line 54, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.optimized_call(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._wrapped_call(self, *args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/fx/graph_module.py", line 413, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] raise e
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._call_impl(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return forward_call(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "<eval_with_key>.98", line 625, in forward
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] submod_2 = self.submod_2(getitem_4, s72, getitem_3, l_self_modules_layers_modules_0_modules_linear_attn_modules_norm_parameters_weight_, l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_modules_base_layer_parameters_weight_, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_token_lora_mapping, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_token_indices_sorted_by_lora_ids, l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_lora_a_stacked_0_, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_num_tokens_per_lora, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_lora_token_start_loc, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_active_lora_ids, l_self_modules_layers_modules_0_modules_linear_attn_modules_in_proj_qkvz_punica_wrapper_token_mapping_meta_no_lora_flag_cpu, l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_lora_b_stacked_0_, getitem_5, l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_, getitem_6, l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_modules_base_layer_parameters_weight_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_lora_a_stacked_0_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_lora_b_stacked_0_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_modules_base_layer_parameters_weight_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_lora_a_stacked_0_, l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_lora_b_stacked_0_); getitem_4 = getitem_3 = l_self_modules_layers_modules_0_modules_linear_attn_modules_norm_parameters_weight_ = l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_modules_base_layer_parameters_weight_ = l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_lora_a_stacked_0_ = l_self_modules_layers_modules_0_modules_linear_attn_modules_out_proj_lora_b_stacked_0_ = getitem_5 = l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = getitem_6 = l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_modules_base_layer_parameters_weight_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_lora_a_stacked_0_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_qkvz_lora_b_stacked_0_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_modules_base_layer_parameters_weight_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_lora_a_stacked_0_ = l_self_modules_layers_modules_1_modules_linear_attn_modules_in_proj_ba_lora_b_stacked_0_ = None
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/cuda_graph.py", line 126, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.runnable(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/compilation/piecewise_backend.py", line 93, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.compiled_graph_for_general_shape(*args)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/standalone_compile.py", line 63, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._compiled_fn(*args)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return fn(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1130, in forward
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return compiled_fn(full_args)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 353, in runtime_wrapper
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] all_outs = call_func_at_runtime_with_args(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 129, in call_func_at_runtime_with_args
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] out = normalize_as_list(f(args))
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 724, in inner_fn
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] outs = compiled_fn(args)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 526, in wrapper
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return compiled_fn(runtime_args)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 613, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.current_callable(inputs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2962, in run
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] out = model(new_inputs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/tmp/torchinductor_ec2-user/do/cdoiit2cy4jrkx3gfahekjzvxkxtlxm2ob5bmntvubln7ayez35b.py", line 1377, in call
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] buf14 = torch.ops.vllm.moe_forward_shared.default(buf12, buf13, 'model.layers.0.mlp.experts')
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_ops.py", line 841, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._op(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 2129, in moe_forward_shared
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self.forward_impl(hidden_states, router_logits)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 1960, in forward_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] final_hidden_states = self.quant_method.apply(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py", line 117, in apply
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] result = self.fused_experts(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._call_impl(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return forward_call(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/lora/layers/fused_moe.py", line 153, in wrapper
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] result = func(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1307, in forward
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] fused_out = self._fused_experts(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1124, in _fused_experts
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] self.fused_experts.apply(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 2120, in apply
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] self.activation(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/lora/layers/fused_moe.py", line 194, in wrapper
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] ) = self.punica_wrapper.moe_lora_align_block_size(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 340, in moe_lora_align_block_size
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] ops.moe_lora_align_block_size(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/_custom_ops.py", line 1923, in moe_lora_align_block_size
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] torch.ops._moe_C.moe_lora_align_block_size(
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] File "/home/ec2-user/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_ops.py", line 1255, in __call__
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] return self._op(*args, **kwargs)
(Worker_TP5 pid=116613) ERROR 12-04 04:49:03 [multiproc_executor.py:822] RuntimeError: Shared memory usage exceeds device limit, and global memory fallback is not implemented yet.
I can successfully deploy Qwen3-Next if LoRA Adpater is disabled
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working