(
self,
session_id: str,
yaml_file: str,
tasks: List[BatchTask],
websocket_manager,
*,
max_parallel: int = 5,
file_base: str = "batch",
log_level: Optional[LogLevel] = None,
)
| 30 | self.logger = logging.getLogger(__name__) |
| 31 | |
| 32 | async def run_batch( |
| 33 | self, |
| 34 | session_id: str, |
| 35 | yaml_file: str, |
| 36 | tasks: List[BatchTask], |
| 37 | websocket_manager, |
| 38 | *, |
| 39 | max_parallel: int = 5, |
| 40 | file_base: str = "batch", |
| 41 | log_level: Optional[LogLevel] = None, |
| 42 | ) -> None: |
| 43 | batch_id = session_id |
| 44 | total = len(tasks) |
| 45 | |
| 46 | await websocket_manager.send_message( |
| 47 | session_id, |
| 48 | {"type": "batch_started", "data": {"batch_id": batch_id, "total": total}}, |
| 49 | ) |
| 50 | |
| 51 | semaphore = asyncio.Semaphore(max_parallel) |
| 52 | success_count = 0 |
| 53 | failure_count = 0 |
| 54 | result_rows: List[Dict[str, Any]] = [] |
| 55 | result_lock = asyncio.Lock() |
| 56 | |
| 57 | async def run_task(task: BatchTask) -> None: |
| 58 | nonlocal success_count, failure_count |
| 59 | task_id = task.task_id or str(uuid.uuid4()) |
| 60 | task_dir = self._sanitize_label(f"{file_base}-{task_id}") |
| 61 | |
| 62 | await websocket_manager.send_message( |
| 63 | session_id, |
| 64 | { |
| 65 | "type": "batch_task_started", |
| 66 | "data": { |
| 67 | "row_index": task.row_index, |
| 68 | "task_id": task_id, |
| 69 | "task_dir": task_dir, |
| 70 | }, |
| 71 | }, |
| 72 | ) |
| 73 | |
| 74 | try: |
| 75 | result = await asyncio.to_thread( |
| 76 | self._run_single_task, |
| 77 | session_id, |
| 78 | yaml_file, |
| 79 | task, |
| 80 | task_dir, |
| 81 | log_level, |
| 82 | ) |
| 83 | success_count += 1 |
| 84 | async with result_lock: |
| 85 | result_rows.append( |
| 86 | { |
| 87 | "row_index": task.row_index, |
| 88 | "task_id": task_id, |
| 89 | "task_dir": task_dir, |
no test coverage detected