# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import time
from collections import defaultdict
from multiprocessing.pool import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional
import typer
from rich.columns import Columns
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
import oumi.cli.cli_utils as cli_utils
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
from oumi.utils.git_utils import get_git_root_dir
from oumi.utils.logging import logger
from oumi.utils.version_utils import is_dev_build
if TYPE_CHECKING:
from oumi.core.launcher import BaseCluster, JobStatus
def _get_working_dir(current: Optional[str]) -> Optional[str]:
"""Prompts the user to select the working directory, if relevant."""
if not is_dev_build():
return current
oumi_root = get_git_root_dir()
if current and (not oumi_root or oumi_root == Path(current).resolve()):
return current
use_root = typer.confirm(
"You are using a dev build of oumi. "
f"Use oumi's root directory ({oumi_root}) as your working directory?",
abort=False,
default=True,
)
return str(oumi_root) if use_root else current
def _print_and_wait(
message: str, task: Callable[..., bool], asynchronous=True, **kwargs
) -> None:
"""Prints a message with a loading spinner until the provided task is done."""
with cli_utils.CONSOLE.status(message):
if asynchronous:
with Pool(processes=1) as worker_pool:
task_done = False
while not task_done:
worker_result = worker_pool.apply_async(task, kwds=kwargs)
worker_result.wait()
# Call get() to reraise any exceptions that occurred in the worker.
task_done = worker_result.get()
else:
# Synchronous tasks should be atomic and not block for a significant amount
# of time. If a task is blocking, it should be run asynchronously.
while not task(**kwargs):
sleep_duration = 0.1
time.sleep(sleep_duration)
def _is_job_done(id: str, cloud: str, cluster: str) -> bool:
"""Returns true IFF a job is no longer running."""
from oumi import launcher
running_cloud = launcher.get_cloud(cloud)
running_cluster = running_cloud.get_cluster(cluster)
if not running_cluster:
return True
status = running_cluster.get_job(id)
return status.done
def _cancel_worker(id: str, cloud: str, cluster: str) -> bool:
"""Cancels a job.
All workers must return a boolean to indicate whether the task is done.
Cancel has no intermediate states, so it always returns True.
"""
from oumi import launcher
if not cluster:
return True
if not id:
return True
if not cloud:
return True
launcher.cancel(id, cloud, cluster)
return True # Always return true to indicate that the task is done.
def _tail_logs(
log_stream: io.TextIOBase, job_id: str, output_filepath: Optional[str] = None
) -> None:
"""Tails logs with pretty CLI output.
This function reads from a log stream and displays the output with rich formatting
for CLI users. Optionally saves the output to a file.
Args:
log_stream: A LogStream object that can be read from.
job_id: The ID of the job being tailed.
output_filepath: Optional path to a file to save the logs to.
"""
if output_filepath:
cli_utils.CONSOLE.print(f"Tailed logs will be saved to: {output_filepath}")
else:
cli_utils.CONSOLE.print("Tailing logs to console...")
# Open output file if specified
file_handle = None
if output_filepath:
file_handle = open(output_filepath, "w", encoding="utf-8")
try:
if file_handle:
with cli_utils.CONSOLE.status(f"Tailing logs to {output_filepath}..."):
for line in iter(log_stream.readline, ""):
file_handle.write(line)
file_handle.flush()
else:
for line in iter(log_stream.readline, ""):
# Because Rich is rendering markup/styled text and
# escapes control characters like \r, we need to handle it specially.
if "\r" in line:
cli_utils.CONSOLE.file.write("\r")
cli_utils.CONSOLE.print(line.strip(), end="", markup=False)
cli_utils.CONSOLE.file.flush()
else:
cli_utils.CONSOLE.print(line, end="", markup=False)
except KeyboardInterrupt:
logger.info(f"Stopped tailing logs for job {job_id}")
except Exception as e:
logger.exception(f"Failed while tailing logs for job {job_id}: {e}")
raise
finally:
if file_handle:
file_handle.close()
log_stream.close()
def _down_worker(cluster: str, cloud: Optional[str]) -> bool:
"""Turns down a cluster.
All workers must return a boolean to indicate whether the task is done.
Down has no intermediate states, so it always returns True.
"""
from oumi import launcher
if cloud:
target_cloud = launcher.get_cloud(cloud)
target_cluster = target_cloud.get_cluster(cluster)
if target_cluster:
target_cluster.down()
else:
cli_utils.CONSOLE.print(
f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
)
return True
# Make a best effort to find a single cluster to turn down without a cloud.
clusters = []
for name in launcher.which_clouds():
target_cloud = launcher.get_cloud(name)
target_cluster = target_cloud.get_cluster(cluster)
if target_cluster:
clusters.append(target_cluster)
if len(clusters) == 0:
cli_utils.CONSOLE.print(
f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
)
return True
if len(clusters) == 1:
clusters[0].down()
else:
cli_utils.CONSOLE.print(
f"[red]Multiple clusters found with name [yellow]{cluster}[/yellow]. "
"Specify a cloud to turn down with `--cloud`.[/red]"
)
return True # Always return true to indicate that the task is done.
def _stop_worker(cluster: str, cloud: Optional[str]) -> bool:
"""Stops a cluster.
All workers must return a boolean to indicate whether the task is done.
Stop has no intermediate states, so it always returns True.
"""
from oumi import launcher
if cloud:
target_cloud = launcher.get_cloud(cloud)
target_cluster = target_cloud.get_cluster(cluster)
if target_cluster:
target_cluster.stop()
else:
cli_utils.CONSOLE.print(
f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
)
return True
# Make a best effort to find a single cluster to stop without a cloud.
clusters = []
for name in launcher.which_clouds():
target_cloud = launcher.get_cloud(name)
target_cluster = target_cloud.get_cluster(cluster)
if target_cluster:
clusters.append(target_cluster)
if len(clusters) == 0:
cli_utils.CONSOLE.print(
f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
)
return True
if len(clusters) == 1:
clusters[0].stop()
else:
cli_utils.CONSOLE.print(
f"[red]Multiple clusters found with name [yellow]{cluster}[/yellow]. "
"Specify a cloud to stop with `--cloud`.[/red]"
)
return True # Always return true to indicate that the task is done.
def _poll_job(
job_status: "JobStatus",
detach: bool,
cloud: str,
running_cluster: Optional["BaseCluster"] = None,
output_filepath: Optional[str] = None,
) -> None:
"""Polls a job until it is complete.
If the job is running in detached mode and the job is not on the local cloud,
the function returns immediately.
"""
from oumi import launcher
is_local = cloud == "local"
if detach and not is_local:
cli_utils.CONSOLE.print(
f"Running job [yellow]{job_status.id}[/yellow] in detached mode."
)
return
if detach and is_local:
cli_utils.CONSOLE.print("Cannot detach from jobs in local mode.")
if not running_cluster:
running_cloud = launcher.get_cloud(cloud)
running_cluster = running_cloud.get_cluster(job_status.cluster)
assert running_cluster
try:
log_stream = running_cluster.get_logs_stream(job_status.id, job_status.cluster)
_tail_logs(log_stream, job_status.id, output_filepath)
except NotImplementedError:
if output_filepath:
cli_utils.CONSOLE.print(
"Cluster does not have support for streaming to a file."
)
_print_and_wait(
f"Running job [yellow]{job_status.id}[/yellow]",
_is_job_done,
asynchronous=not is_local,
id=job_status.id,
cloud=cloud,
cluster=job_status.cluster,
)
final_status = running_cluster.get_job(job_status.id)
if final_status:
cli_utils.CONSOLE.print(
f"Job [yellow]{final_status.id}[/yellow] finished with "
f"status [yellow]{final_status.status}[/yellow]"
)
cli_utils.CONSOLE.print("Job metadata:")
cli_utils.CONSOLE.print(f"[yellow]{final_status.metadata}[/yellow]")
# ----------------------------
# Launch CLI subcommands
# ----------------------------
[docs]
def cancel(
cloud: Annotated[str, typer.Option(help="Filter results by this cloud.")],
cluster: Annotated[
str,
typer.Option(help="Filter results by clusters matching this name."),
],
id: Annotated[
str, typer.Option(help="Filter results by jobs matching this job ID.")
],
level: cli_utils.LOG_LEVEL_TYPE = None,
) -> None:
"""Cancels a job.
Args:
cloud: Filter results by this cloud.
cluster: Filter results by clusters matching this name.
id: Filter results by jobs matching this job ID.
level: The logging level for the specified command.
"""
_print_and_wait(
f"Canceling job [yellow]{id}[/yellow]",
_cancel_worker,
id=id,
cloud=cloud,
cluster=cluster,
)
[docs]
def down(
cluster: Annotated[str, typer.Option(help="The cluster to turn down.")],
cloud: Annotated[
Optional[str],
typer.Option(
help="If specified, only clusters on this cloud will be affected."
),
] = None,
level: cli_utils.LOG_LEVEL_TYPE = None,
) -> None:
"""Turns down a cluster.
Args:
cluster: The cluster to turn down.
cloud: If specified, only clusters on this cloud will be affected.
level: The logging level for the specified command.
"""
_print_and_wait(
f"Turning down cluster [yellow]{cluster}[/yellow]",
_down_worker,
cluster=cluster,
cloud=cloud,
)
cli_utils.CONSOLE.print(f"Cluster [yellow]{cluster}[/yellow] turned down!")
[docs]
def run(
ctx: typer.Context,
config: Annotated[
str,
typer.Option(
*cli_utils.CONFIG_FLAGS, help="Path to the configuration file for the job."
),
],
cluster: Annotated[
Optional[str],
typer.Option(
help=(
"The cluster to use for this job. If unspecified, a new cluster will "
"be created."
)
),
] = None,
detach: Annotated[
bool, typer.Option(help="Run the job in the background.")
] = False,
output_filepath: Annotated[
Optional[str], typer.Option(help="Path to save job logs to a file.")
] = None,
level: cli_utils.LOG_LEVEL_TYPE = None,
) -> None:
"""Runs a job on the target cluster.
Args:
ctx: The Typer context object.
config: Path to the configuration file for the job.
cluster: The cluster to use for this job. If no such cluster exists, a new
cluster will be created. If unspecified, a new cluster will be created with
a unique name.
detach: Run the job in the background.
output_filepath: Path to save job logs to a file.
level: The logging level for the specified command.
"""
extra_args = cli_utils.parse_extra_cli_args(ctx)
config = str(
cli_utils.resolve_and_fetch_config(
try_get_config_name_for_alias(config, AliasType.JOB),
)
)
# Delayed imports
from oumi import launcher
# End imports
parsed_config: launcher.JobConfig = launcher.JobConfig.from_yaml_and_arg_list(
config, extra_args, logger=logger
)
parsed_config.finalize_and_validate()
parsed_config.working_dir = _get_working_dir(parsed_config.working_dir)
if not cluster:
raise ValueError("No cluster specified for the `run` action.")
job_status = launcher.run(parsed_config, cluster)
cli_utils.CONSOLE.print(
f"Job [yellow]{job_status.id}[/yellow] queued on cluster "
f"[yellow]{cluster}[/yellow]."
)
_poll_job(
job_status=job_status,
detach=detach,
cloud=parsed_config.resources.cloud,
output_filepath=output_filepath,
)
[docs]
def status(
cloud: Annotated[
Optional[str], typer.Option(help="Filter results by this cloud.")
] = None,
cluster: Annotated[
Optional[str],
typer.Option(help="Filter results by clusters matching this name."),
] = None,
id: Annotated[
Optional[str], typer.Option(help="Filter results by jobs matching this job ID.")
] = None,
level: cli_utils.LOG_LEVEL_TYPE = None,
) -> None:
"""Prints the status of jobs launched from Oumi.
Optionally, the caller may specify a job id, cluster, or cloud to further filter
results.
Args:
cloud: Filter results by this cloud.
cluster: Filter results by clusters matching this name.
id: Filter results by jobs matching this job ID.
level: The logging level for the specified command.
"""
# Delayed imports
from oumi import launcher
# End imports
filtered_jobs = launcher.status(cloud=cloud, cluster=cluster, id=id)
num_jobs = sum(len(cloud_jobs) for cloud_jobs in filtered_jobs.keys())
# Print the filtered jobs.
if num_jobs == 0 and (cloud or cluster or id):
cli_utils.CONSOLE.print(
"[red]No jobs found for the specified filter criteria: [/red]"
)
if cloud:
cli_utils.CONSOLE.print(f"Cloud: [yellow]{cloud}[/yellow]")
if cluster:
cli_utils.CONSOLE.print(f"Cluster: [yellow]{cluster}[/yellow]")
if id:
cli_utils.CONSOLE.print(f"Job ID: [yellow]{id}[/yellow]")
for target_cloud, job_list in filtered_jobs.items():
cli_utils.section_header(f"Cloud: [yellow]{target_cloud}[/yellow]")
cluster_name_list = [
c.name() for c in launcher.get_cloud(target_cloud).list_clusters()
]
if len(cluster_name_list) == 0:
cli_utils.CONSOLE.print("[red]No matching clusters found.[/red]")
continue
# Organize all jobs by cluster.
jobs_by_cluster: dict[str, list[JobStatus]] = defaultdict(list)
# List all clusters, even if they don't have jobs.
for cluster_name in cluster_name_list:
if not cluster or cluster == cluster_name:
jobs_by_cluster[cluster_name] = []
for job in job_list:
jobs_by_cluster[job.cluster].append(job)
for target_cluster, jobs in jobs_by_cluster.items():
title = f"[cyan]Cluster: [yellow]{target_cluster}[/yellow][/cyan]"
if not jobs:
body = Text("[red]No matching jobs found.[/red]")
else:
jobs_table = Table(show_header=True, show_lines=False)
jobs_table.add_column("Job", justify="left", style="green")
jobs_table.add_column("Status", justify="left", style="yellow")
for job in jobs:
jobs_table.add_row(job.id, job.status)
body = jobs_table
cli_utils.CONSOLE.print(Panel(body, title=title, border_style="blue"))
[docs]
def stop(
cluster: Annotated[str, typer.Option(help="The cluster to stop.")],
cloud: Annotated[
Optional[str],
typer.Option(
help="If specified, only clusters on this cloud will be affected."
),
] = None,
level: cli_utils.LOG_LEVEL_TYPE = None,
) -> None:
"""Stops a cluster.
Args:
cluster: The cluster to stop.
cloud: If specified, only clusters on this cloud will be affected.
level: The logging level for the specified command.
"""
_print_and_wait(
f"Stopping cluster [yellow]{cluster}[/yellow]",
_stop_worker,
cluster=cluster,
cloud=cloud,
)
cli_utils.CONSOLE.print(
f"Cluster [yellow]{cluster}[/yellow] stopped!\n"
"Use [green]oumi launch down[/green] to turn it down."
)
[docs]
def up(
ctx: typer.Context,
config: Annotated[
str,
typer.Option(
*cli_utils.CONFIG_FLAGS, help="Path to the configuration file for the job."
),
],
cluster: Annotated[
Optional[str],
typer.Option(
help=(
"The cluster to use for this job. If unspecified, a new cluster will "
"be created."
)
),
] = None,
detach: Annotated[
bool, typer.Option(help="Run the job in the background.")
] = False,
output_filepath: Annotated[
Optional[str], typer.Option(help="Path to save job logs to a file.")
] = None,
level: cli_utils.LOG_LEVEL_TYPE = None,
):
"""Launches a job.
Args:
ctx: The Typer context object.
config: Path to the configuration file for the job.
cluster: The cluster to use for this job. If no such cluster exists, a new
cluster will be created. If unspecified, a new cluster will be created with
a unique name.
detach: Run the job in the background.
output_filepath: Path to save job logs to a file.
level: The logging level for the specified command.
"""
# Delayed imports
from oumi import launcher
# End imports
extra_args = cli_utils.parse_extra_cli_args(ctx)
config = str(
cli_utils.resolve_and_fetch_config(
try_get_config_name_for_alias(config, AliasType.JOB),
)
)
parsed_config: launcher.JobConfig = launcher.JobConfig.from_yaml_and_arg_list(
config, extra_args, logger=logger
)
parsed_config.finalize_and_validate()
if cluster:
target_cloud = launcher.get_cloud(parsed_config.resources.cloud)
target_cluster = target_cloud.get_cluster(cluster)
if target_cluster:
cli_utils.CONSOLE.print(
f"Found an existing cluster: [yellow]{target_cluster.name()}[/yellow]."
)
run(ctx, config, cluster, detach, output_filepath)
return
parsed_config.working_dir = _get_working_dir(parsed_config.working_dir)
# Start the job
running_cluster, job_status = launcher.up(parsed_config, cluster)
cli_utils.CONSOLE.print(
f"Job [yellow]{job_status.id}[/yellow] queued on cluster "
f"[yellow]{running_cluster.name()}[/yellow]."
)
_poll_job(
job_status=job_status,
detach=detach,
cloud=parsed_config.resources.cloud,
running_cluster=running_cluster,
output_filepath=output_filepath,
)
[docs]
def which(level: cli_utils.LOG_LEVEL_TYPE = None) -> None:
"""Prints the available clouds."""
# Delayed imports
from oumi import launcher
# End imports
clouds = launcher.which_clouds()
cloud_options = [Text(f"{cloud}", style="bold cyan") for cloud in clouds]
cli_utils.CONSOLE.print(
Panel(
Columns(cloud_options, equal=True, expand=True, padding=(0, 2)),
title="[yellow]Available Clouds[/yellow]",
border_style="blue",
)
)