From 7f36f28904ba936ca7c64d4c315e5778441fc29f Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Thu, 23 Oct 2025 02:41:53 -0400 Subject: [PATCH 1/4] add NVFP4 formal document Signed-off-by: He, Xin3 --- README.md | 17 ++++--- docs/source/3x/PT_NVFP4Quant.md | 83 +++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 docs/source/3x/PT_NVFP4Quant.md diff --git a/README.md b/README.md index 7831990f6ea..8d2310f0275 100644 --- a/README.md +++ b/README.md @@ -126,17 +126,22 @@ model = load( - Overview - Dynamic Quantization - Static Quantization + Overview + + + Dynamic Quantization + Static Quantization Smooth Quantization - Weight-Only Quantization - FP8 Quantization - MX Quantization + Weight-Only Quantization + FP8 Quantization Mixed Precision + + MX Quantization + NVFP4 Quantization + diff --git a/docs/source/3x/PT_NVFP4Quant.md b/docs/source/3x/PT_NVFP4Quant.md new file mode 100644 index 00000000000..c24a38faaf2 --- /dev/null +++ b/docs/source/3x/PT_NVFP4Quant.md @@ -0,0 +1,83 @@ +NVFP4 Quantization +================== + +1. [Introduction](#introduction) +2. [Get Started with NVFP4 Quantization API](#get-started-with-nvfp4-quantization-api) +3. [Examples](#examples) +4. [Reference](#reference) + +## Introduction + +Large language models (LLMs) have revolutionized fields such as natural language understanding, generation, and multimodal processing. As these models grow, their computational and memory requirements increase, making efficient deployment challenging. To address these issues, quantization methods are employed to reduce model size and accelerate inference with minimal loss in accuracy. + +NVFP4 is a specialized 4-bit floating-point format (FP4) developed by NVIDIA for deep learning workloads. Compared to traditional INT8 or FP16 formats, NVFP4 offers further reductions in memory footprint and computational resource use, enabling efficient inference for LLMs and other neural networks on supported hardware. + +The following table summarizes the NVFP4 quantization format: + + + + + + + + + + + + + + + + + + + + + + +
Format NameElement Data typeElement BitsScaling Block SizeScale Data TypeScale BitsGlobal Scale Data TypeGlobal Scale Bits
NVFP4E2M1416E4M38FP3232
+ +At similar accuracy levels, NVFP4 can deliver lower memory usage and improved compute efficiency for multiply-accumulate operations compared to higher-precision formats. Neural Compressor supports post-training quantization to NVFP4, providing recipes and APIs for users to quantize LLMs easily. To provide the best performance, the global scale for activation is static. + +## Get Started with NVFP4 Quantization API + +To quantize a model to the NVFP4 format, use the AutoRound Quantization API as shown below. + +```python +from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert +from transformers import AutoModelForCausalLM, AutoTokenizer + +fp32_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + device_map="auto", +) +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", trust_remote_code=True) +output_dir = "./saved_inc" + +# quantization configuration +quant_config = AutoRoundConfig( + tokenizer=tokenizer, + nsamples=32, + seqlen=32, + iters=20, + scheme="NVFP4", # NVFP4 format + export_format="llm_compressor", + output_dir=output_dir, # default is "temp_auto_round" +) + +# quantize the model and save to output_dir +model = prepare(model=fp32_model, quant_config=quant_config) +model = convert(model) + +# loading +model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype="auto", device_map="auto") + +# inference +text = "There is a girl who likes adventure," +inputs = tokenizer(text, return_tensors="pt").to(model.device) +print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0])) +``` + +## Reference + +[1]: NVIDIA, Introducing NVFP4 for efficient and accurate low-precision inference,NVIDIA Developer Blog, Jun. 2025. [Online]. Available: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/ \ No newline at end of file From 58f57d64aeadd06299e6931af9d288707597c780 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Oct 2025 06:44:57 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/3x/PT_NVFP4Quant.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/3x/PT_NVFP4Quant.md b/docs/source/3x/PT_NVFP4Quant.md index c24a38faaf2..2b91cf9f313 100644 --- a/docs/source/3x/PT_NVFP4Quant.md +++ b/docs/source/3x/PT_NVFP4Quant.md @@ -80,4 +80,4 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0])) ## Reference -[1]: NVIDIA, Introducing NVFP4 for efficient and accurate low-precision inference,NVIDIA Developer Blog, Jun. 2025. [Online]. Available: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/ \ No newline at end of file +[1]: NVIDIA, Introducing NVFP4 for efficient and accurate low-precision inference,NVIDIA Developer Blog, Jun. 2025. [Online]. Available: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/ From f1fe2a67ae0db0a877452b443ee5f6507e3deb5f Mon Sep 17 00:00:00 2001 From: Xin He Date: Thu, 23 Oct 2025 14:46:08 +0800 Subject: [PATCH 3/4] Update PT_NVFP4Quant.md --- docs/source/3x/PT_NVFP4Quant.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/3x/PT_NVFP4Quant.md b/docs/source/3x/PT_NVFP4Quant.md index 2b91cf9f313..73fcbe82e2b 100644 --- a/docs/source/3x/PT_NVFP4Quant.md +++ b/docs/source/3x/PT_NVFP4Quant.md @@ -3,8 +3,7 @@ NVFP4 Quantization 1. [Introduction](#introduction) 2. [Get Started with NVFP4 Quantization API](#get-started-with-nvfp4-quantization-api) -3. [Examples](#examples) -4. [Reference](#reference) +3. [Reference](#reference) ## Introduction From 115e283654fc01c2aa4d4ffc39b36d91504afc6f Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Wed, 3 Dec 2025 04:42:26 -0500 Subject: [PATCH 4/4] update Signed-off-by: He, Xin3 --- docs/source/3x/PT_NVFP4Quant.md | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/docs/source/3x/PT_NVFP4Quant.md b/docs/source/3x/PT_NVFP4Quant.md index 73fcbe82e2b..d0a4151bcf6 100644 --- a/docs/source/3x/PT_NVFP4Quant.md +++ b/docs/source/3x/PT_NVFP4Quant.md @@ -21,22 +21,36 @@ The following table summarizes the NVFP4 quantization format: Scaling Block Size Scale Data Type Scale Bits - Global Scale Data Type - Global Scale Bits + Global Tensor-Wise Scale Data Type + Global Tensor-Wise Scale Bits NVFP4 E2M1 4 16 - E4M3 + UE4M3 8 FP32 32 -At similar accuracy levels, NVFP4 can deliver lower memory usage and improved compute efficiency for multiply-accumulate operations compared to higher-precision formats. Neural Compressor supports post-training quantization to NVFP4, providing recipes and APIs for users to quantize LLMs easily. To provide the best performance, the global scale for activation is static. +### Understanding the Scaling Mechanism + +NVFP4 uses a two-level scaling approach to maintain accuracy while reducing precision: + +- **Block-wise Scale**: The quantized tensor is divided into blocks of size 16 (the Scaling Block Size). Each block has its own scale factor stored in UE4M3 format (8 bits), which is used to convert the 4-bit E2M1 quantized values back to a higher precision representation. This fine-grained scaling helps preserve local variations in the data. + +- **Global Tensor-Wise Scale**: In addition to the block-wise scales, a single FP32 (32-bit) scale factor is applied to the entire tensor. This global scale provides an additional level of normalization for the whole weight or activation tensor. For activations, this global scale is static (computed during calibration and fixed during inference) to optimize performance. + +The dequantization formula can be expressed as: + +$$\text{dequantized\_value} = \text{quantized\_value} \times \text{block\_scale} \times \text{global\_scale}$$ + +This hierarchical scaling strategy balances compression efficiency with numerical accuracy, enabling NVFP4 to maintain model performance while significantly reducing memory footprint. + +At similar accuracy levels, NVFP4 can deliver lower memory usage and improved compute efficiency for multiply-accumulate operations compared to higher-precision formats. Neural Compressor supports post-training quantization to NVFP4, providing recipes and APIs for users to quantize LLMs easily. ## Get Started with NVFP4 Quantization API