1+ import os
2+ import random
3+ import functools
4+ import csv
5+ import numpy as np
6+ import torch
7+ import torch .nn .functional as F
8+ from sklearn .metrics import f1_score , multilabel_confusion_matrix , roc_curve , auc
9+ from skmultilearn .model_selection import iterative_train_test_split
10+ from datasets import Dataset , DatasetDict
11+ from peft import (
12+ LoraConfig ,
13+ prepare_model_for_kbit_training ,
14+ get_peft_model
15+ )
16+ from transformers import (
17+ AutoModelForSequenceClassification ,
18+ AutoTokenizer ,
19+ BitsAndBytesConfig ,
20+ TrainingArguments ,
21+ Trainer
22+ )
23+ import matplotlib .pyplot as plt
24+ import seaborn as sns
25+ from itertools import cycle
26+
27+ def tokenize_examples (examples , tokenizer ):
28+ tokenized_inputs = tokenizer (examples ['text' ])
29+ tokenized_inputs ['labels' ] = examples ['labels' ]
30+ return tokenized_inputs
31+
32+ # define custom batch preprocessor
33+ def collate_fn (batch , tokenizer ):
34+ dict_keys = ['input_ids' , 'attention_mask' , 'labels' ]
35+ d = {k : [dic [k ] for dic in batch ] for k in dict_keys }
36+ d ['input_ids' ] = torch .nn .utils .rnn .pad_sequence (
37+ d ['input_ids' ], batch_first = True , padding_value = tokenizer .pad_token_id
38+ )
39+ d ['attention_mask' ] = torch .nn .utils .rnn .pad_sequence (
40+ d ['attention_mask' ], batch_first = True , padding_value = 0
41+ )
42+ d ['labels' ] = torch .stack (d ['labels' ])
43+ return d
44+
45+ def compute_metrics (eval_pred ):
46+ logits , labels = eval_pred
47+ predictions = torch .sigmoid (torch .tensor (logits )).numpy () > 0.5
48+ labels = labels .numpy ()
49+
50+ # Calculate F1 scores
51+ f1_micro = f1_score (labels , predictions , average = 'micro' )
52+ f1_macro = f1_score (labels , predictions , average = 'macro' )
53+ f1_weighted = f1_score (labels , predictions , average = 'weighted' )
54+
55+ # Plot Confusion Matrix for each label
56+ conf_matrices = multilabel_confusion_matrix (labels , predictions )
57+ fig , ax = plt .subplots (1 , len (conf_matrices ), figsize = (15 , 5 ))
58+ if len (conf_matrices ) > 1 :
59+ for idx , cm in enumerate (conf_matrices ):
60+ plot_confusion_matrix (cm , idx , ax [idx ])
61+ else :
62+ plot_confusion_matrix (conf_matrices [0 ], 0 , ax )
63+ plt .tight_layout ()
64+ plt .show ()
65+
66+ # Plot ROC Curves
67+ plot_multilabel_roc (labels , torch .sigmoid (torch .tensor (logits )).numpy (), num_classes = labels .shape [1 ])
68+
69+ return {'f1_micro' : f1_micro , 'f1_macro' : f1_macro , 'f1_weighted' : f1_weighted }
70+
71+ # create custom trainer class to be able to pass label weights and calculate mutilabel loss
72+ class CustomTrainer (Trainer ):
73+
74+ def __init__ (self , label_weights , ** kwargs ):
75+ super ().__init__ (** kwargs )
76+ self .label_weights = label_weights
77+
78+ def compute_loss (self , model , inputs , return_outputs = False ):
79+ labels = inputs .pop ("labels" )
80+
81+ # forward pass
82+ outputs = model (** inputs )
83+ logits = outputs .get ("logits" )
84+
85+ # compute custom loss
86+ loss = F .binary_cross_entropy_with_logits (logits , labels .to (torch .float32 ), pos_weight = self .label_weights )
87+ return (loss , outputs ) if return_outputs else loss
88+
89+ # set random seed
90+ random .seed (0 )
91+
92+ # load data
93+ with open ('train.csv' , newline = '' ) as csvfile :
94+ data = list (csv .reader (csvfile , delimiter = ',' ))
95+ header_row = data .pop (0 )
96+
97+ # shuffle data
98+ random .shuffle (data )
99+
100+ # reshape
101+ idx , text , labels = list (zip (* [(int (row [0 ]), f'Title: { row [1 ].strip ()} \n \n Abstract: { row [2 ].strip ()} ' , row [3 :]) for row in data ]))
102+ labels = np .array (labels , dtype = int )
103+
104+ # create label weights
105+ label_weights = 1 - labels .sum (axis = 0 ) / labels .sum ()
106+
107+ # stratified train test split for multilabel ds
108+ row_ids = np .arange (len (labels ))
109+ train_idx , y_train , val_idx , y_val = iterative_train_test_split (row_ids [:,np .newaxis ], labels , test_size = 0.1 )
110+ x_train = [text [i ] for i in train_idx .flatten ()]
111+ x_val = [text [i ] for i in val_idx .flatten ()]
112+
113+ # create hf dataset
114+ ds = DatasetDict ({
115+ 'train' : Dataset .from_dict ({'text' : x_train , 'labels' : y_train }),
116+ 'val' : Dataset .from_dict ({'text' : x_val , 'labels' : y_val })
117+ })
118+
119+ # model name
120+ model_name = 'mistralai/Mistral-7B-v0.1'
121+
122+ # preprocess dataset with tokenizer
123+ def tokenize_examples (examples , tokenizer ):
124+ tokenized_inputs = tokenizer (examples ['text' ])
125+ tokenized_inputs ['labels' ] = examples ['labels' ]
126+ return tokenized_inputs
127+
128+ tokenizer = AutoTokenizer .from_pretrained (model_name )
129+ tokenizer .pad_token = tokenizer .eos_token
130+ tokenized_ds = ds .map (functools .partial (tokenize_examples , tokenizer = tokenizer ), batched = True )
131+ tokenized_ds = tokenized_ds .with_format ('torch' )
132+
133+ # quantization config
134+ quantization_config = BitsAndBytesConfig (
135+ load_in_4bit = True , # enable 4-bit quantization
136+ bnb_4bit_quant_type = 'nf4' , # information theoretically optimal dtype for normally distributed weights
137+ bnb_4bit_use_double_quant = True , # quantize quantized weights //insert xzibit meme
138+ bnb_4bit_compute_dtype = torch .bfloat16 # optimized fp format for ML
139+ )
140+
141+ # lora config
142+ lora_config = LoraConfig (
143+ r = 16 , # the dimension of the low-rank matrices
144+ lora_alpha = 8 , # scaling factor for LoRA activations vs pre-trained weight activations
145+ target_modules = ['q_proj' , 'k_proj' , 'v_proj' , 'o_proj' ],
146+ lora_dropout = 0.05 , # dropout probability of the LoRA layers
147+ bias = 'none' , # whether to train bias weights, so 'none' for attention layers
148+ task_type = 'SEQ_CLS'
149+ )
150+
151+ # load model
152+ model = AutoModelForSequenceClassification .from_pretrained (
153+ model_name ,
154+ quantization_config = quantization_config ,
155+ num_labels = labels .shape [1 ]
156+ )
157+ model = prepare_model_for_kbit_training (model )
158+ model = get_peft_model (model , lora_config )
159+ model .config .pad_token_id = tokenizer .pad_token_id
160+
161+ # define training args
162+ training_args = TrainingArguments (
163+ output_dir = 'multilabel_classification' ,
164+ learning_rate = 1e-4 ,
165+ per_device_train_batch_size = 8 , # tested with 16gb gpu ram
166+ per_device_eval_batch_size = 8 ,
167+ num_train_epochs = 10 ,
168+ weight_decay = 0.01 ,
169+ evaluation_strategy = 'epoch' ,
170+ save_strategy = 'epoch' ,
171+ load_best_model_at_end = True
172+ )
173+
174+ # train
175+ trainer = CustomTrainer (
176+ model = model ,
177+ args = training_args ,
178+ train_dataset = tokenized_ds ['train' ],
179+ eval_dataset = tokenized_ds ['val' ],
180+ tokenizer = tokenizer ,
181+ data_collator = functools .partial (collate_fn , tokenizer = tokenizer ),
182+ compute_metrics = compute_metrics ,
183+ label_weights = torch .tensor (label_weights , device = model .device )
184+ )
185+
186+ trainer .train ()
187+
188+ # save model
189+ peft_model_id = 'multilabel_mistral'
190+ trainer .model .save_pretrained (peft_model_id )
191+ tokenizer .save_pretrained (peft_model_id )
192+
193+ # plotting
194+ def plot_confusion_matrix (conf_matrix , class_idx , ax , class_names = ["Absent" , "Present" ]):
195+ sns .heatmap (conf_matrix , annot = True , fmt = 'd' , cmap = 'Blues' , cbar = False , ax = ax )
196+ ax .set_xlabel ('Predicted labels' )
197+ ax .set_ylabel ('True labels' )
198+ ax .set_title (f'Class { class_idx } ' )
199+ ax .xaxis .set_ticklabels (class_names )
200+ ax .yaxis .set_ticklabels (class_names )
201+
202+
203+ def plot_multilabel_roc (labels , predictions , num_classes ):
204+ fpr = dict ()
205+ tpr = dict ()
206+ roc_auc = dict ()
207+ for i in range (num_classes ):
208+ fpr [i ], tpr [i ], _ = roc_curve (labels [:, i ], predictions [:, i ])
209+ roc_auc [i ] = auc (fpr [i ], tpr [i ])
210+
211+ colors = cycle (['blue' , 'red' , 'green' , 'yellow' , 'cyan' , 'magenta' , 'black' ])
212+ plt .figure (figsize = (10 , 8 ))
213+ for i , color in zip (range (num_classes ), colors ):
214+ plt .plot (fpr [i ], tpr [i ], color = color , lw = 2 ,
215+ label = f'ROC curve of class { i } (area = { roc_auc [i ]:.2f} )' )
216+ plt .plot ([0 , 1 ], [0 , 1 ], 'k--' , lw = 2 )
217+ plt .xlim ([0.0 , 1.0 ])
218+ plt .ylim ([0.0 , 1.05 ])
219+ plt .xlabel ('False Positive Rate' )
220+ plt .ylabel ('True Positive Rate' )
221+ plt .title ('Receiver Operating Characteristic for Multi-label' )
222+ plt .legend (loc = "lower right" )
223+ plt .show ()
0 commit comments