# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License import abc import asyncio import dataclasses import datetime import decimal import logging import platform import subprocess import sys import lib.litani class _MemoryProfiler: @abc.abstractmethod async def snapshot(self, root_pid): raise NotImplementedError class _UnixMemoryProfiler(_MemoryProfiler): def __init__(self): self.pids = set() self.ps_cmd = "ps" self.ps_args = ["-x", "-o", "pid,ppid,rss,vsz"] async def snapshot(self, root_pid): """Return a dict containing memory usage of process and children""" ret = { "time": datetime.datetime.now(datetime.timezone.utc).strftime( lib.litani.TIME_FORMAT_R), "rss": 0, } ps_output = await self._get_ps_output() self.pids.add(root_pid) current_pids = set(self.pids) seen_pids = set() while current_pids: new_pids = set() for pid in current_pids: if pid in seen_pids: continue seen_pids.add(pid) self._add_usage(pid, ps_output, ret) for child in self._children_of(pid, ps_output): if child not in seen_pids: new_pids.add(child) current_pids = new_pids return ret @staticmethod def human_readable(memory): units = ["B", "KiB", "MiB", "GiB", "TiB"] idx = 0 memory = decimal.Decimal(memory) while memory > 1023: idx += 1 memory /= 1024 memory_str = memory.quantize( decimal.Decimal("0.1"), rounding=decimal.ROUND_HALF_UP) return f"{memory_str} {units[idx]}" def compute_peak(self, trace): peak = {} for item in trace: for k, v in item.items(): if k in ["time"]: continue try: peak[k] = max(peak[k], v) except KeyError: peak[k] = v human_readable = {} for k, v in peak.items(): human_readable[f"human_readable_{k}"] = self.human_readable(v) return {**peak, **human_readable} async def _get_ps_output(self): """ Format: { "fields": ["pid", "ppid", "rss", "vsz"], "processes": { 123: { "pid": 123, "ppid": 457, "rss": 435116, "vsz": 5761524 } } }""" proc = await asyncio.create_subprocess_exec( self.ps_cmd, *self.ps_args, stdout=subprocess.PIPE) stdout, _ = await proc.communicate() if proc.returncode: logging.error( "%s exited (return code %d)", " ".join(self.ps_cmd), proc.returncode) sys.exit(1) ret = { "fields": [], "processes": {} } for line in stdout.decode("utf-8").splitlines(): if not ret["fields"]: ret["fields"] = [f.lower() for f in line.split()] continue for idx, value in enumerate(line.split()): if not idx: current_pid = int(value) ret["processes"][current_pid] = {"pid": current_pid} continue field = ret["fields"][idx] value = int(value) * 1024 if field in ["rss", "vsz"] else int(value) ret["processes"][current_pid][field] = value return ret @staticmethod def _children_of(pid, ps_output): for child_pid, child_process in ps_output["processes"].items(): if child_process["ppid"] == pid: yield child_pid @staticmethod def _add_usage(pid, ps_output, datum): try: process_record = ps_output["processes"][pid] except KeyError: return for field, value in process_record.items(): if field in ["pid", "ppid"]: continue try: datum[field] += value except KeyError: datum[field] = value @dataclasses.dataclass class _MemoryProfileAccumulator: profiler: _MemoryProfiler profile_interval: int pid: int = None trace: dict = dataclasses.field(default_factory=dict) def set_pid(self, pid): self.pid = pid async def __call__(self): try: while True: if not self.pid: await asyncio.sleep(1) continue result = await self.profiler.snapshot(self.pid) if result: try: self.trace["trace"].append(result) except KeyError: self.trace["trace"] = [result] await asyncio.sleep(self.profile_interval) except asyncio.CancelledError: if "trace" in self.trace: self.trace["peak"] = self.profiler.compute_peak( self.trace["trace"]) @dataclasses.dataclass class _Process: command: str interleave_stdout_stderr: bool timeout: int cwd: str proc: subprocess.CompletedProcess = None stdout: str = None stderr: str = None timeout_reached: bool = None async def __call__(self): if self.interleave_stdout_stderr: pipe = asyncio.subprocess.STDOUT else: pipe = asyncio.subprocess.PIPE proc = await asyncio.create_subprocess_shell( self.command, stdout=asyncio.subprocess.PIPE, stderr=pipe, cwd=self.cwd) self.proc = proc timeout_reached = False try: out, err = await asyncio.wait_for( proc.communicate(), timeout=self.timeout) except asyncio.TimeoutError: proc.terminate() await asyncio.sleep(1) try: proc.kill() except ProcessLookupError: pass out, err = await proc.communicate() timeout_reached = True self.stdout = out self.stderr = err self.timeout_reached = timeout_reached class Runner: @staticmethod def get_memory_profiler(profile_memory, system): if not profile_memory: return None return { "Linux": _UnixMemoryProfiler(), "Darwin": _UnixMemoryProfiler(), }.get(system, None) def __init__( self, command, interleave_stdout_stderr, cwd, timeout, profile_memory, profile_interval): self.tasks = [] self.runner = _Process( command=command, interleave_stdout_stderr=interleave_stdout_stderr, cwd=cwd, timeout=timeout) self.tasks.append(self.runner) self.profiler = None profiler = self.get_memory_profiler(profile_memory, platform.system()) if profiler: self.profiler = _MemoryProfileAccumulator(profiler, profile_interval) self.tasks.append(self.profiler) async def __call__(self): tasks = [] for task in self.tasks: tasks.append(asyncio.create_task(task())) if self.profiler: for _ in range(10): if self.runner.proc: self.profiler.set_pid(self.runner.proc.pid) break await asyncio.sleep(1) _, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED) for task in pending: task.cancel() await task def get_return_code(self): return self.runner.proc.returncode def get_stdout(self): if self.runner.stdout: return self.runner.stdout.decode("utf-8") return None def get_stderr(self): if self.runner.stderr: return self.runner.stderr.decode("utf-8") return None def reached_timeout(self): return self.runner.timeout_reached def get_memory_trace(self): return self.profiler.trace if self.profiler else {}