1- ///! OpenAI API client library
1+ ///! ` OpenAI` API client library
22#[ macro_use]
33extern crate derive_builder;
44
5- use reqwest:: header:: HeaderMap ;
65use thiserror:: Error ;
76
87type Result < T > = std:: result:: Result < T , OpenAIError > ;
@@ -215,6 +214,9 @@ pub enum OpenAIError {
215214 /// An error the client discovers before talking to the API
216215 #[ error( "Bad arguments" ) ]
217216 BadArguments ( String ) ,
217+ /// Network / protocol related errors
218+ #[ error( "Error at the protocol level" ) ]
219+ ProtocolError ( surf:: Error ) ,
218220}
219221
220222impl From < api:: ErrorMessage > for OpenAIError {
@@ -229,49 +231,86 @@ impl From<String> for OpenAIError {
229231 }
230232}
231233
234+ impl From < surf:: Error > for OpenAIError {
235+ fn from ( e : surf:: Error ) -> Self {
236+ OpenAIError :: ProtocolError ( e)
237+ }
238+ }
239+
232240/// Client object. Must be constructed to talk to the API.
233241pub struct OpenAIClient {
234- client : reqwest:: Client ,
235- root : String ,
242+ client : surf:: Client ,
243+ }
244+
245+ /// Authentication middleware
246+ struct BearerToken {
247+ token : String ,
248+ }
249+
250+ impl std:: fmt:: Debug for BearerToken {
251+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
252+ // Get the first few characters to help debug, but not accidentally log key
253+ write ! (
254+ f,
255+ r#"Bearer {{ token: "{}" }}"# ,
256+ self . token. get( 0 ..8 ) . ok_or( std:: fmt:: Error ) ?
257+ )
258+ }
259+ }
260+
261+ impl BearerToken {
262+ fn new ( token : & str ) -> Self {
263+ Self {
264+ token : String :: from ( token) ,
265+ }
266+ }
267+ }
268+
269+ #[ surf:: utils:: async_trait]
270+ impl surf:: middleware:: Middleware for BearerToken {
271+ async fn handle (
272+ & self ,
273+ mut req : surf:: Request ,
274+ client : surf:: Client ,
275+ next : surf:: middleware:: Next < ' _ > ,
276+ ) -> surf:: Result < surf:: Response > {
277+ log:: debug!( "Request: {:?}" , req) ;
278+ req. insert_header ( "Authorization" , format ! ( "Bearer {}" , self . token) ) ;
279+ let response = next. run ( req, client) . await ?;
280+ log:: debug!( "Response: {:?}" , response) ;
281+ Ok ( response)
282+ }
236283}
237284
238285impl OpenAIClient {
239286 /// Creates a new `OpenAIClient` given an api token
287+ #[ must_use]
240288 pub fn new ( token : & str ) -> Self {
241- let mut headers = HeaderMap :: new ( ) ;
242- headers. insert (
243- reqwest:: header:: AUTHORIZATION ,
244- reqwest:: header:: HeaderValue :: from_str ( & format ! ( "Bearer {}" , token) )
245- . expect ( "Client library error. Header value badly formatted" ) ,
289+ let mut client = surf:: client ( ) ;
290+ client. set_base_url (
291+ surf:: Url :: parse ( "https://api.openai.com/v1/" ) . expect ( "Static string should parse" ) ,
246292 ) ;
247- Self {
248- client : reqwest:: Client :: builder ( )
249- . default_headers ( headers)
250- . build ( )
251- . expect ( "Client library error. Should have constructed a valid http client." ) ,
252- root : "https://api.openai.com/v1" . into ( ) ,
253- }
293+ client = client. with ( BearerToken :: new ( token) ) ;
294+ Self { client }
295+ }
296+
297+ /// Allow setting the api root in the tests
298+ #[ cfg( test) ]
299+ fn set_api_root ( & mut self , url : surf:: Url ) {
300+ self . client . set_base_url ( url) ;
254301 }
255302
256303 /// Private helper for making gets
257304 async fn get < T : serde:: de:: DeserializeOwned > ( & self , endpoint : & str ) -> Result < T > {
258- let url = & format ! ( "{}/{}" , self . root, endpoint) ;
259- let response = self
260- . client
261- . get ( url)
262- . send ( )
263- . await
264- . expect ( "Client error. Should have passed a valid url" ) ;
265- if response. status ( ) != 200 {
266- return Err ( OpenAIError :: APIError (
267- response
268- . json :: < api:: ErrorWrapper > ( )
269- . await
270- . expect ( "The API has returned something funky" )
271- . error ,
272- ) ) ;
305+ let mut response = self . client . get ( endpoint) . await ?;
306+ if let surf:: StatusCode :: Ok = response. status ( ) {
307+ Ok ( response. body_json :: < T > ( ) . await ?)
308+ } else {
309+ {
310+ let err = response. body_json :: < api:: ErrorWrapper > ( ) . await ?. error ;
311+ Err ( OpenAIError :: APIError ( err) )
312+ }
273313 }
274- Ok ( response. json :: < T > ( ) . await . unwrap ( ) )
275314 }
276315
277316 /// Lists the currently available engines.
@@ -295,29 +334,26 @@ impl OpenAIClient {
295334
296335 // Private helper to generate post requests. Needs to be a bit more flexible than
297336 // get because it should support SSE eventually
298- async fn post < B : serde:: ser:: Serialize > (
299- & self ,
300- endpoint : & str ,
301- body : B ,
302- ) -> Result < reqwest:: Response > {
303- let url = & format ! ( "{}/{}" , self . root, endpoint) ;
304- let response = self
337+ async fn post < B , R > ( & self , endpoint : & str , body : B ) -> Result < R >
338+ where
339+ B : serde:: ser:: Serialize ,
340+ R : serde:: de:: DeserializeOwned ,
341+ {
342+ let mut response = self
305343 . client
306- . post ( url)
307- . json ( & body)
308- . send ( )
309- . await
310- . expect ( "Client library error, json failed to parse" ) ;
311- if response. status ( ) != 200 {
312- return Err ( OpenAIError :: APIError (
344+ . post ( endpoint)
345+ . body ( surf:: Body :: from_json ( & body) ?)
346+ . await ?;
347+ match response. status ( ) {
348+ surf:: StatusCode :: Ok => Ok ( response. body_json :: < R > ( ) . await ?) ,
349+ _ => Err ( OpenAIError :: APIError (
313350 response
314- . json :: < api:: ErrorWrapper > ( )
351+ . body_json :: < api:: ErrorWrapper > ( )
315352 . await
316353 . expect ( "The API has returned something funky" )
317354 . error ,
318- ) ) ;
355+ ) ) ,
319356 }
320- Ok ( response)
321357 }
322358 /// Get predicted completion of the prompt
323359 ///
@@ -330,11 +366,7 @@ impl OpenAIClient {
330366 let args = prompt. into ( ) ;
331367 Ok ( self
332368 . post ( & format ! ( "engines/{}/completions" , args. engine) , args)
333- . await ?
334- //.text()
335- . json ( )
336- . await
337- . expect ( "Client error. JSON didn't parse correctly." ) )
369+ . await ?)
338370 }
339371}
340372
@@ -344,8 +376,11 @@ mod unit {
344376 use crate :: { api, OpenAIClient , OpenAIError } ;
345377
346378 fn mocked_client ( ) -> OpenAIClient {
379+ let _ = env_logger:: builder ( ) . is_test ( true ) . try_init ( ) ;
347380 let mut client = OpenAIClient :: new ( "bogus" ) ;
348- client. root = mockito:: server_url ( ) ;
381+ client. set_api_root (
382+ surf:: Url :: parse ( & mockito:: server_url ( ) ) . expect ( "mockito url didn't parse" ) ,
383+ ) ;
349384 client
350385 }
351386
@@ -493,9 +528,9 @@ mod integration {
493528 use api:: ErrorMessage ;
494529
495530 use crate :: { api, OpenAIClient , OpenAIError } ;
496-
497531 /// Used by tests to get a client to the actual api
498532 fn get_client ( ) -> OpenAIClient {
533+ let _ = env_logger:: builder ( ) . is_test ( true ) . try_init ( ) ;
499534 let sk = std:: env:: var ( "OPENAI_SK" ) . expect (
500535 "To run integration tests, you must put set the OPENAI_SK env var to your api token" ,
501536 ) ;
0 commit comments