Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 4212741

Browse files
authored
[LLM Runtime] Remove use_cache in WOQ (#818)
1 parent 5e607e6 commit 4212741

File tree

5 files changed

+13
-15
lines changed

5 files changed

+13
-15
lines changed

intel_extension_for_transformers/llm/runtime/graph/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ Argument description of WeightOnlyQuantConfig:
128128
| scale_dtype | String | Data type of scales: fp32/bf16 (default fp32) |
129129
| use_ggml | Bool | Enable ggml for quantization and inference (default: False) |
130130
| use_quant | Bool | Determine whether or not the model will be quantized. (default: True) |
131-
| use_cache | Bool | Use local quantized model if file exists (default: False) |
132131

133132
Argument description of generate function:
134133
| Argument | Type | Description |

intel_extension_for_transformers/llm/runtime/graph/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_model_type(model_config):
7575
model_type = "chatglm2"
7676
return model_type
7777

78-
def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **quant_kwargs):
78+
def init(self, model_name, use_quant=True, use_gptq=False, **quant_kwargs):
7979
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
8080
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8181
self.model_type = Model.get_model_type(self.config)
@@ -106,15 +106,18 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **qu
106106
self.bin_file = fp32_bin
107107
else:
108108
self.bin_file = quant_bin
109-
if use_cache and os.path.exists(self.bin_file):
109+
110+
if os.path.exists(self.bin_file):
111+
print("{} existed, will use cache file. Otherwise please remove the file".
112+
format(self.bin_file))
110113
return
111114

112115
if use_gptq:
113116
convert_model(model_name, quant_bin, "f32")
114117
return
115118

116119

117-
if not use_cache or not os.path.exists(fp32_bin):
120+
if not os.path.exists(fp32_bin):
118121
convert_model(model_name, fp32_bin, "f32")
119122
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
120123

@@ -125,8 +128,7 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **qu
125128
assert os.path.exists(quant_bin), "Fail to quantize model"
126129

127130
# clean
128-
if not use_cache:
129-
os.remove(fp32_bin)
131+
os.remove(fp32_bin)
130132

131133
def init_from_bin(self, model_type, model_path, **generate_kwargs):
132134
self.__import_package(model_type)

intel_extension_for_transformers/llm/runtime/graph/tests/test_llm_runtime.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_llm_runtime(self):
5555
print(tokenizer.decode(pt_generate_ids))
5656

5757
# check output ids
58-
woq_config = WeightOnlyQuantConfig(use_cache=True, use_quant=False)
58+
woq_config = WeightOnlyQuantConfig(use_quant=False)
5959
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True)
6060
itrex_generate_ids = itrex_model.generate(inputs.input_ids, do_sample=False, max_new_tokens=100)[0]
6161
print(tokenizer.decode(itrex_generate_ids))
@@ -64,10 +64,10 @@ def test_llm_runtime(self):
6464

6565
# check diff of logits
6666
woq_configs = {
67-
"fp32": WeightOnlyQuantConfig(use_cache=True, use_quant=False),
68-
# "ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, use_ggml=True),
69-
"jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True),
70-
# "jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8", use_cache=True),
67+
"fp32": WeightOnlyQuantConfig(use_quant=False),
68+
# "ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4",use_ggml=True),
69+
"jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4"),
70+
# "jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8"),
7171
}
7272
for config_type in woq_configs:
7373
itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_configs[config_type],
@@ -98,7 +98,7 @@ def test_beam_search(self):
9898
pt_generate_ids = torch.load("/tf_dataset2/inc-ut/nlptoolkit_ut_model/beam_pt_generate_ids.pth").tolist()
9999

100100
# llm runtime fp32
101-
woq_config = WeightOnlyQuantConfig(use_quant=False, use_cache=True)
101+
woq_config = WeightOnlyQuantConfig(use_quant=False)
102102
itrex_model = AutoModelForCausalLM.from_pretrained(
103103
model_name, quantization_config=woq_config, trust_remote_code=True)
104104
itrex_generate_ids = itrex_model.generate(

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
184184
compute_dtype=quantization_config.compute_dtype,
185185
use_ggml=quantization_config.use_ggml,
186186
use_quant=quantization_config.use_quant,
187-
use_cache=quantization_config.use_cache,
188187
use_gptq=quantization_config.use_gptq,
189188
)
190189
return model

intel_extension_for_transformers/transformers/utils/quantization_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(
4343
algorithm="RTN",
4444
use_ggml=False,
4545
use_quant=True,
46-
use_cache=False,
4746
use_gptq=False,
4847
**kwargs,
4948
):
@@ -70,7 +69,6 @@ def __init__(
7069
self.calib_iters = kwargs.pop("calib_iters", 100)
7170
self.use_ggml = use_ggml
7271
self.use_quant = use_quant
73-
self.use_cache = use_cache
7472
self.use_gptq = use_gptq
7573

7674
if compute_dtype is None:

0 commit comments

Comments
 (0)