@@ -91,13 +91,11 @@ fn main() -> Result<()> {
9191 . with_n_threads_batch ( std:: thread:: available_parallelism ( ) ?. get ( ) . try_into ( ) ?)
9292 . with_embeddings ( true )
9393 . with_pooling_type ( pooling_type) ;
94- println ! ( "ctx_params: {:?}" , ctx_params ) ;
94+ println ! ( "ctx_params: {ctx_params :?}" ) ;
9595 let mut ctx = model
9696 . new_context ( & backend, ctx_params)
9797 . with_context ( || "unable to create the llama_context" ) ?;
9898
99- let n_embd = model. n_embd ( ) ;
100-
10199 let prompt_lines = {
102100 let mut lines = Vec :: new ( ) ;
103101 for doc in documents {
@@ -107,13 +105,13 @@ fn main() -> Result<()> {
107105 lines
108106 } ;
109107
110- println ! ( "prompt_lines: {:?}" , prompt_lines ) ;
108+ println ! ( "prompt_lines: {prompt_lines :?}" ) ;
111109 // tokenize the prompt
112110 let tokens_lines_list = prompt_lines
113111 . iter ( )
114112 . map ( |line| model. str_to_token ( line, AddBos :: Always ) )
115113 . collect :: < Result < Vec < _ > , _ > > ( )
116- . with_context ( || format ! ( "failed to tokenize {:?}" , prompt_lines ) ) ?;
114+ . with_context ( || format ! ( "failed to tokenize {prompt_lines :?}" ) ) ?;
117115
118116 let n_ctx = ctx. n_ctx ( ) as usize ;
119117 let n_ctx_train = model. n_ctx_train ( ) ;
@@ -169,7 +167,7 @@ fn main() -> Result<()> {
169167 max_seq_id_batch,
170168 & mut output,
171169 normalise,
172- pooling. clone ( ) ,
170+ & pooling,
173171 ) ?;
174172 max_seq_id_batch = 0 ;
175173 batch. clear ( ) ;
@@ -185,31 +183,21 @@ fn main() -> Result<()> {
185183 max_seq_id_batch,
186184 & mut output,
187185 normalise,
188- pooling. clone ( ) ,
186+ & pooling,
189187 ) ?;
190188
191189 let t_main_end = ggml_time_us ( ) ;
192190
193191 for ( j, embeddings) in output. iter ( ) . enumerate ( ) {
194- if pooling == "none" {
195- eprintln ! ( "embedding {j}: " ) ;
196- for i in 0 ..n_embd as usize {
197- if !normalise {
198- eprint ! ( "{:6.5} " , embeddings[ i] ) ;
199- } else {
200- eprint ! ( "{:9.6} " , embeddings[ i] ) ;
201- }
202- }
203- eprintln ! ( ) ;
204- } else if pooling == "rank" {
192+ if pooling == "rank" {
205193 eprintln ! ( "rerank score {j}: {:8.3}" , embeddings[ 0 ] ) ;
206194 } else {
207195 eprintln ! ( "embedding {j}: " ) ;
208- for i in 0 ..n_embd as usize {
209- if ! normalise {
210- eprint ! ( "{:6.5 } " , embeddings [ i ] ) ;
196+ for embedding in embeddings {
197+ if normalise {
198+ eprint ! ( "{embedding:9.6 } " ) ;
211199 } else {
212- eprint ! ( "{:9.6 } " , embeddings [ i ] ) ;
200+ eprint ! ( "{embedding:6.5 } " ) ;
213201 }
214202 }
215203 eprintln ! ( ) ;
@@ -236,7 +224,7 @@ fn batch_decode(
236224 s_batch : i32 ,
237225 output : & mut Vec < Vec < f32 > > ,
238226 normalise : bool ,
239- pooling : String ,
227+ pooling : & str ,
240228) -> Result < ( ) > {
241229 eprintln ! (
242230 "{}: n_tokens = {}, n_seq = {}" ,
@@ -256,9 +244,9 @@ fn batch_decode(
256244 . with_context ( || "Failed to get sequence embeddings" ) ?;
257245 let normalized = if normalise {
258246 if pooling == "rank" {
259- normalize_embeddings ( & embeddings, -1 )
247+ normalize_embeddings ( embeddings, -1 )
260248 } else {
261- normalize_embeddings ( & embeddings, 2 )
249+ normalize_embeddings ( embeddings, 2 )
262250 }
263251 } else {
264252 embeddings. to_vec ( )
@@ -281,27 +269,30 @@ fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec<f32> {
281269 0 => {
282270 // max absolute
283271 let max_abs = input. iter ( ) . map ( |x| x. abs ( ) ) . fold ( 0.0f32 , f32:: max) / 32760.0 ;
284- max_abs as f64
272+ f64:: from ( max_abs )
285273 }
286274 2 => {
287275 // euclidean norm
288276 input
289277 . iter ( )
290- . map ( |x| ( * x as f64 ) . powi ( 2 ) )
278+ . map ( |x| f64 :: from ( * x) . powi ( 2 ) )
291279 . sum :: < f64 > ( )
292280 . sqrt ( )
293281 }
294282 p => {
295283 // p-norm
296- let sum = input. iter ( ) . map ( |x| ( x. abs ( ) as f64 ) . powi ( p) ) . sum :: < f64 > ( ) ;
297- sum. powf ( 1.0 / p as f64 )
284+ let sum = input
285+ . iter ( )
286+ . map ( |x| f64:: from ( x. abs ( ) ) . powi ( p) )
287+ . sum :: < f64 > ( ) ;
288+ sum. powf ( 1.0 / f64:: from ( p) )
298289 }
299290 } ;
300291
301292 let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 } ;
302293
303294 for i in 0 ..n {
304- output[ i] = ( input[ i] as f64 * norm) as f32 ;
295+ output[ i] = ( f64 :: from ( input[ i] ) * norm) as f32 ;
305296 }
306297
307298 output
0 commit comments