diff --git a/lib/omniauth/microsoft_graph/domain_verifier.rb b/lib/omniauth/microsoft_graph/domain_verifier.rb index 4401a5e..d9455cb 100644 --- a/lib/omniauth/microsoft_graph/domain_verifier.rb +++ b/lib/omniauth/microsoft_graph/domain_verifier.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true + require 'jwt' # for token signature validation -require 'omniauth' # to inherit from OmniAuth::Error +require 'omniauth-oauth2' # to use CallbackError require 'oauth2' # to rescue OAuth2::Error module OmniAuth @@ -11,8 +12,6 @@ module MicrosoftGraph OIDC_CONFIG_URL = 'https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration' COMMON_JWKS_URL = 'https://login.microsoftonline.com/common/discovery/v2.0/keys' - class DomainVerificationError < OmniAuth::Error; end - class DomainVerifier def self.verify!(auth_hash, access_token, options) new(auth_hash, access_token, options).verify! @@ -41,7 +40,13 @@ def verify! skip_verification == true || (skip_verification.is_a?(Array) && skip_verification.include?(email_domain)) || domain_verified_jwt_claim - raise DomainVerificationError, verification_error_message + + # Use CallbackError to ensure the error is properly caught by the callback_phase + # rescue clause and converted to an OmniAuth failure instead of bubbling up as a 500 error. + raise OmniAuth::Strategies::OAuth2::CallbackError.new( + :domain_verification_failed, + verification_error_message + ) end private diff --git a/spec/omniauth/microsoft_graph/domain_verifier_spec.rb b/spec/omniauth/microsoft_graph/domain_verifier_spec.rb index 777695b..6d8891f 100644 --- a/spec/omniauth/microsoft_graph/domain_verifier_spec.rb +++ b/spec/omniauth/microsoft_graph/domain_verifier_spec.rb @@ -105,8 +105,11 @@ context 'when all verification strategies fail' do before { allow(access_token).to receive(:get).and_raise(::OAuth2::Error.new('whoops')) } - it 'raises a DomainVerificationError' do - expect { result }.to raise_error OmniAuth::MicrosoftGraph::DomainVerificationError + it 'raises a CallbackError with domain_verification_failed' do + expect { result }.to raise_error(OmniAuth::Strategies::OAuth2::CallbackError) do |error| + expect(error.error).to eq(:domain_verification_failed) + expect(error.error_reason).to include('not a verified domain') + end end end end diff --git a/spec/omniauth/strategies/microsoft_graph_oauth2_spec.rb b/spec/omniauth/strategies/microsoft_graph_oauth2_spec.rb index 01482d5..5f34d49 100644 --- a/spec/omniauth/strategies/microsoft_graph_oauth2_spec.rb +++ b/spec/omniauth/strategies/microsoft_graph_oauth2_spec.rb @@ -282,7 +282,12 @@ context 'when email verification fails' do let(:response_hash) { { mail: 'something@domain.invalid' } } - let(:error) { OmniAuth::MicrosoftGraph::DomainVerificationError.new } + let(:error) do + OmniAuth::Strategies::OAuth2::CallbackError.new( + :domain_verification_failed, + 'Domain verification failed' + ) + end before do allow(OmniAuth::MicrosoftGraph::DomainVerifier).to receive(:verify!).and_raise(error)