Skip to content

Commit 94d633e

Browse files
Make this a better library (#7)
Some additions to make this more useful as a library - Move testing only dependencies into dev-dependencies - The library itself now uses surf instead of reqwest, which is more flexible over runtimes - Add some crate features to allow configuring which backend to use - Add logging for requests and responses
1 parent 697acd0 commit 94d633e

File tree

2 files changed

+102
-61
lines changed

2 files changed

+102
-61
lines changed

Cargo.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,23 @@ repository = "https://github.com/deontologician/openai-api-rust/"
1010
keywords = ["openai", "gpt3"]
1111
categories = ["api-bindings", "asynchronous"]
1212

13+
[features]
14+
default = ["hyper"]
1315

14-
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
16+
hyper = ["surf/hyper-client"]
17+
curl = ["surf/curl-client"]
18+
h1 = ["surf/h1-client"]
1519

1620
[dependencies]
17-
reqwest = { version = "0.10.9", features = ["json"] }
21+
surf = { version = "^2.1.0", default-features = false }
1822
thiserror = "^1.0.22"
1923
serde = { version = "^1.0.117", features = ["derive"] }
20-
tokio = { version = "^0.2.5", features = ["full"]}
21-
serde_json = "^1.0"
22-
derive_builder = "0.9.0"
24+
derive_builder = "^0.9.0"
25+
log = "^0.4.11"
2326

2427
[dev-dependencies]
2528
mockito = "0.28.0"
2629
maplit = "1.0.2"
30+
tokio = { version = "^0.2.5", features = ["full"]}
31+
serde_json = "^1.0"
32+
env_logger = "0.8.2"

src/lib.rs

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
///! OpenAI API client library
1+
///! `OpenAI` API client library
22
#[macro_use]
33
extern crate derive_builder;
44

5-
use reqwest::header::HeaderMap;
65
use thiserror::Error;
76

87
type Result<T> = std::result::Result<T, OpenAIError>;
@@ -215,6 +214,9 @@ pub enum OpenAIError {
215214
/// An error the client discovers before talking to the API
216215
#[error("Bad arguments")]
217216
BadArguments(String),
217+
/// Network / protocol related errors
218+
#[error("Error at the protocol level")]
219+
ProtocolError(surf::Error),
218220
}
219221

220222
impl From<api::ErrorMessage> for OpenAIError {
@@ -229,49 +231,86 @@ impl From<String> for OpenAIError {
229231
}
230232
}
231233

234+
impl From<surf::Error> for OpenAIError {
235+
fn from(e: surf::Error) -> Self {
236+
OpenAIError::ProtocolError(e)
237+
}
238+
}
239+
232240
/// Client object. Must be constructed to talk to the API.
233241
pub struct OpenAIClient {
234-
client: reqwest::Client,
235-
root: String,
242+
client: surf::Client,
243+
}
244+
245+
/// Authentication middleware
246+
struct BearerToken {
247+
token: String,
248+
}
249+
250+
impl std::fmt::Debug for BearerToken {
251+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252+
// Get the first few characters to help debug, but not accidentally log key
253+
write!(
254+
f,
255+
r#"Bearer {{ token: "{}" }}"#,
256+
self.token.get(0..8).ok_or(std::fmt::Error)?
257+
)
258+
}
259+
}
260+
261+
impl BearerToken {
262+
fn new(token: &str) -> Self {
263+
Self {
264+
token: String::from(token),
265+
}
266+
}
267+
}
268+
269+
#[surf::utils::async_trait]
270+
impl surf::middleware::Middleware for BearerToken {
271+
async fn handle(
272+
&self,
273+
mut req: surf::Request,
274+
client: surf::Client,
275+
next: surf::middleware::Next<'_>,
276+
) -> surf::Result<surf::Response> {
277+
log::debug!("Request: {:?}", req);
278+
req.insert_header("Authorization", format!("Bearer {}", self.token));
279+
let response = next.run(req, client).await?;
280+
log::debug!("Response: {:?}", response);
281+
Ok(response)
282+
}
236283
}
237284

238285
impl OpenAIClient {
239286
/// Creates a new `OpenAIClient` given an api token
287+
#[must_use]
240288
pub fn new(token: &str) -> Self {
241-
let mut headers = HeaderMap::new();
242-
headers.insert(
243-
reqwest::header::AUTHORIZATION,
244-
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))
245-
.expect("Client library error. Header value badly formatted"),
289+
let mut client = surf::client();
290+
client.set_base_url(
291+
surf::Url::parse("https://api.openai.com/v1/").expect("Static string should parse"),
246292
);
247-
Self {
248-
client: reqwest::Client::builder()
249-
.default_headers(headers)
250-
.build()
251-
.expect("Client library error. Should have constructed a valid http client."),
252-
root: "https://api.openai.com/v1".into(),
253-
}
293+
client = client.with(BearerToken::new(token));
294+
Self { client }
295+
}
296+
297+
/// Allow setting the api root in the tests
298+
#[cfg(test)]
299+
fn set_api_root(&mut self, url: surf::Url) {
300+
self.client.set_base_url(url);
254301
}
255302

256303
/// Private helper for making gets
257304
async fn get<T: serde::de::DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
258-
let url = &format!("{}/{}", self.root, endpoint);
259-
let response = self
260-
.client
261-
.get(url)
262-
.send()
263-
.await
264-
.expect("Client error. Should have passed a valid url");
265-
if response.status() != 200 {
266-
return Err(OpenAIError::APIError(
267-
response
268-
.json::<api::ErrorWrapper>()
269-
.await
270-
.expect("The API has returned something funky")
271-
.error,
272-
));
305+
let mut response = self.client.get(endpoint).await?;
306+
if let surf::StatusCode::Ok = response.status() {
307+
Ok(response.body_json::<T>().await?)
308+
} else {
309+
{
310+
let err = response.body_json::<api::ErrorWrapper>().await?.error;
311+
Err(OpenAIError::APIError(err))
312+
}
273313
}
274-
Ok(response.json::<T>().await.unwrap())
275314
}
276315

277316
/// Lists the currently available engines.
@@ -295,29 +334,26 @@ impl OpenAIClient {
295334

296335
// Private helper to generate post requests. Needs to be a bit more flexible than
297336
// get because it should support SSE eventually
298-
async fn post<B: serde::ser::Serialize>(
299-
&self,
300-
endpoint: &str,
301-
body: B,
302-
) -> Result<reqwest::Response> {
303-
let url = &format!("{}/{}", self.root, endpoint);
304-
let response = self
337+
async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
338+
where
339+
B: serde::ser::Serialize,
340+
R: serde::de::DeserializeOwned,
341+
{
342+
let mut response = self
305343
.client
306-
.post(url)
307-
.json(&body)
308-
.send()
309-
.await
310-
.expect("Client library error, json failed to parse");
311-
if response.status() != 200 {
312-
return Err(OpenAIError::APIError(
344+
.post(endpoint)
345+
.body(surf::Body::from_json(&body)?)
346+
.await?;
347+
match response.status() {
348+
surf::StatusCode::Ok => Ok(response.body_json::<R>().await?),
349+
_ => Err(OpenAIError::APIError(
313350
response
314-
.json::<api::ErrorWrapper>()
351+
.body_json::<api::ErrorWrapper>()
315352
.await
316353
.expect("The API has returned something funky")
317354
.error,
318-
));
355+
)),
319356
}
320-
Ok(response)
321357
}
322358
/// Get predicted completion of the prompt
323359
///
@@ -330,11 +366,7 @@ impl OpenAIClient {
330366
let args = prompt.into();
331367
Ok(self
332368
.post(&format!("engines/{}/completions", args.engine), args)
333-
.await?
334-
//.text()
335-
.json()
336-
.await
337-
.expect("Client error. JSON didn't parse correctly."))
369+
.await?)
338370
}
339371
}
340372

@@ -344,8 +376,11 @@ mod unit {
344376
use crate::{api, OpenAIClient, OpenAIError};
345377

346378
fn mocked_client() -> OpenAIClient {
379+
let _ = env_logger::builder().is_test(true).try_init();
347380
let mut client = OpenAIClient::new("bogus");
348-
client.root = mockito::server_url();
381+
client.set_api_root(
382+
surf::Url::parse(&mockito::server_url()).expect("mockito url didn't parse"),
383+
);
349384
client
350385
}
351386

@@ -493,9 +528,9 @@ mod integration {
493528
use api::ErrorMessage;
494529

495530
use crate::{api, OpenAIClient, OpenAIError};
496-
497531
/// Used by tests to get a client to the actual api
498532
fn get_client() -> OpenAIClient {
533+
let _ = env_logger::builder().is_test(true).try_init();
499534
let sk = std::env::var("OPENAI_SK").expect(
500535
"To run integration tests, you must put set the OPENAI_SK env var to your api token",
501536
);

0 commit comments

Comments
 (0)