11use crate :: { Error , RefConfig } ;
22use base64:: prelude:: * ;
33use bytes:: Bytes ;
4- use http:: { HeaderMap , HeaderValue , StatusCode } ;
4+ use http:: { header :: ToStrError , HeaderMap , HeaderValue , StatusCode } ;
55use serde:: { Deserialize , Serialize } ;
66use std:: {
77 collections:: HashMap ,
8- convert :: TryFrom ,
8+ env ,
99 fmt:: Debug ,
1010 time:: { Duration , SystemTime } ,
1111} ;
1212use tokio_stream:: Stream ;
13+ use tracing:: Span ;
1314
1415#[ derive( Debug , Eq , PartialEq , Clone , Serialize , Deserialize ) ]
1516#[ serde( rename_all = "camelCase" ) ]
@@ -120,11 +121,10 @@ pub struct Context {
120121 pub env_config : RefConfig ,
121122}
122123
123- impl TryFrom < ( RefConfig , HeaderMap ) > for Context {
124- type Error = Error ;
125- fn try_from ( data : ( RefConfig , HeaderMap ) ) -> Result < Self , Self :: Error > {
126- let env_config = data. 0 ;
127- let headers = data. 1 ;
124+ impl Context {
125+ /// Create a new [Context] struct based on the fuction configuration
126+ /// and the incoming request data.
127+ pub fn new ( request_id : & str , env_config : RefConfig , headers : & HeaderMap ) -> Result < Self , Error > {
128128 let client_context: Option < ClientContext > = if let Some ( value) = headers. get ( "lambda-runtime-client-context" ) {
129129 serde_json:: from_str ( value. to_str ( ) ?) ?
130130 } else {
@@ -138,11 +138,7 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {
138138 } ;
139139
140140 let ctx = Context {
141- request_id : headers
142- . get ( "lambda-runtime-aws-request-id" )
143- . expect ( "missing lambda-runtime-aws-request-id header" )
144- . to_str ( ) ?
145- . to_owned ( ) ,
141+ request_id : request_id. to_owned ( ) ,
146142 deadline : headers
147143 . get ( "lambda-runtime-deadline-ms" )
148144 . expect ( "missing lambda-runtime-deadline-ms header" )
@@ -165,13 +161,37 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context {
165161
166162 Ok ( ctx)
167163 }
168- }
169164
170- impl Context {
171165 /// The execution deadline for the current invocation.
172166 pub fn deadline ( & self ) -> SystemTime {
173167 SystemTime :: UNIX_EPOCH + Duration :: from_millis ( self . deadline )
174168 }
169+
170+ /// Create a new [`tracing::Span`] for an incoming invocation.
171+ pub ( crate ) fn request_span ( & self ) -> Span {
172+ match & self . xray_trace_id {
173+ Some ( trace_id) => {
174+ env:: set_var ( "_X_AMZN_TRACE_ID" , trace_id) ;
175+ tracing:: info_span!(
176+ "Lambda runtime invoke" ,
177+ requestId = & self . request_id,
178+ xrayTraceId = trace_id
179+ )
180+ }
181+ None => {
182+ env:: remove_var ( "_X_AMZN_TRACE_ID" ) ;
183+ tracing:: info_span!( "Lambda runtime invoke" , requestId = & self . request_id)
184+ }
185+ }
186+ }
187+ }
188+
189+ /// Extract the invocation request id from the incoming request.
190+ pub ( crate ) fn invoke_request_id ( headers : & HeaderMap ) -> Result < & str , ToStrError > {
191+ headers
192+ . get ( "lambda-runtime-aws-request-id" )
193+ . expect ( "missing lambda-runtime-aws-request-id header" )
194+ . to_str ( )
175195}
176196
177197/// Incoming Lambda request containing the event payload and context.
@@ -313,7 +333,7 @@ mod test {
313333 HeaderValue :: from_static ( "arn::myarn" ) ,
314334 ) ;
315335 headers. insert ( "lambda-runtime-trace-id" , HeaderValue :: from_static ( "arn::myarn" ) ) ;
316- let tried = Context :: try_from ( ( config, headers) ) ;
336+ let tried = Context :: new ( "id" , config, & headers) ;
317337 assert ! ( tried. is_ok( ) ) ;
318338 }
319339
@@ -324,7 +344,7 @@ mod test {
324344 let mut headers = HeaderMap :: new ( ) ;
325345 headers. insert ( "lambda-runtime-aws-request-id" , HeaderValue :: from_static ( "my-id" ) ) ;
326346 headers. insert ( "lambda-runtime-deadline-ms" , HeaderValue :: from_static ( "123" ) ) ;
327- let tried = Context :: try_from ( ( config, headers) ) ;
347+ let tried = Context :: new ( "id" , config, & headers) ;
328348 assert ! ( tried. is_ok( ) ) ;
329349 }
330350
@@ -355,7 +375,7 @@ mod test {
355375 ) ;
356376
357377 let config = Arc :: new ( Config :: default ( ) ) ;
358- let tried = Context :: try_from ( ( config, headers) ) ;
378+ let tried = Context :: new ( "id" , config, & headers) ;
359379 assert ! ( tried. is_ok( ) ) ;
360380 let tried = tried. unwrap ( ) ;
361381 assert ! ( tried. client_context. is_some( ) ) ;
@@ -369,7 +389,7 @@ mod test {
369389 headers. insert ( "lambda-runtime-aws-request-id" , HeaderValue :: from_static ( "my-id" ) ) ;
370390 headers. insert ( "lambda-runtime-deadline-ms" , HeaderValue :: from_static ( "123" ) ) ;
371391 headers. insert ( "lambda-runtime-client-context" , HeaderValue :: from_static ( "{}" ) ) ;
372- let tried = Context :: try_from ( ( config, headers) ) ;
392+ let tried = Context :: new ( "id" , config, & headers) ;
373393 assert ! ( tried. is_ok( ) ) ;
374394 assert ! ( tried. unwrap( ) . client_context. is_some( ) ) ;
375395 }
@@ -390,7 +410,7 @@ mod test {
390410 "lambda-runtime-cognito-identity" ,
391411 HeaderValue :: from_str ( & cognito_identity_str) . unwrap ( ) ,
392412 ) ;
393- let tried = Context :: try_from ( ( config, headers) ) ;
413+ let tried = Context :: new ( "id" , config, & headers) ;
394414 assert ! ( tried. is_ok( ) ) ;
395415 let tried = tried. unwrap ( ) ;
396416 assert ! ( tried. identity. is_some( ) ) ;
@@ -412,7 +432,7 @@ mod test {
412432 HeaderValue :: from_static ( "arn::myarn" ) ,
413433 ) ;
414434 headers. insert ( "lambda-runtime-trace-id" , HeaderValue :: from_static ( "arn::myarn" ) ) ;
415- let tried = Context :: try_from ( ( config, headers) ) ;
435+ let tried = Context :: new ( "id" , config, & headers) ;
416436 assert ! ( tried. is_err( ) ) ;
417437 }
418438
@@ -427,7 +447,7 @@ mod test {
427447 "lambda-runtime-client-context" ,
428448 HeaderValue :: from_static ( "BAD-Type,not JSON" ) ,
429449 ) ;
430- let tried = Context :: try_from ( ( config, headers) ) ;
450+ let tried = Context :: new ( "id" , config, & headers) ;
431451 assert ! ( tried. is_err( ) ) ;
432452 }
433453
@@ -439,7 +459,7 @@ mod test {
439459 headers. insert ( "lambda-runtime-aws-request-id" , HeaderValue :: from_static ( "my-id" ) ) ;
440460 headers. insert ( "lambda-runtime-deadline-ms" , HeaderValue :: from_static ( "123" ) ) ;
441461 headers. insert ( "lambda-runtime-cognito-identity" , HeaderValue :: from_static ( "{}" ) ) ;
442- let tried = Context :: try_from ( ( config, headers) ) ;
462+ let tried = Context :: new ( "id" , config, & headers) ;
443463 assert ! ( tried. is_err( ) ) ;
444464 }
445465
@@ -454,14 +474,13 @@ mod test {
454474 "lambda-runtime-cognito-identity" ,
455475 HeaderValue :: from_static ( "BAD-Type,not JSON" ) ,
456476 ) ;
457- let tried = Context :: try_from ( ( config, headers) ) ;
477+ let tried = Context :: new ( "id" , config, & headers) ;
458478 assert ! ( tried. is_err( ) ) ;
459479 }
460480
461481 #[ test]
462482 #[ should_panic]
463- #[ allow( unused_must_use) ]
464- fn context_with_missing_request_id_should_panic ( ) {
483+ fn context_with_missing_deadline_should_panic ( ) {
465484 let config = Arc :: new ( Config :: default ( ) ) ;
466485
467486 let mut headers = HeaderMap :: new ( ) ;
@@ -471,22 +490,34 @@ mod test {
471490 HeaderValue :: from_static ( "arn::myarn" ) ,
472491 ) ;
473492 headers. insert ( "lambda-runtime-trace-id" , HeaderValue :: from_static ( "arn::myarn" ) ) ;
474- Context :: try_from ( ( config, headers) ) ;
493+ let _ = Context :: new ( "id" , config, & headers) ;
475494 }
476495
477496 #[ test]
478- #[ should_panic]
479- #[ allow( unused_must_use) ]
480- fn context_with_missing_deadline_should_panic ( ) {
481- let config = Arc :: new ( Config :: default ( ) ) ;
497+ fn invoke_request_id_should_not_panic ( ) {
498+ let mut headers = HeaderMap :: new ( ) ;
499+ headers. insert ( "lambda-runtime-aws-request-id" , HeaderValue :: from_static ( "my-id" ) ) ;
500+ headers. insert ( "lambda-runtime-deadline-ms" , HeaderValue :: from_static ( "123" ) ) ;
501+ headers. insert (
502+ "lambda-runtime-invoked-function-arn" ,
503+ HeaderValue :: from_static ( "arn::myarn" ) ,
504+ ) ;
505+ headers. insert ( "lambda-runtime-trace-id" , HeaderValue :: from_static ( "arn::myarn" ) ) ;
506+
507+ let _ = invoke_request_id ( & headers) ;
508+ }
482509
510+ #[ test]
511+ #[ should_panic]
512+ fn invoke_request_id_should_panic ( ) {
483513 let mut headers = HeaderMap :: new ( ) ;
484514 headers. insert ( "lambda-runtime-deadline-ms" , HeaderValue :: from_static ( "123" ) ) ;
485515 headers. insert (
486516 "lambda-runtime-invoked-function-arn" ,
487517 HeaderValue :: from_static ( "arn::myarn" ) ,
488518 ) ;
489519 headers. insert ( "lambda-runtime-trace-id" , HeaderValue :: from_static ( "arn::myarn" ) ) ;
490- Context :: try_from ( ( config, headers) ) ;
520+
521+ let _ = invoke_request_id ( & headers) ;
491522 }
492523}
0 commit comments