11import json
22import logging
33import socket
4+ import time
5+ from datetime import datetime
46
57import requests
68import websocket
79
8- from browserdebuggertools .exceptions import ResultNotFoundError , TabNotFoundError , \
9- DomainNotEnabledError
10+ from browserdebuggertools .exceptions import (
11+ ResultNotFoundError , TabNotFoundError ,
12+ DomainNotEnabledError , DevToolsTimeoutException , DomainNotFoundError
13+ )
1014
1115logging .basicConfig (format = '%(levelname)s:%(message)s' )
1216
1317
18+ def open_connection_if_closed (socket_handler_method ):
19+
20+ def retry_if_exception (socket_handler_instance , * args , ** kwargs ):
21+
22+ try :
23+ return socket_handler_method (socket_handler_instance , * args , ** kwargs )
24+
25+ except websocket .WebSocketConnectionClosedException :
26+
27+ socket_handler_instance .increment_connection_closed_count ()
28+ retry_if_exception (socket_handler_instance , * args , ** kwargs )
29+
30+ return retry_if_exception
31+
32+
1433class SocketHandler (object ):
1534
16- CONN_TIMEOUT = 15 # Connection timeout
35+ MAX_CONNECTION_RETRIES = 3
36+ RETRY_COUNT_TIMEOUT = 300 # Seconds
37+ CONN_TIMEOUT = 15 # Connection timeout seconds
1738
18- def __init__ (self , port ):
19- websocket_url = self ._get_websocket_url (port )
20- self .websocket = websocket .create_connection (websocket_url , timeout = self .CONN_TIMEOUT )
21- self .websocket .settimeout (0 ) # Don"t wait for new messages
39+ def __init__ (self , port , timeout , domains = None ):
40+
41+ self .timeout = timeout
42+
43+ if not domains :
44+ domains = {}
45+
46+ self ._domains = domains
47+ self ._events = dict ([(k , []) for k in self ._domains ])
48+ self ._results = {}
2249
2350 self ._next_result_id = 0
24- self .domains = set ()
25- self .results = {}
26- self .events = {}
51+ self ._connection_last_closed = None
52+ self ._connection_closed_count = 0
53+
54+ self ._websocket_url = self ._get_websocket_url (port )
55+ self ._websocket = self ._setup_websocket ()
56+
57+ def __del__ (self ):
58+ try :
59+ self .close ()
60+ except :
61+ pass
62+
63+ def _setup_websocket (self ):
64+
65+ self ._websocket = websocket .create_connection (
66+ self ._websocket_url , timeout = self .CONN_TIMEOUT
67+ )
68+ self ._websocket .settimeout (0 ) # Don"t wait for new messages
69+
70+ for domain , params in self ._domains .items ():
71+ self .enable_domain (domain , params )
72+
73+ return self ._websocket
74+
75+ def increment_connection_closed_count (self ):
76+
77+ now = datetime .now ()
78+
79+ if (
80+ self ._connection_last_closed and
81+ (now - self ._connection_last_closed ).seconds > self .RETRY_COUNT_TIMEOUT
82+ ):
83+ self ._connection_closed_count = 0
84+
85+ self ._connection_last_closed = now
86+ self ._connection_closed_count += 1
87+
88+ if self ._connection_closed_count > self .MAX_CONNECTION_RETRIES :
89+ raise Exception ("Websocket connection found closed too many times" )
90+
91+ self ._setup_websocket ()
92+
93+ @open_connection_if_closed
94+ def _send (self , data ):
95+ data ['id' ] = self ._next_result_id
96+ self ._websocket .send (json .dumps (data , sort_keys = True ))
97+
98+ @open_connection_if_closed
99+ def _recv (self ):
100+ message = self ._websocket .recv ()
101+ if message :
102+ message = json .loads (message )
103+ return message
27104
28105 def _get_websocket_url (self , port ):
29106 targets = requests .get (
@@ -36,67 +113,113 @@ def _get_websocket_url(self, port):
36113 return tabs [0 ]["webSocketDebuggerUrl" ]
37114
38115 def close (self ):
39- self .websocket .close ()
116+ if hasattr (self , "_websocket" ):
117+ self ._websocket .close ()
40118
41119 def _append (self , message ):
120+
42121 if "result" in message :
43- self .results [message ["id" ]] = message .get ("result" )
122+ self ._results [message ["id" ]] = message .get ("result" )
44123 elif "error" in message :
45124 result_id = message .pop ("id" )
46- self .results [result_id ] = message
125+ self ._results [result_id ] = message
47126 elif "method" in message :
48127 domain , event = message ["method" ].split ("." )
49- self .events [domain ].append (message )
128+ self ._events [domain ].append (message )
50129 else :
51130 logging .warning ("Unrecognised message: {}" .format (message ))
52131
53- def flush_messages (self ):
132+ def _flush_messages (self ):
54133 """ Will only return once all the messages have been retrieved.
55134 and will hold the thread until so.
56135 """
57136 try :
58- message = self .websocket . recv ()
137+ message = self ._recv ()
59138 while message :
60- message = json .loads (message )
61139 self ._append (message )
62- message = self .websocket . recv ()
140+ message = self ._recv ()
63141 except socket .error :
64142 return
65143
66- def find_result (self , result_id ):
67- if result_id not in self .results :
68- self .flush_messages ()
144+ def _find_next_result (self ):
145+ if self . _next_result_id not in self ._results :
146+ self ._flush_messages ()
69147
70- if result_id not in self .results :
71- raise ResultNotFoundError ("Result not found for id: {} ." .format (result_id ))
148+ if self . _next_result_id not in self ._results :
149+ raise ResultNotFoundError ("Result not found for id: {} ." .format (self . _next_result_id ))
72150
73- return self .results .pop (result_id )
151+ return self ._results .pop (self . _next_result_id )
74152
75- def execute (self , method , params ):
76- self ._next_result_id += 1
77- self .websocket .send (json .dumps ({
78- "id" : self ._next_result_id , "method" : method , "params" : params if params else {}
79- }, sort_keys = True ))
80- return self ._next_result_id
153+ def execute (self , domainName , methodName , params = None ):
81154
82- def add_domain (self , domain ):
83- if domain not in self .domains :
84- self .domains .add (domain )
85- self .events [domain ] = []
155+ if params is None :
156+ params = {}
86157
87- def remove_domain (self , domain ):
88- if domain in self .domains :
89- self .domains .remove (domain )
158+ self ._next_result_id += 1
159+ method = "{}.{}" .format (domainName , methodName )
160+ self ._send ({
161+ "method" : method , "params" : params
162+ })
163+ return self ._wait_for_result ()
164+
165+ def _add_domain (self , domain , params ):
166+ if domain not in self ._domains :
167+ self ._domains [domain ] = params
168+ self ._events [domain ] = []
169+
170+ def _remove_domain (self , domain ):
171+ if domain in self ._domains :
172+ del self ._domains [domain ]
173+ del self ._events [domain ]
90174
91175 def get_events (self , domain , clear = False ):
92- if domain not in self .domains :
176+ if domain not in self ._domains :
93177 raise DomainNotEnabledError (
94178 'The domain "%s" is not enabled, try enabling it via the interface.' % domain
95179 )
96180
97- self .flush_messages ()
98- events = self .events [domain ][:]
181+ self ._flush_messages ()
182+ events = self ._events [domain ][:]
99183 if clear :
100- self .events [domain ] = []
184+ self ._events [domain ] = []
101185
102186 return events
187+
188+ def _wait_for_result (self ):
189+ """ Waits for a result to complete within the timeout duration then returns it.
190+ Raises a DevToolsTimeoutException if it cannot find the result.
191+
192+ :return: The result.
193+ """
194+ start = time .time ()
195+ while not self .timeout or (time .time () - start ) < self .timeout :
196+ try :
197+ return self ._find_next_result ()
198+ except ResultNotFoundError :
199+ time .sleep (0.5 )
200+ raise DevToolsTimeoutException (
201+ "Reached timeout limit of {}, waiting for a response message" .format (self .timeout )
202+ )
203+
204+ def enable_domain (self , domainName , parameters = None ):
205+
206+ if not parameters :
207+ parameters = {}
208+
209+ self ._add_domain (domainName , parameters )
210+ result = self .execute (domainName , "enable" , parameters )
211+ if "error" in result :
212+ self ._remove_domain (domainName )
213+ raise DomainNotFoundError ("Domain \" {}\" not found." .format (domainName ))
214+
215+ logging .info ("\" {}\" domain has been enabled" .format (domainName ))
216+
217+ def disable_domain (self , domainName ):
218+ """ Disables further notifications from the given domain.
219+ """
220+ self ._remove_domain (domainName )
221+ result = self .execute (domainName , "disable" , {})
222+ if "error" in result :
223+ logging .warn ("Domain \" {}\" doesn't exist" .format (domainName ))
224+ else :
225+ logging .info ("Domain {} has been disabled" .format (domainName ))
0 commit comments