|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import multiprocessing.connection |
| 15 | +import time |
| 16 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 17 | + |
| 18 | +from attrs import define |
| 19 | + |
| 20 | +from qualtran import Bloq |
| 21 | +from qualtran.simulation.tensor import cbloq_to_quimb |
| 22 | + |
| 23 | + |
| 24 | +@define |
| 25 | +class _Pending: |
| 26 | + """Helper dataclass to track currently executing processes in `ExecuteWithTimeout`.""" |
| 27 | + |
| 28 | + p: multiprocessing.Process |
| 29 | + recv: multiprocessing.connection.Connection |
| 30 | + start_time: float |
| 31 | + kwargs: Dict[str, Any] |
| 32 | + |
| 33 | + |
| 34 | +class ExecuteWithTimeout: |
| 35 | + """Execute tasks in processes where each task will be killed if it exceeds `timeout`. |
| 36 | +
|
| 37 | + Seemingly all the existing "timeout" parameters in the various built-in concurrency |
| 38 | + primitives in Python won't actually terminate the process. This one does. |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, timeout: float, max_workers: int): |
| 42 | + self.timeout = timeout |
| 43 | + self.max_workers = max_workers |
| 44 | + |
| 45 | + self.queued: List[Tuple[Callable, Dict[str, Any]]] = [] |
| 46 | + self.pending: List[_Pending] = [] |
| 47 | + |
| 48 | + @property |
| 49 | + def work_to_be_done(self) -> int: |
| 50 | + """The number of tasks currently executing or queued.""" |
| 51 | + return len(self.queued) + len(self.pending) |
| 52 | + |
| 53 | + def submit(self, func: Callable, kwargs: Dict[str, Any]) -> None: |
| 54 | + """Add a task to the queue. |
| 55 | +
|
| 56 | + `func` must be a callable that can accept `kwargs` in addition to |
| 57 | + a keyword argument `cxn` which is a multiprocessing `Connection` object that forms |
| 58 | + the sending-half of a `mp.Pipe`. The callable must call `cxn.send(...)` |
| 59 | + to return a result. |
| 60 | + """ |
| 61 | + self.queued.append((func, kwargs)) |
| 62 | + |
| 63 | + def _submit_from_queue(self): |
| 64 | + # helper method that takes an item from the queue, launches a process, |
| 65 | + # and records it in the `pending` attribute. This must only be called |
| 66 | + # if we're allowed to spawn a new process. |
| 67 | + func, kwargs = self.queued.pop(0) |
| 68 | + recv, send = multiprocessing.Pipe(duplex=False) |
| 69 | + kwargs['cxn'] = send |
| 70 | + p = multiprocessing.Process(target=func, kwargs=kwargs) |
| 71 | + start_time = time.time() |
| 72 | + p.start() |
| 73 | + self.pending.append(_Pending(p=p, recv=recv, start_time=start_time, kwargs=kwargs)) |
| 74 | + |
| 75 | + def _scan_pendings(self) -> Optional[_Pending]: |
| 76 | + # helper method that goes through the currently pending tasks, terminates the ones |
| 77 | + # that have been going on too long, and accounts for ones that have finished. |
| 78 | + # Returns the `_Pending` of the killed or completed job or `None` if each pending |
| 79 | + # task is still running but none have exceeded the timeout. |
| 80 | + for i in range(len(self.pending)): |
| 81 | + pen = self.pending[i] |
| 82 | + |
| 83 | + if not pen.p.is_alive(): |
| 84 | + self.pending.pop(i) |
| 85 | + pen.p.join() |
| 86 | + return pen |
| 87 | + |
| 88 | + if time.time() - pen.start_time > self.timeout: |
| 89 | + pen.p.terminate() |
| 90 | + self.pending.pop(i) |
| 91 | + return pen |
| 92 | + |
| 93 | + return None |
| 94 | + |
| 95 | + def next_result(self) -> Tuple[Dict[str, Any], Optional[Any]]: |
| 96 | + """Get the next available result. |
| 97 | +
|
| 98 | + This call is blocking, but should never take longer than `self.timeout`. This should |
| 99 | + be called in a loop to make sure the queue continues to be processed. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + task kwargs: The keyword arguments used to submit the task. |
| 103 | + result: If the process finished successfully, this is the object that was |
| 104 | + sent through the multiprocessing pipe as the result. Otherwise, the result |
| 105 | + is None. |
| 106 | + """ |
| 107 | + while len(self.queued) > 0 and len(self.pending) < self.max_workers: |
| 108 | + self._submit_from_queue() |
| 109 | + |
| 110 | + while True: |
| 111 | + finished = self._scan_pendings() |
| 112 | + if finished is not None: |
| 113 | + break |
| 114 | + |
| 115 | + if finished.p.exitcode == 0: |
| 116 | + result = finished.recv.recv() |
| 117 | + else: |
| 118 | + result = None |
| 119 | + |
| 120 | + finished.recv.close() |
| 121 | + |
| 122 | + while len(self.queued) > 0 and len(self.pending) < self.max_workers: |
| 123 | + self._submit_from_queue() |
| 124 | + |
| 125 | + return (finished.kwargs, result) |
| 126 | + |
| 127 | + |
| 128 | +def report_on_tensors(name: str, cls_name: str, bloq: Bloq, cxn) -> None: |
| 129 | + """Get timing information for tensor functionality. |
| 130 | +
|
| 131 | + This should be used with `ExecuteWithTimeout`. The resultant |
| 132 | + record dictionary is sent over `cxn`. |
| 133 | + """ |
| 134 | + record: Dict[str, Any] = {'name': name, 'cls': cls_name} |
| 135 | + |
| 136 | + try: |
| 137 | + start = time.perf_counter() |
| 138 | + flat = bloq.as_composite_bloq().flatten() |
| 139 | + record['flat_dur'] = time.perf_counter() - start |
| 140 | + |
| 141 | + start = time.perf_counter() |
| 142 | + tn = cbloq_to_quimb(flat) |
| 143 | + record['tn_dur'] = time.perf_counter() - start |
| 144 | + |
| 145 | + start = time.perf_counter() |
| 146 | + record['width'] = tn.contraction_width() |
| 147 | + record['width_dur'] = time.perf_counter() - start |
| 148 | + |
| 149 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 150 | + record['err'] = str(e) |
| 151 | + |
| 152 | + cxn.send(record) |
0 commit comments