import datetime
import random
import inspect
import logging
import sys
import traceback
import typing
import os
from contextlib import ExitStack, redirect_stderr, redirect_stdout
from pathlib import Path
import yaml
from ._stream_utils import NotSavingIO, PrintingStringIO, StreamWithTime
from .benchmark_db import BenchmarkDb
from .db.json_serializer import to_json
from .fingerprint import fingerprint
from .log_capture import JsonLogCapture, JsonLogHandler
from .utils import Timer
[docs]
class Benchmark:
"""
This is the heart of the library. It allows to run, save, and load
a benchmark.
The function `add` will run a configuration, if it is not
already in the database. You can also split this into `check` and
`run`. This may be advised if you want to distribute the execution.
The following functions are thread-safe:
- exists
- run
- add
- insert
- front
- capture_logger
- unlink_logger
- __iter__
Don't call any of the other functions while the benchmark is
running. It could lead to data loss.
"""
[docs]
def __init__(
self,
path: str,
save_output: bool = True,
hide_output: bool = True,
save_output_with_time: bool = True,
) -> None:
"""
Just specify the path of where to put the
database and everything else happens magically.
Make sure not to use the same path for different
databases, as they will get mixed.
:param path: The path to the database.
:param save_output: If true, all output (stdout and stderr) will be
saved. If set to false, the output will be discarded. This is
useful if you have a lot of output and don't want to waste disk
space. However, you will not be able to see the output of the
algorithm afterwards. Note that the output can only be saved if
the code aquires the Python sys.stdout and sys.stderr streams
during the execution, as the corresponding streams are replaced
by the benchmark. Normal ``print`` statements do so, but
``logging.StreamHandler`` does not. For the latter, use
``Benchmark.capture_logger``.
:param hide_output: If true, all output (stdout and stderr) will be
hidden. This is useful if you have a lot of output and don't want
to clutter your console. However, you will not be able to see the
output of the algorithm while it is running. Code the aquired handles
to the Python sys.stdout and sys.stderr streams earlier will still be
able to print to the console, as they circumvent the replacement.
:param save_output_with_time: If true, all output (stdout and stderr)
will be saved with the time it was written. This gives you more
insights on the runtime of the algorithm, but also increases the
size of the database. This option is ignored if `save_output` is
set to false.
"""
self._db = BenchmarkDb(path)
self._save_output = save_output
self._hide_output = hide_output
self._save_output_with_time = save_output_with_time
self._log_captures = {}
[docs]
def capture_logger(self, logger_name: str, level=logging.NOTSET):
"""
Capture the logs of a logger of the Python logging module.
This allows you to precisely control which logs you want to
capture. Prefer logging to stdout/stderr, as just using ``print``
will not allow you to control the output of sub-algorithms.
The logging module also allows you to serch more easily for
specific log entries, if used correctly. However, it is
more expensive than just using ``print`` as more metadata
is created. Don't overuse it but only log important events
in the algorithm.
:param logger_name: The name of the logger to capture.
:param level: The level of the logger to capture. The logger will
will automatically be set to this level while capturing, but
will be reset afterwards. NOTSET will not change the level.
:return: None
"""
self._log_captures[logger_name] = level
[docs]
def unlink_logger(self, logger_name: str):
"""
Stop capturing the logs of a logger of the Python logging module
while the benchmark is running.
"""
del self._log_captures[logger_name]
def _get_arg_data(self, func, args, kwargs) -> typing.Tuple[str, typing.Dict]:
sig = inspect.signature(func)
func_args = {
k: v.default
for k, v in sig.parameters.items()
if v.default is not inspect.Parameter.empty
}
func_args.update(sig.bind(*args, **kwargs).arguments)
data = {
"func": func.__name__,
"args": {
key: value
for key, value in func_args.items()
if not key.startswith("_")
},
}
json_data = to_json(data)
assert isinstance(json_data, dict)
return fingerprint(data), json_data
[docs]
def exists(self, func: typing.Callable, *args, **kwargs) -> bool:
"""
Use this function to check if an entry already exist and thus
does not have to be run again. If you want to have multiple
samples, add a sample index argument.
Caveat: This function may have false negatives. i.e., says that it
does not exist despite it existing (only for fresh data).
"""
fingp, _ = self._get_arg_data(func, args, kwargs)
return self._db.contains_fingerprint(fingp)
def _get_stream_obj(self, forward_stream):
if not self._save_output:
# This wrapper just adds a ``getvalue`` method to the stream,
# so it can be used drop-in for StringIO.
return NotSavingIO(forward_stream)
if self._save_output_with_time:
# SteamWithTime is a wrapper around StringIO.
# It stores the time of each line.
# getvalue() returns a list of tuples (time, line).
return StreamWithTime(forward_stream)
else:
return PrintingStringIO(forward_stream)
[docs]
def run(self, func: typing.Callable, *args, **kwargs):
"""
Will add the function call with the arguments
to the benchmark.
The output of stdout and stderr will be captured and stored,
but not printed to the console.
"""
fingp, arg_data = self._get_arg_data(func, args, kwargs)
try:
stdout = self._get_stream_obj(sys.stdout if not self._hide_output else None)
stderr = self._get_stream_obj(sys.stderr if not self._hide_output else None)
with ExitStack() as logging_stack:
log_handler = JsonLogHandler()
for logger_name, level in self._log_captures.items():
logging_stack.enter_context(
JsonLogCapture(logger_name, level, log_handler)
)
with redirect_stdout(stdout), redirect_stderr(stderr):
timestamp = datetime.datetime.now().isoformat()
timer = Timer()
result = func(*args, **kwargs)
runtime = timer.time()
self._db.add(
arg_fingerprint=fingp,
arg_data=arg_data,
result={
"result": result,
"timestamp": timestamp,
"runtime": runtime,
"stdout": stdout.getvalue(),
"stderr": stderr.getvalue(),
"logging": log_handler.get_entries(),
},
)
print(".", end="") # flake8: noqa T201
except Exception as e:
print() # flake8: noqa T201
print("Exception while running benchmark.") # flake8: noqa T201
print("=====================================") # flake8: noqa T201
print(yaml.dump(arg_data)) # flake8: noqa T201
print("-------------------------------------") # flake8: noqa T201
print("ERROR:", e, f"({type(e)})") # flake8: noqa T201
print(traceback.format_exc()) # flake8: noqa T201
print("-------------------------------------") # flake8: noqa T201
raise
[docs]
def add(self, func: typing.Callable, *args, **kwargs):
"""
Will add the function call with the arguments
to the benchmark if not yet contained.
Combination of `check` and `run`.
Will only call `run` if the arguments are not
yet in the benchmark.
"""
if not self.exists(func, *args, **kwargs):
self.run(func, *args, **kwargs)
[docs]
def insert(self, entry: typing.Dict):
"""
Insert a raw entry, as returned by `__iter__` or `front`.
"""
self._db.insert(entry)
[docs]
def compress(self):
"""
Compress the data of the benchmark to take less disk space.
NOT THREAD-SAFE!
"""
self._db.compress()
[docs]
def repair(self):
"""
Repairs the benchmark in case it has some broken entries.
NOT THREAD-SAFE!
"""
self.delete_if(lambda data: False)
def __iter__(self) -> typing.Generator[typing.Dict, None, None]:
"""
Iterate over all entries in the benchmark.
Use `front` to get a preview on how an entry looks like.
"""
for entry in self._db:
yield entry.copy()
[docs]
def delete(self):
"""
Delete the benchmark and all its files. Do not use it afterwards,
there are no files left to write results into.
If you just want to delete the content, use `clear.
NOT THREAD-SAFE!
"""
self._db.delete()
[docs]
def front(self) -> typing.Optional[typing.Dict]:
"""
Return the first entry of the benchmark.
Useful for checking its content.
"""
return self._db.front()
[docs]
def clear(self):
"""
Clears all entries of the benchmark, without deleting
the benchmark itself. You can continue to use it afterwards.
NOT THREAD-SAFE!
"""
self._db.clear()
[docs]
def delete_if(self, condition: typing.Callable[[typing.Dict], bool]):
"""
Delete entries if a specific condition is met (return True).
Recreates the internal 'results' folder for this porpose.
Use `front` to get a preview on how an entry that is
passed to the condition looks like.
NOT THREAD-SAFE!
"""
def func(entry) -> typing.Optional[typing.Dict]:
if condition(entry):
# Delete the entry by returning None
return None
return entry
self.apply(func)
[docs]
def apply(self, func: typing.Callable[[typing.Dict], typing.Optional[typing.Dict]]):
"""
Allows to modify all entries (in place !) inside this benchmark,
based on the provided callable. It is being called for every
entry inside the database, and the returned entry will be stored
instead. If None is returned, the provided entry will be deleted
from the database.
NOT THREAD-SAFE, execute this while no other instance is accessing
the file system.
"""
old_db = self._db
original_path = Path(old_db.path)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
i = 0
new_path = original_path.parent/f"{timestamp}-{i}"
while os.path.exists(new_path):
i += 1
new_path = original_path.parent/f"{timestamp}-{i}"
old_db.move_database(str(new_path))
self._db = BenchmarkDb(original_path)
for entry in old_db:
new_entry = func(entry)
if new_entry:
self.insert(new_entry)
old_db.delete()
self.compress()
def __len__(self):
"""
Return the number of fingerprints in the database.
It is possible that this does not correspond to the
number of entries. Use `__iter__` to iterate over
all entries and count them to get the number of entries.
However, this is not recommended, as it is slow.
"""
return self._db.__len__()
[docs]
def empty(self):
"""
Return True if the database is empty, False otherwise.
"""
return len(self) == 0
[docs]
def fingerprint(self):
"""
Returns a fingerprint over all data contained in this benchmark.
Two fingerprints should be matching exactly if the benchmark contains the same
data, including timestamps etc., no matter the internal structure like order
of entries and possible compression.
"""
hashes = [fingerprint(entry) for entry in self]
hashes.sort()
return fingerprint(hashes)