@@ -22,18 +22,21 @@ use crate::{
2222 page:: TemplateData ,
2323 } ,
2424} ;
25- use anyhow:: Context as _;
25+ use anyhow:: { Context as _, anyhow } ;
2626use axum:: body:: Bytes ;
2727use axum:: { Router , body:: Body , http:: Request , response:: Response as AxumResponse } ;
2828use axum_extra:: headers:: { ETag , HeaderMapExt as _} ;
2929use fn_error_context:: context;
3030use futures_util:: stream:: TryStreamExt ;
31- use http:: { HeaderMap , StatusCode , header:: CACHE_CONTROL } ;
31+ use http:: {
32+ HeaderMap , HeaderName , HeaderValue , StatusCode ,
33+ header:: { CACHE_CONTROL , CONTENT_TYPE } ,
34+ } ;
3235use http_body_util:: BodyExt ;
3336use opentelemetry_sdk:: metrics:: InMemoryMetricExporter ;
3437use serde:: de:: DeserializeOwned ;
3538use sqlx:: Connection as _;
36- use std:: { fs, future:: Future , panic, rc:: Rc , str:: FromStr , sync:: Arc } ;
39+ use std:: { collections :: HashMap , fs, future:: Future , panic, rc:: Rc , str:: FromStr , sync:: Arc } ;
3740use tokio:: { runtime, task:: block_in_place} ;
3841use tower:: ServiceExt ;
3942use tracing:: error;
@@ -137,11 +140,13 @@ pub(crate) trait AxumRouterTestExt {
137140 config : & Config ,
138141 ) -> Result < AxumResponse > ;
139142 async fn assert_not_found ( & self , path : & str ) -> Result < ( ) > ;
140- async fn assert_success_and_conditional_get (
143+ async fn assert_conditional_get (
141144 & self ,
142- path : & str ,
143- expected_body : & str ,
145+ initial_path : & str ,
146+ uncached_response : & AxumResponse ,
144147 ) -> Result < ( ) > ;
148+ async fn assert_success_and_conditional_get ( & self , path : & str ) -> Result < ( ) > ;
149+
145150 async fn assert_success_cached (
146151 & self ,
147152 path : & str ,
@@ -187,37 +192,73 @@ impl AxumRouterTestExt for axum::Router {
187192 Ok ( response)
188193 }
189194
190- async fn assert_success_and_conditional_get (
195+ async fn assert_conditional_get (
191196 & self ,
192- path : & str ,
193- expected_body : & str ,
197+ initial_path : & str ,
198+ uncached_response : & AxumResponse ,
194199 ) -> Result < ( ) > {
195- let etag: ETag = {
196- // uncached response
197- let response = self . assert_success ( path ) . await ? ;
198- let etag : ETag = response . headers ( ) . typed_get ( ) . unwrap ( ) ;
200+ let etag: ETag = uncached_response
201+ . headers ( )
202+ . typed_get ( )
203+ . ok_or_else ( || anyhow ! ( "missing ETag header" ) ) ? ;
199204
200- assert_eq ! ( response . text ( ) . await ? , expected_body ) ;
205+ let if_none_match = IfNoneMatch :: from ( etag . clone ( ) ) ;
201206
202- etag
203- } ;
207+ // general rule:
208+ // if a header influences how any client or intermediate proxy should treat the response,
209+ // it should be repeated on the 304 response.
210+ // This logic assumes _all_ headers have to be repeated, except for a few known ones.
211+ const NON_CACHE_HEADERS : & [ & HeaderName ] = & [ & CONTENT_TYPE ] ;
204212
205- let if_none_match = IfNoneMatch :: from ( etag. clone ( ) ) ;
213+ // store original headers, to assert that they are repeated on the 304 response.
214+ let original_headers: HashMap < HeaderName , HeaderValue > = uncached_response
215+ . headers ( )
216+ . iter ( )
217+ . filter ( |( k, _) | !NON_CACHE_HEADERS . contains ( k) )
218+ . map ( |( k, v) | ( k. clone ( ) , v. clone ( ) ) )
219+ . collect ( ) ;
206220
207221 {
208222 // cached response
209223 let response = self
210- . get_with_headers ( path, |headers| {
224+ . get_with_headers ( initial_path, |headers| {
225+ headers. insert ( X_RLNG_SOURCE_CDN , HeaderValue :: from_static ( "fastly" ) ) ;
211226 headers. typed_insert ( if_none_match) ;
212227 } )
213228 . await ?;
214229 assert_eq ! ( response. status( ) , StatusCode :: NOT_MODIFIED ) ;
215230 // etag is repeated
216- assert_eq ! ( response. headers( ) . typed_get:: <ETag >( ) . unwrap( ) , etag) ;
231+ assert_eq ! (
232+ response
233+ . headers( )
234+ . typed_get:: <ETag >( )
235+ . ok_or_else( || anyhow!( "missing ETag header" ) ) ?,
236+ etag
237+ ) ;
238+
239+ // many other headers are repeated
240+ let cached_response_headers: HashMap < HeaderName , HeaderValue > = response
241+ . headers ( )
242+ . iter ( )
243+ . filter_map ( |( k, v) | {
244+ if original_headers. contains_key ( k) {
245+ Some ( ( k. clone ( ) , v. clone ( ) ) )
246+ } else {
247+ None
248+ }
249+ } )
250+ . collect ( ) ;
251+
252+ assert_eq ! ( original_headers, cached_response_headers) ;
217253 }
218254 Ok ( ( ) )
219255 }
220256
257+ async fn assert_success_and_conditional_get ( & self , path : & str ) -> Result < ( ) > {
258+ self . assert_conditional_get ( path, & self . assert_success ( path) . await ?)
259+ . await
260+ }
261+
221262 async fn assert_not_found ( & self , path : & str ) -> Result < ( ) > {
222263 let response = self . get ( path) . await ?;
223264
0 commit comments