@@ -22,18 +22,21 @@ use crate::{
2222 page:: TemplateData ,
2323 } ,
2424} ;
25- use anyhow:: Context as _;
25+ use anyhow:: { Context as _, anyhow } ;
2626use axum:: body:: Bytes ;
2727use axum:: { Router , body:: Body , http:: Request , response:: Response as AxumResponse } ;
2828use axum_extra:: headers:: { ETag , HeaderMapExt as _} ;
2929use fn_error_context:: context;
3030use futures_util:: stream:: TryStreamExt ;
31- use http:: { HeaderMap , StatusCode , header:: CACHE_CONTROL } ;
31+ use http:: {
32+ HeaderMap , HeaderName , HeaderValue , StatusCode ,
33+ header:: { CACHE_CONTROL , CONTENT_TYPE } ,
34+ } ;
3235use http_body_util:: BodyExt ;
3336use opentelemetry_sdk:: metrics:: InMemoryMetricExporter ;
3437use serde:: de:: DeserializeOwned ;
3538use sqlx:: Connection as _;
36- use std:: { fs, future:: Future , panic, rc:: Rc , str:: FromStr , sync:: Arc } ;
39+ use std:: { collections :: HashMap , fs, future:: Future , panic, rc:: Rc , str:: FromStr , sync:: Arc } ;
3740use tokio:: { runtime, task:: block_in_place} ;
3841use tower:: ServiceExt ;
3942use tracing:: error;
@@ -137,11 +140,13 @@ pub(crate) trait AxumRouterTestExt {
137140 config : & Config ,
138141 ) -> Result < AxumResponse > ;
139142 async fn assert_not_found ( & self , path : & str ) -> Result < ( ) > ;
140- async fn assert_success_and_conditional_get (
143+ async fn assert_conditional_get (
141144 & self ,
142- path : & str ,
143- expected_body : & str ,
145+ initial_path : & str ,
146+ uncached_response : & AxumResponse ,
144147 ) -> Result < ( ) > ;
148+ async fn assert_success_and_conditional_get ( & self , path : & str ) -> Result < ( ) > ;
149+
145150 async fn assert_success_cached (
146151 & self ,
147152 path : & str ,
@@ -187,37 +192,66 @@ impl AxumRouterTestExt for axum::Router {
187192 Ok ( response)
188193 }
189194
190- async fn assert_success_and_conditional_get (
195+ async fn assert_conditional_get (
191196 & self ,
192- path : & str ,
193- expected_body : & str ,
197+ initial_path : & str ,
198+ uncached_response : & AxumResponse ,
194199 ) -> Result < ( ) > {
195- let etag: ETag = {
196- // uncached response
197- let response = self . assert_success ( path ) . await ? ;
198- let etag : ETag = response . headers ( ) . typed_get ( ) . unwrap ( ) ;
200+ let etag: ETag = uncached_response
201+ . headers ( )
202+ . typed_get ( )
203+ . ok_or_else ( || anyhow ! ( "missing ETag header" ) ) ? ;
199204
200- assert_eq ! ( response . text ( ) . await ? , expected_body ) ;
205+ let if_none_match = IfNoneMatch :: from ( etag . clone ( ) ) ;
201206
202- etag
203- } ;
207+ // general rule:
208+ //
209+ // if a header influences how any client or intermediate proxy should treat the response,
210+ // it should be repeated on the 304 response.
211+ //
212+ // This logic assumes _all_ headers have to be repeated, except for a few known ones.
213+ const NON_CACHE_HEADERS : & [ & HeaderName ] = & [ & CONTENT_TYPE ] ;
204214
205- let if_none_match = IfNoneMatch :: from ( etag. clone ( ) ) ;
215+ // store original headers, to assert that they are repeated on the 304 response.
216+ let original_headers: HashMap < HeaderName , HeaderValue > = uncached_response
217+ . headers ( )
218+ . iter ( )
219+ . filter ( |( k, _) | !NON_CACHE_HEADERS . contains ( k) )
220+ . map ( |( k, v) | ( k. clone ( ) , v. clone ( ) ) )
221+ . collect ( ) ;
206222
207223 {
208- // cached response
209- let response = self
210- . get_with_headers ( path , |headers| {
224+ let cached_response = self
225+ . get_with_headers ( initial_path , |headers| {
226+ headers . insert ( X_RLNG_SOURCE_CDN , HeaderValue :: from_static ( "fastly" ) ) ;
211227 headers. typed_insert ( if_none_match) ;
212228 } )
213229 . await ?;
214- assert_eq ! ( response. status( ) , StatusCode :: NOT_MODIFIED ) ;
215- // etag is repeated
216- assert_eq ! ( response. headers( ) . typed_get:: <ETag >( ) . unwrap( ) , etag) ;
230+ assert_eq ! ( cached_response. status( ) , StatusCode :: NOT_MODIFIED ) ;
231+
232+ // most headers are repeated on the 304 response.
233+ let cached_response_headers: HashMap < HeaderName , HeaderValue > = cached_response
234+ . headers ( )
235+ . iter ( )
236+ . filter_map ( |( k, v) | {
237+ if original_headers. contains_key ( k) {
238+ Some ( ( k. clone ( ) , v. clone ( ) ) )
239+ } else {
240+ None
241+ }
242+ } )
243+ . collect ( ) ;
244+
245+ assert_eq ! ( original_headers, cached_response_headers) ;
217246 }
218247 Ok ( ( ) )
219248 }
220249
250+ async fn assert_success_and_conditional_get ( & self , path : & str ) -> Result < ( ) > {
251+ self . assert_conditional_get ( path, & self . assert_success ( path) . await ?)
252+ . await
253+ }
254+
221255 async fn assert_not_found ( & self , path : & str ) -> Result < ( ) > {
222256 let response = self . get ( path) . await ?;
223257
0 commit comments