Skip to content

Commit e18e8ad

Browse files
committed
Tensor report card
1 parent 6f378c0 commit e18e8ad

File tree

2 files changed

+414
-0
lines changed

2 files changed

+414
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import multiprocessing.connection
15+
import time
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
17+
18+
from attrs import define
19+
20+
from qualtran import Bloq
21+
from qualtran.simulation.tensor import cbloq_to_quimb
22+
23+
24+
@define
25+
class _Pending:
26+
"""Helper dataclass to track currently executing processes in `ExecuteWithTimeout`."""
27+
28+
p: multiprocessing.Process
29+
recv: multiprocessing.connection.Connection
30+
start_time: float
31+
kwargs: Dict[str, Any]
32+
33+
34+
class ExecuteWithTimeout:
35+
"""Execute tasks in processes where each task will be killed if it exceeds `timeout`.
36+
37+
Seemingly all the existing "timeout" parameters in the various built-in concurrency
38+
primitives in Python won't actually terminate the process. This one does.
39+
"""
40+
41+
def __init__(self, timeout: float, max_workers: int):
42+
self.timeout = timeout
43+
self.max_workers = max_workers
44+
45+
self.queued: List[Tuple[Callable, Dict[str, Any]]] = []
46+
self.pending: List[_Pending] = []
47+
48+
@property
49+
def work_to_be_done(self) -> int:
50+
"""The number of tasks currently executing or queued."""
51+
return len(self.queued) + len(self.pending)
52+
53+
def submit(self, func: Callable, kwargs: Dict[str, Any]) -> None:
54+
"""Add a task to the queue.
55+
56+
`func` must be a callable that can accept `kwargs` in addition to
57+
a keyword argument `cxn` which is a multiprocessing `Connection` object that forms
58+
the sending-half of a `mp.Pipe`. The callable must call `cxn.send(...)`
59+
to return a result.
60+
"""
61+
self.queued.append((func, kwargs))
62+
63+
def _submit_from_queue(self):
64+
# helper method that takes an item from the queue, launches a process,
65+
# and records it in the `pending` attribute. This must only be called
66+
# if we're allowed to spawn a new process.
67+
func, kwargs = self.queued.pop(0)
68+
recv, send = multiprocessing.Pipe(duplex=False)
69+
kwargs['cxn'] = send
70+
p = multiprocessing.Process(target=func, kwargs=kwargs)
71+
start_time = time.time()
72+
p.start()
73+
self.pending.append(_Pending(p=p, recv=recv, start_time=start_time, kwargs=kwargs))
74+
75+
def _scan_pendings(self) -> Optional[_Pending]:
76+
# helper method that goes through the currently pending tasks, terminates the ones
77+
# that have been going on too long, and accounts for ones that have finished.
78+
# Returns the `_Pending` of the killed or completed job or `None` if each pending
79+
# task is still running but none have exceeded the timeout.
80+
for i in range(len(self.pending)):
81+
pen = self.pending[i]
82+
83+
if not pen.p.is_alive():
84+
self.pending.pop(i)
85+
pen.p.join()
86+
return pen
87+
88+
if time.time() - pen.start_time > self.timeout:
89+
pen.p.terminate()
90+
self.pending.pop(i)
91+
return pen
92+
93+
return None
94+
95+
def next_result(self) -> Tuple[Dict[str, Any], Optional[Any]]:
96+
"""Get the next available result.
97+
98+
This call is blocking, but should never take longer than `self.timeout`. This should
99+
be called in a loop to make sure the queue continues to be processed.
100+
101+
Returns:
102+
task kwargs: The keyword arguments used to submit the task.
103+
result: If the process finished successfully, this is the object that was
104+
sent through the multiprocessing pipe as the result. Otherwise, the result
105+
is None.
106+
"""
107+
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
108+
self._submit_from_queue()
109+
110+
while True:
111+
finished = self._scan_pendings()
112+
if finished is not None:
113+
break
114+
115+
if finished.p.exitcode == 0:
116+
result = finished.recv.recv()
117+
else:
118+
result = None
119+
120+
finished.recv.close()
121+
122+
while len(self.queued) > 0 and len(self.pending) < self.max_workers:
123+
self._submit_from_queue()
124+
125+
return (finished.kwargs, result)
126+
127+
128+
def report_on_tensors(name: str, cls_name: str, bloq: Bloq, cxn) -> None:
129+
"""Get timing information for tensor functionality.
130+
131+
This should be used with `ExecuteWithTimeout`. The resultant
132+
record dictionary is sent over `cxn`.
133+
"""
134+
record: Dict[str, Any] = {'name': name, 'cls': cls_name}
135+
136+
try:
137+
start = time.perf_counter()
138+
flat = bloq.as_composite_bloq().flatten()
139+
record['flat_dur'] = time.perf_counter() - start
140+
141+
start = time.perf_counter()
142+
tn = cbloq_to_quimb(flat)
143+
record['tn_dur'] = time.perf_counter() - start
144+
145+
start = time.perf_counter()
146+
record['width'] = tn.contraction_width()
147+
record['width_dur'] = time.perf_counter() - start
148+
149+
except Exception as e: # pylint: disable=broad-exception-caught
150+
record['err'] = str(e)
151+
152+
cxn.send(record)

0 commit comments

Comments
 (0)