Module materialize.cli.optbench

Expand source code Browse git
# Copyright Materialize, Inc. and contributors. All rights reserved.
#
# Use of this software is governed by the Business Source License
# included in the LICENSE file at the root of this repository.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0.

import csv
import re
import tempfile
from contextlib import closing
from pathlib import Path
from typing import cast

import click
import numpy as np
import pandas as pd

from ..optbench import Scenario, scenarios, sql, util

# import logging
# logging.basicConfig(encoding='utf-8', level=logging.DEBUG)

# Typer CLI Application
# ---------------------


@click.group()
def app() -> None:
    pass


class Arg:
    scenario = dict(
        type=click.Choice(scenarios()),
        callback=lambda ctx, param, value: Scenario(value),
    )

    base = dict(
        type=click.Path(
            exists=True,
            file_okay=True,
            dir_okay=False,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),
    )

    diff = dict(
        type=click.Path(
            exists=True,
            file_okay=True,
            dir_okay=False,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),
    )


class Opt:
    repository = dict(
        default=Path(tempfile.gettempdir()),
        type=click.Path(
            exists=True,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        help="Experiment results folder.",
        callback=lambda ctx, param, value: Path(value),
    )

    samples = dict(default=11, help="Samples per query.")

    print_results = dict(default=False, help="Print the experiment results.")

    no_indexes = dict(default=False, help="Skip CREATE [DEFAULT]/DROP INDEX DDL.")

    db_port = dict(default=6875, help="DB connection port.", envvar="PGPORT")

    db_host = dict(default="localhost", help="DB connection host.", envvar="PGHOST")

    db_user = dict(default="materialize", help="DB connection user.", envvar="PGUSER")

    db_pass = dict(default=None, help="DB connection password.", envvar="PGPASSWORD")

    db_require_ssl = dict(
        is_flag=True,
        help="DB connection requires SSL.",
        envvar="PGREQUIRESSL",
    )


@app.command()
@click.argument("scenario", **Arg.scenario)
@click.option("--no-indexes", **Opt.no_indexes)
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
def init(
    scenario: Scenario,
    no_indexes: bool,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
) -> None:
    """Initialize the DB under test for the given scenario."""

    info(f'Initializing "{scenario}" as the DB under test')

    try:
        # connect to the default database in order to re-create the
        # database for the selected scenario
        with closing(
            sql.Database(
                port=db_port,
                host=db_host,
                user=db_user,
                database=None,
                password=db_pass,
                require_ssl=db_require_ssl,
            )
        ) as db:
            db.drop_database(scenario)
            db.create_database(scenario)

        # re-connect to the database for the selected scenario
        with closing(
            sql.Database(
                port=db_port,
                host=db_host,
                user=db_user,
                database=str(scenario),
                password=db_pass,
                require_ssl=db_require_ssl,
            )
        ) as db:
            statements = sql.parse_from_file(scenario.schema_path())
            if no_indexes:
                idx_re = re.compile(r"(create|create\s+default|drop)\s+index\s+")
                statements = [
                    statement
                    for statement in statements
                    if not idx_re.match(statement.lower())
                ]
            db.execute_all(statements)
    except Exception as e:
        raise click.ClickException(f"init command failed: {e}")


@app.command()
@click.argument("scenario", **Arg.scenario)
@click.option("--samples", **Opt.samples)
@click.option("--repository", **Opt.repository)
@click.option("--print-results", **Opt.print_results)
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
def run(
    scenario: Scenario,
    samples: int,
    repository: Path,
    print_results: bool,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
) -> None:
    """Run benchmark in the DB under test for a given scenario."""

    info(f'Running "{scenario}" scenario')

    try:
        with closing(
            sql.Database(
                port=db_port,
                host=db_host,
                user=db_user,
                database=str(scenario),
                password=db_pass,
                require_ssl=db_require_ssl,
            )
        ) as db:
            db_version = db.mz_version() or db.version()
            df = pd.DataFrame.from_records(
                [
                    (
                        query.name(),
                        sample,
                        cast(
                            np.timedelta64,
                            db.explain(query, timing=True).optimization_time(),
                        ).astype(int),
                    )
                    for sample in range(samples)
                    for query in [
                        sql.Query(query)
                        for query in sql.parse_from_file(scenario.workload_path())
                    ]
                ],
                columns=["query", "sample", "data"],
            ).pivot(index="sample", columns="query", values="data")

            if print_results:
                print(df.to_string())

        results_path = util.results_path(repository, scenario, db_version)
        info(f'Writing results to "{results_path}"')
        df.to_csv(results_path, index=False, quoting=csv.QUOTE_MINIMAL)
    except Exception as e:
        raise click.ClickException(f"run command failed: {e}")


@app.command()
@click.argument("base", **Arg.base)
@click.argument("diff", **Arg.diff)
def compare(
    base: Path,
    diff: Path,
) -> None:
    """Compare the results of a base and diff benchmark runs."""

    info(f'Compare experiment results between "{base}" and "{diff}"')

    try:
        base_df = pd.read_csv(base, quoting=csv.QUOTE_MINIMAL).agg(
            ["min", "median", "max"]
        )

        diff_df = pd.read_csv(diff, quoting=csv.QUOTE_MINIMAL).agg(
            ["min", "median", "max"]
        )

        # compute diff/base quotient for all (metric, query) pairs
        quot_df = diff_df / base_df
        # append average quotient across all queries for each metric
        quot_df.insert(0, "Avg", quot_df.mean(axis=1))

        # TODO: use styler to color-code the cells
        print("base times")
        print("----------")
        print(base_df.to_string())
        print("")
        print("diff times")
        print("----------")
        print(diff_df.to_string())
        print("")
        print("diff/base ratio")
        print("---------------")
        print(quot_df.to_string())
    except Exception as e:
        raise click.ClickException(f"compare command failed: {e}")


# Utility methods
# ---------------


def print_df(df: pd.DataFrame) -> None:
    with pd.option_context("display.max_rows", None, "display.max_columns", None):
        print(df)


def info(msg: str, fg: str = "green") -> None:
    click.secho(msg, fg=fg)


def err(msg: str, fg: str = "red") -> None:
    click.secho(msg, fg=fg, err=True)


if __name__ == "__main__":
    app()

Functions

def err(msg: str, fg: str = 'red') ‑> None
Expand source code Browse git
def err(msg: str, fg: str = "red") -> None:
    click.secho(msg, fg=fg, err=True)
def info(msg: str, fg: str = 'green') ‑> None
Expand source code Browse git
def info(msg: str, fg: str = "green") -> None:
    click.secho(msg, fg=fg)
def print_df(df: pandas.core.frame.DataFrame) ‑> None
Expand source code Browse git
def print_df(df: pd.DataFrame) -> None:
    with pd.option_context("display.max_rows", None, "display.max_columns", None):
        print(df)

Classes

class Arg
Expand source code Browse git
class Arg:
    scenario = dict(
        type=click.Choice(scenarios()),
        callback=lambda ctx, param, value: Scenario(value),
    )

    base = dict(
        type=click.Path(
            exists=True,
            file_okay=True,
            dir_okay=False,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),
    )

    diff = dict(
        type=click.Path(
            exists=True,
            file_okay=True,
            dir_okay=False,
            writable=False,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),
    )

Class variables

var base
var diff
var scenario
class Opt
Expand source code Browse git
class Opt:
    repository = dict(
        default=Path(tempfile.gettempdir()),
        type=click.Path(
            exists=True,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        help="Experiment results folder.",
        callback=lambda ctx, param, value: Path(value),
    )

    samples = dict(default=11, help="Samples per query.")

    print_results = dict(default=False, help="Print the experiment results.")

    no_indexes = dict(default=False, help="Skip CREATE [DEFAULT]/DROP INDEX DDL.")

    db_port = dict(default=6875, help="DB connection port.", envvar="PGPORT")

    db_host = dict(default="localhost", help="DB connection host.", envvar="PGHOST")

    db_user = dict(default="materialize", help="DB connection user.", envvar="PGUSER")

    db_pass = dict(default=None, help="DB connection password.", envvar="PGPASSWORD")

    db_require_ssl = dict(
        is_flag=True,
        help="DB connection requires SSL.",
        envvar="PGREQUIRESSL",
    )

Class variables

var db_host
var db_pass
var db_port
var db_require_ssl
var db_user
var no_indexes
var print_results
var repository
var samples