Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scalene/redirect_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def redirect_python(
f"python{sys.version_info.major}.{sys.version_info.minor}{base_python_extension}",
]

shebang = "@echo off" if sys.platform == "win32" else "#!/bin/bash"
shebang = "@echo off" if sys.platform == "win32" else "#!/usr/bin/env bash"
all_args = "%*" if sys.platform == "win32" else '"$@"'

payload = f"{shebang}\n{preface} {sys.executable} -m scalene {cmdline} {all_args}\n"
Expand Down
205 changes: 205 additions & 0 deletions scalene/scalene_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import asyncio
import sys
import threading
import gc

from types import (
AsyncGeneratorType,
FrameType
)
from typing import (
List,
Tuple,
cast,
)


class ScaleneAsyncio:
"""Provides a set of methods to collect idle task frames."""

should_trace = None
loops: List[Tuple[asyncio.AbstractEventLoop, int]] = []
current_task = None

@staticmethod
def current_task_exists(tident) -> bool:
"""Given TIDENT, returns true if a current task exists. Returns
true if no event loop is running on TIDENT."""
current = True
for loop, t in ScaleneAsyncio.loops:
if t == tident:
current = asyncio.current_task(loop)
break
return bool(current)

@staticmethod
def compute_suspended_frames_to_record(should_trace) -> \
List[Tuple[FrameType, int, FrameType]]:
"""Collect all frames which belong to suspended tasks."""
# TODO this is an ugly way to access the function
ScaleneAsyncio.should_trace = should_trace
ScaleneAsyncio.loops = ScaleneAsyncio._get_event_loops()

return ScaleneAsyncio._get_frames_from_loops(ScaleneAsyncio.loops)

@staticmethod
def _get_event_loops() -> List[Tuple[asyncio.AbstractEventLoop, int]]:
"""Returns each thread's event loop. If there are none, returns
the empty array."""
loops = []
for t in threading.enumerate():
frame = sys._current_frames().get(t.ident)
if frame:
loop = ScaleneAsyncio._walk_back_until_loop(frame)
# duplicates shouldn't be possible, but just in case...
if loop and loop not in loops:
loops.append((loop, cast(int, t.ident)))
return loops

@staticmethod
def _walk_back_until_loop(frame) -> asyncio.AbstractEventLoop:
"""Helper for get_event_loops.
Walks back the callstack until we are in a method named '_run_once'.
If this becomes true and the 'self' variable is an instance of
AbstractEventLoop, then we return that variable.
This works because _run_once is one of the main methods asyncio uses
to facilitate its event loop, and is always on the stack while the
loop runs."""
while frame:
if frame.f_code.co_name == '_run_once' and \
'self' in frame.f_locals:
loop = frame.f_locals['self']
if isinstance(loop, asyncio.AbstractEventLoop):
return loop
else:
frame = frame.f_back
return None

@staticmethod
def _get_frames_from_loops(loops) -> \
List[Tuple[FrameType, int, FrameType]]:
"""Given LOOPS, returns a flat list of frames corresponding to idle
tasks."""
return [
(frame, tident, None) for loop, tident in loops
for frame in ScaleneAsyncio._get_idle_task_frames(loop)
]

@staticmethod
def _get_idle_task_frames(loop) -> List[FrameType]:
"""Given an asyncio event loop, returns the list of idle task frames.
We only care about idle task frames, as running tasks are already
included elsewhere.
A task is considered 'idle' if it is pending and not the current
task."""
idle = []

# set this when we start processing a loop.
# it is required later, but I only want to set it once.
ScaleneAsyncio.current_task = asyncio.current_task(loop)

for task in asyncio.all_tasks(loop):
if not ScaleneAsyncio._should_trace_task(task):
continue

coro = task.get_coro()

frame = ScaleneAsyncio._get_deepest_traceable_frame(coro)
if frame:
idle.append(cast(FrameType, frame))

return idle

@staticmethod
def _get_deepest_traceable_frame(coro) -> FrameType:
"""Get the deepest frame of coro we care to trace.
This is possible because each corooutine keeps a reference to the
coroutine it is waiting on.
Note that it cannot be the case that a task is suspended in a frame
that does not belong to a coroutine, asyncio is very particular about
that! This is also why we only track idle tasks this way."""
curr = coro
deepest_frame = None
while curr:
frame = getattr(curr, 'cr_frame', None)

if not frame:
curr = ScaleneAsyncio._search_awaitable(curr)
if isinstance(curr, AsyncGeneratorType):
frame = getattr(curr, 'ag_frame', None)
else:
break

if ScaleneAsyncio.should_trace(frame.f_code.co_filename,
frame.f_code.co_name):
deepest_frame = frame

if isinstance(curr, AsyncGeneratorType):
curr = getattr(curr, 'ag_await', None)
else:
curr = getattr(curr, 'cr_await', None)

# if this task is found to point to another task we're profiling,
# then we will get the deepest frame later and should return nothing.
# this is specific to gathering futures, i.e., gather statement.
if isinstance(curr, asyncio.Future):
tasks = getattr(curr, '_children', [])
if any(
ScaleneAsyncio._should_trace_task(task)
for task in tasks
):
return None

return deepest_frame

@staticmethod
def _search_awaitable(awaitable):
"""Given an awaitable which is not a coroutine, assume it is a future
and attempt to find references to further futures or async generators.
"""
future = None
if not isinstance(awaitable, asyncio.Future):
# TODO some wrappers like _asyncio.FutureIter,
# async_generator_asend get caught here, I am not sure if a more
# robust approach is necessary

# can gc be avoided here?
refs = gc.get_referents(awaitable)
if refs:
future = refs[0]

return future

@staticmethod
def _should_trace_task(task) -> bool:
"""Returns FALSE if TASK is uninteresting to the user.
A task is interesting if it is not the current task, if it has actually
started executing, and if a child task did not originate from it.
"""
if not isinstance(task, asyncio.Task):
return False

# the task is not idle
if task == ScaleneAsyncio.current_task:
return False

coro = task.get_coro()

# the task hasn't even run yet
# assumes that all started tasks are sitting at an await
# statement.
# if this isn't the case, the associated coroutine will
# be 'waiting' on the coroutine declaration. No! Bad!
if getattr(coro, 'cr_frame', None) is None or \
getattr(coro, 'cr_await', None) is None:
return False

frame = getattr(coro, 'cr_frame', None)

return ScaleneAsyncio.should_trace(frame.f_code.co_filename,
frame.f_code.co_name)
27 changes: 16 additions & 11 deletions scalene/scalene_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

import scalene.scalene_config
from scalene.scalene_arguments import ScaleneArguments
from scalene.scalene_asyncio import ScaleneAsyncio
from scalene.scalene_client_timer import ScaleneClientTimer
from scalene.scalene_funcutils import ScaleneFuncUtils
from scalene.scalene_json import ScaleneJSON
Expand Down Expand Up @@ -163,7 +164,7 @@ def enable_profiling() -> Generator[None, None, None]:
yield
stop()


class Scalene:
"""The Scalene profiler itself."""

Expand Down Expand Up @@ -512,13 +513,13 @@ def malloc_signal_handler(
):
Scalene.update_profiled()
pywhere.set_last_profiled_invalidated_false()
# In the setprofile callback, we rely on
# __last_profiled always having the same memory address.
# In the setprofile callback, we rely on
# __last_profiled always having the same memory address.
# This is an optimization to not have to traverse the Scalene profiler
# object's dictionary every time we want to update the last profiled line.
#
# A previous change to this code set Scalene.__last_profiled = [fname, lineno, lasti],
# which created a new list object and set the __last_profiled attribute to the new list. This
# which created a new list object and set the __last_profiled attribute to the new list. This
# made the object held in `pywhere.cpp` out of date, and caused the profiler to not update the last profiled line.
Scalene.__last_profiled[:] = [
Filename(f.f_code.co_filename),
Expand Down Expand Up @@ -756,6 +757,7 @@ def cpu_signal_handler(
Scalene.process_cpu_sample(
signum,
Scalene.compute_frames_to_record(),
ScaleneAsyncio.compute_suspended_frames_to_record(Scalene.should_trace),
now,
gpu_load,
gpu_mem_used,
Expand Down Expand Up @@ -906,6 +908,7 @@ def process_cpu_sample(
None,
],
new_frames: List[Tuple[FrameType, int, FrameType]],
async_frames: List[Tuple[FrameType, int, FrameType]],
now: TimeInfo,
gpu_load: float,
gpu_mem_used: float,
Expand Down Expand Up @@ -1042,7 +1045,7 @@ def process_cpu_sample(
Scalene.__stats.gpu_stats.gpu_mem_samples[fname][lineno].push(gpu_mem_used)

# Now handle the rest of the threads.
for frame, tident, orig_frame in new_frames:
for frame, tident, orig_frame in new_frames + async_frames:
if frame == main_thread_frame:
continue
add_stack(
Expand All @@ -1068,10 +1071,12 @@ def process_cpu_sample(
# Ignore sleeping threads.
continue
# Check if the original caller is stuck inside a call.
if ScaleneFuncUtils.is_call_function(
orig_frame.f_code,
ByteCodeIndex(orig_frame.f_lasti),
):
# TODO
if orig_frame is None or \
ScaleneFuncUtils.is_call_function(
orig_frame.f_code,
ByteCodeIndex(orig_frame.f_lasti),
):
# It is. Attribute time to native.
Scalene.__stats.cpu_stats.cpu_samples_c[fname][lineno] += normalized_time
else:
Expand Down Expand Up @@ -1225,7 +1230,7 @@ def alloc_sigqueue_processor(x: Optional[List[int]]) -> None:
freed_last_trigger = 0
for item in arr:
is_malloc = item.action == Scalene.MALLOC_ACTION
if item.count == scalene.scalene_config.NEWLINE_TRIGGER_LENGTH + 1:
if item.count == scalene.scalene_config.NEWLINE_TRIGGER_LENGTH + 1:
continue # in previous implementations, we were adding NEWLINE to the footprint.
# We should not account for this in the user-facing profile.
count = item.count / Scalene.BYTES_PER_MB
Expand Down Expand Up @@ -1441,7 +1446,7 @@ def memcpy_sigqueue_processor(
lineno=LineNumber(int(lineno)),
bytecode_index=ByteCodeIndex(int(bytei)))
arr.append(memcpy_profiling_sample)

arr.sort()

for item in arr:
Expand Down
Loading