Skip to content

Commit 2b2e7f3

Browse files
committed
chore: mistral-7b [tested on single GPU]
1 parent 6011421 commit 2b2e7f3

File tree

1 file changed

+223
-0
lines changed

1 file changed

+223
-0
lines changed

Mistral-7B/classification.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import os
2+
import random
3+
import functools
4+
import csv
5+
import numpy as np
6+
import torch
7+
import torch.nn.functional as F
8+
from sklearn.metrics import f1_score, multilabel_confusion_matrix, roc_curve, auc
9+
from skmultilearn.model_selection import iterative_train_test_split
10+
from datasets import Dataset, DatasetDict
11+
from peft import (
12+
LoraConfig,
13+
prepare_model_for_kbit_training,
14+
get_peft_model
15+
)
16+
from transformers import (
17+
AutoModelForSequenceClassification,
18+
AutoTokenizer,
19+
BitsAndBytesConfig,
20+
TrainingArguments,
21+
Trainer
22+
)
23+
import matplotlib.pyplot as plt
24+
import seaborn as sns
25+
from itertools import cycle
26+
27+
def tokenize_examples(examples, tokenizer):
28+
tokenized_inputs = tokenizer(examples['text'])
29+
tokenized_inputs['labels'] = examples['labels']
30+
return tokenized_inputs
31+
32+
# define custom batch preprocessor
33+
def collate_fn(batch, tokenizer):
34+
dict_keys = ['input_ids', 'attention_mask', 'labels']
35+
d = {k: [dic[k] for dic in batch] for k in dict_keys}
36+
d['input_ids'] = torch.nn.utils.rnn.pad_sequence(
37+
d['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
38+
)
39+
d['attention_mask'] = torch.nn.utils.rnn.pad_sequence(
40+
d['attention_mask'], batch_first=True, padding_value=0
41+
)
42+
d['labels'] = torch.stack(d['labels'])
43+
return d
44+
45+
def compute_metrics(eval_pred):
46+
logits, labels = eval_pred
47+
predictions = torch.sigmoid(torch.tensor(logits)).numpy() > 0.5
48+
labels = labels.numpy()
49+
50+
# Calculate F1 scores
51+
f1_micro = f1_score(labels, predictions, average='micro')
52+
f1_macro = f1_score(labels, predictions, average='macro')
53+
f1_weighted = f1_score(labels, predictions, average='weighted')
54+
55+
# Plot Confusion Matrix for each label
56+
conf_matrices = multilabel_confusion_matrix(labels, predictions)
57+
fig, ax = plt.subplots(1, len(conf_matrices), figsize=(15, 5))
58+
if len(conf_matrices) > 1:
59+
for idx, cm in enumerate(conf_matrices):
60+
plot_confusion_matrix(cm, idx, ax[idx])
61+
else:
62+
plot_confusion_matrix(conf_matrices[0], 0, ax)
63+
plt.tight_layout()
64+
plt.show()
65+
66+
# Plot ROC Curves
67+
plot_multilabel_roc(labels, torch.sigmoid(torch.tensor(logits)).numpy(), num_classes=labels.shape[1])
68+
69+
return {'f1_micro': f1_micro, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}
70+
71+
# create custom trainer class to be able to pass label weights and calculate mutilabel loss
72+
class CustomTrainer(Trainer):
73+
74+
def __init__(self, label_weights, **kwargs):
75+
super().__init__(**kwargs)
76+
self.label_weights = label_weights
77+
78+
def compute_loss(self, model, inputs, return_outputs=False):
79+
labels = inputs.pop("labels")
80+
81+
# forward pass
82+
outputs = model(**inputs)
83+
logits = outputs.get("logits")
84+
85+
# compute custom loss
86+
loss = F.binary_cross_entropy_with_logits(logits, labels.to(torch.float32), pos_weight=self.label_weights)
87+
return (loss, outputs) if return_outputs else loss
88+
89+
# set random seed
90+
random.seed(0)
91+
92+
# load data
93+
with open('train.csv', newline='') as csvfile:
94+
data = list(csv.reader(csvfile, delimiter=','))
95+
header_row = data.pop(0)
96+
97+
# shuffle data
98+
random.shuffle(data)
99+
100+
# reshape
101+
idx, text, labels = list(zip(*[(int(row[0]), f'Title: {row[1].strip()}\n\nAbstract: {row[2].strip()}', row[3:]) for row in data]))
102+
labels = np.array(labels, dtype=int)
103+
104+
# create label weights
105+
label_weights = 1 - labels.sum(axis=0) / labels.sum()
106+
107+
# stratified train test split for multilabel ds
108+
row_ids = np.arange(len(labels))
109+
train_idx, y_train, val_idx, y_val = iterative_train_test_split(row_ids[:,np.newaxis], labels, test_size = 0.1)
110+
x_train = [text[i] for i in train_idx.flatten()]
111+
x_val = [text[i] for i in val_idx.flatten()]
112+
113+
# create hf dataset
114+
ds = DatasetDict({
115+
'train': Dataset.from_dict({'text': x_train, 'labels': y_train}),
116+
'val': Dataset.from_dict({'text': x_val, 'labels': y_val})
117+
})
118+
119+
# model name
120+
model_name = 'mistralai/Mistral-7B-v0.1'
121+
122+
# preprocess dataset with tokenizer
123+
def tokenize_examples(examples, tokenizer):
124+
tokenized_inputs = tokenizer(examples['text'])
125+
tokenized_inputs['labels'] = examples['labels']
126+
return tokenized_inputs
127+
128+
tokenizer = AutoTokenizer.from_pretrained(model_name)
129+
tokenizer.pad_token = tokenizer.eos_token
130+
tokenized_ds = ds.map(functools.partial(tokenize_examples, tokenizer=tokenizer), batched=True)
131+
tokenized_ds = tokenized_ds.with_format('torch')
132+
133+
# quantization config
134+
quantization_config = BitsAndBytesConfig(
135+
load_in_4bit=True, # enable 4-bit quantization
136+
bnb_4bit_quant_type='nf4', # information theoretically optimal dtype for normally distributed weights
137+
bnb_4bit_use_double_quant=True, # quantize quantized weights //insert xzibit meme
138+
bnb_4bit_compute_dtype=torch.bfloat16 # optimized fp format for ML
139+
)
140+
141+
# lora config
142+
lora_config = LoraConfig(
143+
r=16, # the dimension of the low-rank matrices
144+
lora_alpha=8, # scaling factor for LoRA activations vs pre-trained weight activations
145+
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
146+
lora_dropout=0.05, # dropout probability of the LoRA layers
147+
bias='none', # whether to train bias weights, so 'none' for attention layers
148+
task_type='SEQ_CLS'
149+
)
150+
151+
# load model
152+
model = AutoModelForSequenceClassification.from_pretrained(
153+
model_name,
154+
quantization_config=quantization_config,
155+
num_labels=labels.shape[1]
156+
)
157+
model = prepare_model_for_kbit_training(model)
158+
model = get_peft_model(model, lora_config)
159+
model.config.pad_token_id = tokenizer.pad_token_id
160+
161+
# define training args
162+
training_args = TrainingArguments(
163+
output_dir = 'multilabel_classification',
164+
learning_rate = 1e-4,
165+
per_device_train_batch_size = 8, # tested with 16gb gpu ram
166+
per_device_eval_batch_size = 8,
167+
num_train_epochs = 10,
168+
weight_decay = 0.01,
169+
evaluation_strategy = 'epoch',
170+
save_strategy = 'epoch',
171+
load_best_model_at_end = True
172+
)
173+
174+
# train
175+
trainer = CustomTrainer(
176+
model = model,
177+
args = training_args,
178+
train_dataset = tokenized_ds['train'],
179+
eval_dataset = tokenized_ds['val'],
180+
tokenizer = tokenizer,
181+
data_collator = functools.partial(collate_fn, tokenizer=tokenizer),
182+
compute_metrics = compute_metrics,
183+
label_weights = torch.tensor(label_weights, device=model.device)
184+
)
185+
186+
trainer.train()
187+
188+
# save model
189+
peft_model_id = 'multilabel_mistral'
190+
trainer.model.save_pretrained(peft_model_id)
191+
tokenizer.save_pretrained(peft_model_id)
192+
193+
# plotting
194+
def plot_confusion_matrix(conf_matrix, class_idx, ax, class_names=["Absent", "Present"]):
195+
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False, ax=ax)
196+
ax.set_xlabel('Predicted labels')
197+
ax.set_ylabel('True labels')
198+
ax.set_title(f'Class {class_idx}')
199+
ax.xaxis.set_ticklabels(class_names)
200+
ax.yaxis.set_ticklabels(class_names)
201+
202+
203+
def plot_multilabel_roc(labels, predictions, num_classes):
204+
fpr = dict()
205+
tpr = dict()
206+
roc_auc = dict()
207+
for i in range(num_classes):
208+
fpr[i], tpr[i], _ = roc_curve(labels[:, i], predictions[:, i])
209+
roc_auc[i] = auc(fpr[i], tpr[i])
210+
211+
colors = cycle(['blue', 'red', 'green', 'yellow', 'cyan', 'magenta', 'black'])
212+
plt.figure(figsize=(10, 8))
213+
for i, color in zip(range(num_classes), colors):
214+
plt.plot(fpr[i], tpr[i], color=color, lw=2,
215+
label=f'ROC curve of class {i} (area = {roc_auc[i]:.2f})')
216+
plt.plot([0, 1], [0, 1], 'k--', lw=2)
217+
plt.xlim([0.0, 1.0])
218+
plt.ylim([0.0, 1.05])
219+
plt.xlabel('False Positive Rate')
220+
plt.ylabel('True Positive Rate')
221+
plt.title('Receiver Operating Characteristic for Multi-label')
222+
plt.legend(loc="lower right")
223+
plt.show()

0 commit comments

Comments
 (0)