Module materialize.cli.mzexplore

Expand source code Browse git
# Copyright Materialize, Inc. and contributors. All rights reserved.
#
# Use of this software is governed by the Business Source License
# included in the LICENSE file at the root of this repository.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0.

from pathlib import Path
from typing import Any

import click

import materialize.mzexplore as api
import materialize.mzexplore.common as common

# import logging
# logging.basicConfig(encoding='utf-8', level=logging.DEBUG)

# Click CLI Application
# ---------------------


@click.group()
def app() -> None:
    pass


class Arg:
    repository: dict[str, Any] = dict(
        type=click.Path(
            exists=False,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),  # type: ignore
    )

    output_file: dict[str, Any] = dict(
        type=click.Path(
            exists=False,
            file_okay=True,
            dir_okay=False,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),  # type: ignore
    )

    base_suffix: dict[str, Any] = dict(
        type=str,
        metavar="BASE",
    )

    diff_suffix: dict[str, Any] = dict(
        type=str,
        metavar="DIFF",
    )


class Opt:
    db_port: dict[str, Any] = dict(
        default=6877,
        help="DB connection port.",
        envvar="PGPORT",
    )

    db_host: dict[str, Any] = dict(
        default="localhost",
        help="DB connection host.",
        envvar="PGHOST",
    )

    db_user: dict[str, Any] = dict(
        default="mz_support",
        help="DB connection user.",
        envvar="PGUSER",
    )

    db_pass: dict[str, Any] = dict(
        default=None,
        help="DB connection password.",
        envvar="PGPASSWORD",
    )

    db_require_ssl: dict[str, Any] = dict(
        is_flag=True,
        help="DB connection requires SSL.",
        envvar="PGREQUIRESSL",
    )

    mzfmt: dict[str, Any] = dict(
        default=True,
        help="Format SQL statements with `mzfmt` if present.",
    )

    explainee_type: dict[str, Any] = dict(
        type=click.Choice([v.name.lower() for v in list(api.ExplaineeType)]),
        default="catalog_item",
        callback=lambda ctx, param, v: api.ExplaineeType[v.upper()],  # type: ignore
        help="EXPLAIN mode.",
        metavar="MODE",
    )

    explain_options: dict[str, Any] = dict(
        type=api.ExplainOptionType(),
        multiple=True,
        help="WITH key=val pairs to pass to the EXPLAIN command.",
        metavar="KEY=VAL",
    )

    explain_stage: dict[str, Any] = dict(
        type=click.Choice([str(v.name.lower()) for v in list(api.ExplainStage)]),
        multiple=True,
        default=["optimized_plan"],  # Most often we'll only the optimized plan.
        callback=lambda ctx, param, vals: [api.ExplainStage[v.upper()] for v in vals],  # type: ignore
        help="EXPLAIN stage to export.",
        metavar="STAGE",
    )

    explain_suffix: dict[str, Any] = dict(
        type=str,
        default="",
        help="A suffix to append to the EXPLAIN output files.",
        metavar="SUFFIX",
    )

    explain_format: dict[str, Any] = dict(
        type=click.Choice([str(v.name.lower()) for v in list(api.ExplainFormat)]),
        default="text",
        callback=lambda ctx, param, v: api.ExplainFormat[v.upper()],  # type: ignore
        help="AS [FORMAT] clause to pass to the EXPLAIN command.",
        metavar="FORMAT",
    )

    system: dict[str, Any] = dict(
        is_flag=True,
        show_default=True,
        default=False,
        help="Inspect system or user tables.",
    )


def is_documented_by(original: Any) -> Any:
    def wrapper(target):
        target.__doc__ = original.__doc__
        return target

    return wrapper


@app.group()
@is_documented_by(api.extract)
def extract() -> None:
    pass


@extract.command(name="defs")
@click.argument("target", **Arg.repository)
@click.argument("database", type=str)
@click.argument("schema", type=str)
@click.argument("name", type=str)
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
@click.option("--mzfmt/--no-mzfmt", **Opt.mzfmt)
@is_documented_by(api.extract.defs)
def extract_defs(
    target: Path,
    database: str,
    schema: str,
    name: str,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
    mzfmt: bool,
) -> None:
    try:
        api.extract.defs(
            target=target,
            database=database,
            schema=schema,
            name=name,
            db_port=db_port,
            db_host=db_host,
            db_user=db_user,
            db_pass=db_pass,
            db_require_ssl=db_require_ssl,
            mzfmt=mzfmt,
        )
    except Exception as e:
        import traceback

        traceback.print_tb(e.__traceback__)
        raise click.ClickException(f"extract defs command failed: {e=}, {type(e)=}")


@extract.command(name="plans")
@click.argument("target", **Arg.repository)
@click.argument("database", type=str)
@click.argument("schema", type=str)
@click.argument("name", type=str)
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
@click.option("--explainee-type", "-t", **Opt.explainee_type)
@click.option("--with", "-w", "explain_options", **Opt.explain_options)
@click.option("--stage", "-s", "explain_stages", **Opt.explain_stage)
@click.option("--format", "-f", "explain_format", **Opt.explain_format)
@click.option("--system/--user", "system", **Opt.system)
@click.option("--suffix", **Opt.explain_suffix)
@is_documented_by(api.extract.plans)
def extract_plans(
    target: Path,
    database: str,
    schema: str,
    name: str,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
    explainee_type: api.ExplaineeType,
    explain_options: list[api.ExplainOption],
    explain_stages: set[api.ExplainStage],
    explain_format: api.ExplainFormat,
    system: bool,
    suffix: str | None = None,
) -> None:
    try:
        api.extract.plans(
            target=target,
            database=database,
            schema=schema,
            name=name,
            db_port=db_port,
            db_host=db_host,
            db_user=db_user,
            db_pass=db_pass,
            db_require_ssl=db_require_ssl,
            explainee_type=explainee_type,
            explain_options=explain_options,
            explain_stages=explain_stages,
            explain_format=explain_format,
            system=system,
            suffix=suffix,
        )
    except Exception as e:
        import traceback

        traceback.print_tb(e.__traceback__)
        raise click.ClickException(f"extract plans command failed: {e=}, {type(e)=}")


@extract.command(name="arrangement-sizes")
@click.argument("target", **Arg.repository)
@click.argument("cluster", type=str)
@click.argument("cluster_replica", type=str)
@click.argument("database", type=str)
@click.argument("schema", type=str)
@click.argument("name", type=str)
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
@click.option("--print-results", is_flag=True, default=False)
@is_documented_by(api.extract.arrangement_sizes)
def extract_arrangement_sizes(
    target: Path,
    cluster: str,
    cluster_replica: str,
    database: str,
    schema: str,
    name: str,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
    print_results: bool,
) -> None:
    try:
        api.extract.arrangement_sizes(
            target=target,
            cluster=cluster,
            cluster_replica=cluster_replica,
            database=database,
            schema=schema,
            name=name,
            db_port=db_port,
            db_host=db_host,
            db_user=db_user,
            db_pass=db_pass,
            db_require_ssl=db_require_ssl,
            print_results=print_results,
        )
    except Exception as e:
        import traceback

        traceback.print_tb(e.__traceback__)
        raise click.ClickException(
            f"extract arrangement-sizes command failed: {e=}, {type(e)=}"
        )


@app.group()
@is_documented_by(api.analyze)
def analyze() -> None:
    pass


@analyze.command(name="changes")
@click.argument("target", **Arg.repository)  # type: ignore
@click.argument("summary_file", **Arg.output_file)  # type: ignore
@click.argument("base_suffix", **Arg.base_suffix)
@click.argument("diff_suffix", **Arg.diff_suffix)
@is_documented_by(api.analyze.changes)
def analyze_changes(
    target: Path,
    summary_file: Path,
    base_suffix: str,
    diff_suffix: str,
) -> None:
    """
    Prepare an analysis report of plan changes as a Markdown document.
    """

    try:
        with summary_file.open("a+", encoding="utf-8") as out:
            api.analyze.changes(
                out=out,
                target=target,
                header_name=str(target),
                base_suffix=base_suffix,
                diff_suffix=diff_suffix,
            )
            common.info(f"Summary written to {summary_file}")

    except Exception as e:
        import traceback

        traceback.print_tb(e.__traceback__)
        msg = f"analyze changes command failed: {e=}, {type(e)=}"
        raise click.ClickException(msg) from e


@app.group()
@is_documented_by(api.clone)
def clone() -> None:
    pass


@clone.command(name="defs")
@click.argument("database", type=str)
@click.argument("schema", type=str)
@click.argument("cluster", type=str)
@click.argument("object_ids", type=str, nargs=-1)
@click.argument("ddl_file", **Arg.output_file)  # type: ignore
@click.option("--db-port", **Opt.db_port)
@click.option("--db-host", **Opt.db_host)
@click.option("--db-user", **Opt.db_user)
@click.option("--db-pass", **Opt.db_pass)
@click.option("--db-require-ssl", **Opt.db_require_ssl)
@click.option("--mzfmt/--no-mzfmt", **Opt.mzfmt)
@is_documented_by(api.clone.defs)
def clone_defs(
    database: str,
    schema: str,
    cluster: str,
    object_ids: list[str],
    ddl_file: Path,
    db_port: int,
    db_host: str,
    db_user: str,
    db_pass: str | None,
    db_require_ssl: bool,
    mzfmt: bool,
) -> None:
    try:
        cmp_file = Path(ddl_file.parent, f"__compare__{ddl_file.name}")
        with ddl_file.open("w", encoding="utf-8") as ddl_out:
            with cmp_file.open("w", encoding="utf-8") as cmp_out:
                api.clone.defs(
                    ddl_out=ddl_out,
                    cmp_out=cmp_out,
                    database=database,
                    schema=schema,
                    cluster=cluster,
                    object_ids=object_ids,
                    db_port=db_port,
                    db_host=db_host,
                    db_user=db_user,
                    db_pass=db_pass,
                    db_require_ssl=db_require_ssl,
                    mzfmt=mzfmt,
                )
                common.info(f"Modified DDL written to {ddl_file}")
                common.info(f"Original DDL written to {cmp_file}")
                common.warn("Please inspect the diff between the two!!!")

    except Exception as e:
        import traceback

        traceback.print_tb(e.__traceback__)
        msg = f"clone defs command failed: {e=}, {type(e)=}"
        raise click.ClickException(msg) from e


# Entrypoint
# ----------

if __name__ == "__main__":
    app()

Functions

def is_documented_by(original: Any) ‑> Any
Expand source code Browse git
def is_documented_by(original: Any) -> Any:
    def wrapper(target):
        target.__doc__ = original.__doc__
        return target

    return wrapper

Classes

class Arg
Expand source code Browse git
class Arg:
    repository: dict[str, Any] = dict(
        type=click.Path(
            exists=False,
            file_okay=False,
            dir_okay=True,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),  # type: ignore
    )

    output_file: dict[str, Any] = dict(
        type=click.Path(
            exists=False,
            file_okay=True,
            dir_okay=False,
            writable=True,
            readable=True,
            resolve_path=True,
        ),
        callback=lambda ctx, param, value: Path(value),  # type: ignore
    )

    base_suffix: dict[str, Any] = dict(
        type=str,
        metavar="BASE",
    )

    diff_suffix: dict[str, Any] = dict(
        type=str,
        metavar="DIFF",
    )

Class variables

var base_suffix : dict[str, typing.Any]
var diff_suffix : dict[str, typing.Any]
var output_file : dict[str, typing.Any]
var repository : dict[str, typing.Any]
class Opt
Expand source code Browse git
class Opt:
    db_port: dict[str, Any] = dict(
        default=6877,
        help="DB connection port.",
        envvar="PGPORT",
    )

    db_host: dict[str, Any] = dict(
        default="localhost",
        help="DB connection host.",
        envvar="PGHOST",
    )

    db_user: dict[str, Any] = dict(
        default="mz_support",
        help="DB connection user.",
        envvar="PGUSER",
    )

    db_pass: dict[str, Any] = dict(
        default=None,
        help="DB connection password.",
        envvar="PGPASSWORD",
    )

    db_require_ssl: dict[str, Any] = dict(
        is_flag=True,
        help="DB connection requires SSL.",
        envvar="PGREQUIRESSL",
    )

    mzfmt: dict[str, Any] = dict(
        default=True,
        help="Format SQL statements with `mzfmt` if present.",
    )

    explainee_type: dict[str, Any] = dict(
        type=click.Choice([v.name.lower() for v in list(api.ExplaineeType)]),
        default="catalog_item",
        callback=lambda ctx, param, v: api.ExplaineeType[v.upper()],  # type: ignore
        help="EXPLAIN mode.",
        metavar="MODE",
    )

    explain_options: dict[str, Any] = dict(
        type=api.ExplainOptionType(),
        multiple=True,
        help="WITH key=val pairs to pass to the EXPLAIN command.",
        metavar="KEY=VAL",
    )

    explain_stage: dict[str, Any] = dict(
        type=click.Choice([str(v.name.lower()) for v in list(api.ExplainStage)]),
        multiple=True,
        default=["optimized_plan"],  # Most often we'll only the optimized plan.
        callback=lambda ctx, param, vals: [api.ExplainStage[v.upper()] for v in vals],  # type: ignore
        help="EXPLAIN stage to export.",
        metavar="STAGE",
    )

    explain_suffix: dict[str, Any] = dict(
        type=str,
        default="",
        help="A suffix to append to the EXPLAIN output files.",
        metavar="SUFFIX",
    )

    explain_format: dict[str, Any] = dict(
        type=click.Choice([str(v.name.lower()) for v in list(api.ExplainFormat)]),
        default="text",
        callback=lambda ctx, param, v: api.ExplainFormat[v.upper()],  # type: ignore
        help="AS [FORMAT] clause to pass to the EXPLAIN command.",
        metavar="FORMAT",
    )

    system: dict[str, Any] = dict(
        is_flag=True,
        show_default=True,
        default=False,
        help="Inspect system or user tables.",
    )

Class variables

var db_host : dict[str, typing.Any]
var db_pass : dict[str, typing.Any]
var db_port : dict[str, typing.Any]
var db_require_ssl : dict[str, typing.Any]
var db_user : dict[str, typing.Any]
var explain_format : dict[str, typing.Any]
var explain_options : dict[str, typing.Any]
var explain_stage : dict[str, typing.Any]
var explain_suffix : dict[str, typing.Any]
var explainee_type : dict[str, typing.Any]
var mzfmt : dict[str, typing.Any]
var system : dict[str, typing.Any]