33module OmniAuth
44 module Strategies
55 class MicrosoftGraph < OmniAuth ::Strategies ::OAuth2
6+ BASE_SCOPE_URL = 'https://graph.microsoft.com/'
7+ BASE_SCOPES = %w[ offline_access openid email profile ] . freeze
8+ DEFAULT_SCOPE = 'openid email profile User.Read' . freeze
9+
610 option :name , :microsoft_graph
711
812 option :client_options , {
@@ -11,13 +15,13 @@ class MicrosoftGraph < OmniAuth::Strategies::OAuth2
1115 authorize_url : 'common/oauth2/v2.0/authorize'
1216 }
1317
14- option :authorize_params , {
15- }
18+ option :authorize_options , %i[ state callback_url access_type display score auth_type scope prompt login_hint domain_hint response_mode ]
1619
1720 option :token_params , {
1821 }
1922
20- option :scope , "offline_access https://graph.microsoft.com/User.Read"
23+ option :scope , DEFAULT_SCOPE
24+ option :authorized_client_ids , [ ]
2125
2226 uid { raw_info [ "id" ] }
2327
@@ -34,17 +38,88 @@ class MicrosoftGraph < OmniAuth::Strategies::OAuth2
3438 extra do
3539 {
3640 'raw_info' => raw_info ,
37- 'params' => access_token . params
41+ 'params' => access_token . params ,
42+ 'aud' => options . client_id
3843 }
3944 end
4045
46+ def authorize_params
47+ super . tap do |params |
48+ options [ :authorize_options ] . each do |k |
49+ params [ k ] = request . params [ k . to_s ] unless [ nil , '' ] . include? ( request . params [ k . to_s ] )
50+ end
51+
52+ params [ :scope ] = get_scope ( params )
53+ params [ :access_type ] = 'offline' if params [ :access_type ] . nil?
54+
55+ session [ 'omniauth.state' ] = params [ :state ] if params [ :state ]
56+ end
57+ end
58+
4159 def raw_info
4260 @raw_info ||= access_token . get ( 'https://graph.microsoft.com/v1.0/me' ) . parsed
4361 end
4462
4563 def callback_url
4664 options [ :callback_url ] || full_host + script_name + callback_path
47- end
65+ end
66+
67+ def custom_build_access_token
68+ access_token = get_access_token ( request )
69+ access_token
70+ end
71+
72+ alias build_access_token custom_build_access_token
73+
74+ private
75+
76+ def get_access_token ( request )
77+ verifier = request . params [ 'code' ]
78+ redirect_uri = request . params [ 'redirect_uri' ] || request . params [ 'callback_url' ]
79+ if verifier && request . xhr?
80+ client_get_token ( verifier , redirect_uri || '/auth/microsoft_graph/callback' )
81+ elsif verifier
82+ client_get_token ( verifier , redirect_uri || callback_url )
83+ elsif verify_token ( request . params [ 'access_token' ] )
84+ ::OAuth2 ::AccessToken . from_hash ( client , request . params . dup )
85+ elsif request . content_type =~ /json/i
86+ begin
87+ body = JSON . parse ( request . body . read )
88+ request . body . rewind # rewind request body for downstream middlewares
89+ verifier = body && body [ 'code' ]
90+ client_get_token ( verifier , '/auth/microsoft_graph/callback' ) if verifier
91+ rescue JSON ::ParserError => e
92+ warn "[omniauth google-oauth2] JSON parse error=#{ e } "
93+ end
94+ end
95+ end
96+
97+ def client_get_token ( verifier , redirect_uri )
98+ client . auth_code . get_token ( verifier , get_token_options ( redirect_uri ) , get_token_params )
99+ end
100+
101+ def get_token_params
102+ deep_symbolize ( options . auth_token_params || { } )
103+ end
104+
105+ def get_token_options ( redirect_uri = '' )
106+ { redirect_uri : redirect_uri } . merge ( token_params . to_hash ( symbolize_keys : true ) )
107+ end
108+
109+ def get_scope ( params )
110+ raw_scope = params [ :scope ] || DEFAULT_SCOPE
111+ scope_list = raw_scope . split ( ' ' ) . map { |item | item . split ( ',' ) } . flatten
112+ scope_list . map! { |s | s =~ %r{^https?://} || BASE_SCOPES . include? ( s ) ? s : "#{ BASE_SCOPE_URL } #{ s } " }
113+ scope_list . join ( ' ' )
114+ end
115+
116+ def verify_token ( access_token )
117+ return false unless access_token
118+ # access_token.get('https://graph.microsoft.com/v1.0/me').parsed
119+ raw_response = client . request ( :get , 'https://graph.microsoft.com/v1.0/me' ,
120+ params : { access_token : access_token } ) . parsed
121+ ( raw_response [ 'aud' ] == options . client_id ) || options . authorized_client_ids . include? ( raw_response [ 'aud' ] )
122+ end
48123 end
49124 end
50125end
0 commit comments