Skip to content
142 changes: 141 additions & 1 deletion vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import getpass
import subprocess
from subprocess import PIPE
from croniter import croniter, CroniterBadCronError

try:
from urllib import quote_plus # Python 2.X
Expand Down Expand Up @@ -52,6 +53,64 @@
class Object(object):
pass

def convert_cron_to_millis(schedule):
"""Convert cron schedule to a time interval in milliseconds."""
try:
cron = croniter(schedule) # Validate the cron schedule
except CroniterBadCronError:
raise argparse.ArgumentTypeError(f"Invalid cron schedule: '{schedule}'. "
"Make sure it follows the cron syntax format:\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this throw a error since its essentially a bunch of separate strings instead of one big one? I think you're looking for a block string with ''' abcdefg '''

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems okay to me when I tested, seems like croniter is built to handle this

" ┌───────────── minute (0 - 59)\n"
" │ ┌───────────── hour (0 - 23)\n"
" │ │ ┌───────────── day of month (1 - 31)\n"
" │ │ │ ┌───────────── month (1 - 12)\n"
" │ │ │ │ ┌───────────── day of week (0 - 6) (Sunday=0)\n"
" │ │ │ │ │\n"
" │ │ │ │ │\n"
" * * * * * command to execute\n\n"
"For example, '0 */2 * * *' runs every 2 hours.")

first_time = cron.get_next(datetime)
second_time = cron.get_next(datetime)

time_interval = int((second_time - first_time).total_seconds() * 1000)

return time_interval

def validate_millis(value):
"""Validate that the input value is a valid number for milliseconds between yesterday and Jan 1, 2100."""
try:
val = int(value)

# Calculate min_millis as the start of yesterday in millis
yesterday = datetime.now() - timedelta(days=1)
min_millis = int(yesterday.timestamp() * 1000)

# Calculate max_millis for Jan 1st, 2100 in millis
max_date = datetime(2100, 1, 1, 0, 0, 0)
max_millis = int(max_date.timestamp() * 1000)

if not (min_millis <= val <= max_millis):
raise argparse.ArgumentTypeError(f"{value} is not a valid millisecond timestamp.")
return val
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not a valid integer.")

def validate_schedule_values(args):
"""Validate start and end times."""
# Validate start_time and end_time
args.start_time = validate_millis(args.start_time)
args.end_time = validate_millis(args.end_time)

if args.start_time >= args.end_time:
raise argparse.ArgumentTypeError("--start_time must be less than --end_time.")

# Get the time interval in milliseconds
time_interval = convert_cron_to_millis(args.schedule)

print(f"Time Interval (in milliseconds): {time_interval}")
return args.start_time, args.end_time, time_interval

def strip_strings(value):
if isinstance(value, str):
return value.strip()
Expand Down Expand Up @@ -924,6 +983,12 @@ def change__bid(args: argparse.Namespace):
r.raise_for_status()
print("Per gpu bid price changed".format(r.json()))

if (args.schedule):
cli_command = "change bid"
api_endpoint = "/api/v0/instances/bid_price/{id}/".format(id=args.id)
json_blob["instance_id"] = args.id
add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT")




Expand Down Expand Up @@ -978,6 +1043,7 @@ def copy(args: argparse.Namespace):
if (args.explain):
print("request json: ")
print(req_json)

r = http_put(args, url, headers=headers,json=req_json)
r.raise_for_status()
if (r.status_code == 200):
Expand Down Expand Up @@ -1091,17 +1157,60 @@ def cloud__copy(args: argparse.Namespace):
if (args.explain):
print("request json: ")
print(req_json)

r = http_post(args, url, headers=headers,json=req_json)
r.raise_for_status()
if (r.status_code == 200):
print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).")
print("When the operation is finished you should see 'Cloud Cody Operation Finished' in the instance status bar.")
if (args.schedule):
cli_command = "cloud copy"
api_endpoint = "/api/v0/commands/rclone/"
add_scheduled_job(args, req_json, cli_command, api_endpoint, "POST")
else:
print(r.text);
print("failed with error {r.status_code}".format(**locals()));




def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method):
start_time, end_time, time_interval = validate_schedule_values(args)


schedule_job_url = apiurl(args, f"/commands/schedule_job/")
start_date = millis_to_date(start_time)
end_date = millis_to_date(end_time)
request_body = {
"start_time": start_time,
"end_time": end_time,
"time_interval": time_interval,
"api_endpoint": api_endpoint,
"request_method": request_method,
"request_body": req_json
}
# Send a POST request
response = requests.post(schedule_job_url, headers=headers, json=request_body)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ask for request_method if only POST requests are called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scheduled_jobs table is intended to store GET, POST, PUT, or DELETE requests. but call to scheduled_jobs api endpoint will be post if inserting a new record and PUT if updating


print(f"url: {schedule_job_url}")
print(f"headers: {headers}")
print(f"request_body: {request_body}")

# Raise an exception for HTTP errors
response.raise_for_status()

# Handle the response based on the status code
if response.status_code == 200:
time_interval_hours = millis_to_hours(time_interval)
print(f"Scheduling job to {cli_command} from {start_date} to {end_date} every {time_interval_hours} hours")
print(response.json())
elif response.status_code == 401:
print(f"Failed with error {response.status_code}. It could be because you aren't using a valid api_key.")
else:
# print(r.text)
print(f"Failed with error {response.status_code}.")


@parser.command(
argument("--name", help="name of the api-key", type=str),
argument("--permission_file", help="file path for json encoded permissions, see https://vast.ai/docs/cli/roles-and-permissions for more information", type=str),
Expand Down Expand Up @@ -1384,6 +1493,7 @@ def create__instance(args: argparse.Namespace):
else:
print("Started. {}".format(r.json()))


@parser.command(
argument("--email", help="email address to use for login", type=str),
argument("--username", help="username to use for login", type=str),
Expand Down Expand Up @@ -1767,6 +1877,11 @@ def execute(args):
if (r.status_code == 200):
filtered_text = r.text.replace(rj["writeable_path"], '');
print(filtered_text)
if (args.schedule):
cli_command = "execute"
api_endpoint = "/api/v0/instances/command/{id}/".format(id=args.ID)
json_blob["instance_id"] = args.ID
add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT")
break
else:
print(rj);
Expand Down Expand Up @@ -2179,6 +2294,11 @@ def reboot__instance(args):
rj = r.json();
if (rj["success"]):
print("Rebooting instance {args.ID}.".format(**(locals())));
if (args.schedule):
cli_command = "reboot instance"
api_endpoint = "/api/v0/instances/reboot/{id}/".format(id=args.ID)
json_blob = {"instance_id": args.ID}
add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT")
else:
print(rj["msg"]);
else:
Expand Down Expand Up @@ -4495,6 +4615,21 @@ def schedule__maint(args):
r.raise_for_status()
print(f"Maintenance window scheduled for {dt} success".format(r.json()))


def millis_to_date(milliseconds):
# Convert milliseconds to seconds
seconds = milliseconds / 1000.0
# Create a datetime object from the epoch (January 1, 1970)
return datetime(1970, 1, 1) + timedelta(seconds=seconds)

def millis_to_hours(milliseconds):
hours = milliseconds / (1000 * 60 * 60)
return hours


def hours_to_millis(hours):
return hours * 60 * 60 * 1000

@parser.command(
argument("ID", help="id of machine to display", type=int),
argument("-q", "--quiet", action="store_true", help="only display numeric ids"),
Expand Down Expand Up @@ -4597,11 +4732,16 @@ def login(args):
print(login_deprecated_message)
"""



def main():
parser.add_argument("--url", help="server REST api url", default=server_url_default)
parser.add_argument("--retry", help="retry limit", default=3)
parser.add_argument("--raw", action="store_true", help="output machine-readable json")
parser.add_argument("--explain", action="store_true", help="output verbose explanation of mapping of CLI calls to HTTPS API endpoints")
parser.add_argument("--schedule", help="try to schedule a command to run every x mins, hours, etc. by passing in time interval in cron syntax to --schedule option. Can also choose to have --start_time and --end_time options with valid values. For ex. --schedule \"0 */2 * * *\"")
parser.add_argument("--start_time", help="the start time for your scheduled job in millis since unix epoch. Default will be current time. For ex. --start_time 1728510298144", default=(time.time() * 1000))
parser.add_argument("--end_time", help="the end time for your scheduled job in millis since unix epoch. Default will be 7 days from now. For ex. --end_time 1729115232881", default=(time.time() * 1000 + 7 * 24 * 60 * 60 * 1000))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yknow you can enforce type with argparse arguments by just adding type=int
also if you want it so that the user can choose --schedule OR (start_time and end_time) I recommend checking out argparse's mutually exclusive groups :D up to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not mutually exclusive

parser.add_argument("--api-key", help="api key. defaults to using the one stored in {}".format(api_key_file_base), type=str, required=False, default=os.getenv("VAST_API_KEY", api_key_guard))


Expand Down