Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions qiskit_ibm_runtime/api/clients/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ def job_metadata(self, job_id: str) -> Dict[str, Any]:
"""
return self._api.program_job(job_id).metadata()

def create_session(self, mode: str = None) -> Dict[str, Any]:
"""Create a session.

Args:
mode: Execution mode.
"""
return self._api.runtime_session(session_id=None).create(mode=mode)

def cancel_session(self, session_id: str) -> None:
"""Close all jobs in the runtime session.

Expand Down
2 changes: 1 addition & 1 deletion qiskit_ibm_runtime/api/rest/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def program_job(self, job_id: str) -> "ProgramJob":
"""
return ProgramJob(self.session, job_id)

def runtime_session(self, session_id: str) -> "RuntimeSession":
def runtime_session(self, session_id: str = None) -> "RuntimeSession":
"""Return an adapter for the session.

Args:
Expand Down
12 changes: 11 additions & 1 deletion qiskit_ibm_runtime/api/rest/runtime_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

"""Runtime Session REST adapter."""

import json
from typing import Dict, Any
from .base import RestAdapterBase
from ..session import RetrySession
Expand All @@ -35,7 +36,16 @@ def __init__(self, session: RetrySession, session_id: str, url_prefix: str = "")
session_id: Job ID of the first job in a runtime session.
url_prefix: Prefix to use in the URL.
"""
super().__init__(session, "{}/sessions/{}".format(url_prefix, session_id))
if not session_id:
super().__init__(session, "{}/sessions".format(url_prefix))
else:
super().__init__(session, "{}/sessions/{}".format(url_prefix, session_id))

def create(self, mode: str = None) -> Dict[str, Any]:
"""Create a session"""
url = self.get_url("self")
payload = json.dumps({"mode": mode})
return self.session.post(url, data=payload).json()

def cancel(self) -> None:
"""Cancel all jobs in the session."""
Expand Down
17 changes: 16 additions & 1 deletion qiskit_ibm_runtime/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,25 @@

"""Qiskit Runtime batch mode."""

from typing import Optional, Union
from qiskit_ibm_runtime import QiskitRuntimeService
from .ibm_backend import IBMBackend

from .session import Session


class Batch(Session):
"""Class for creating a batch mode in Qiskit Runtime."""

pass
def __init__(
self,
service: Optional[QiskitRuntimeService] = None,
backend: Optional[Union[str, IBMBackend]] = None,
max_time: Optional[Union[int, str]] = None,
):
super().__init__(service=service, backend=backend, max_time=max_time)

def create_session(self) -> None:
Comment thread
kt474 marked this conversation as resolved.
Outdated
"""Create a session."""
session = self._service._api_client.create_session(mode="batch")
return session.get("id")
41 changes: 19 additions & 22 deletions qiskit_ibm_runtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, Optional, Type, Union, Callable, Any
from types import TracebackType
from functools import wraps
from threading import Lock

from qiskit_ibm_runtime import QiskitRuntimeService
from .runtime_job import RuntimeJob
Expand Down Expand Up @@ -110,10 +109,11 @@ def __init__(
if QiskitRuntimeService.global_service is None
else QiskitRuntimeService.global_service
)

else:
self._service = service

self._session_id = self.create_session()

if self._service.channel == "ibm_quantum" and not backend:
raise ValueError('"backend" is required for ``ibm_quantum`` channel.')

Expand All @@ -123,15 +123,18 @@ def __init__(
backend = backend.name
self._backend = backend

self._setup_lock = Lock()
self._session_id: Optional[str] = None
self._active = True
self._max_time = (
max_time
if max_time is None or isinstance(max_time, int)
else hms_to_seconds(max_time, "Invalid max_time value: ")
)

def create_session(self) -> str:
"""Create a session."""
session = self._service._api_client.create_session()
return session.get("id")

@_active_session
def run(
self,
Expand Down Expand Up @@ -163,28 +166,22 @@ def run(
options["backend"] = self._backend

if not self._session_id:
# Make sure only one thread can send the session starter job.
self._setup_lock.acquire()
# TODO: What happens if session max time != first job max time?
# Use session max time if this is first job.
options["session_time"] = self._max_time
Comment thread
kt474 marked this conversation as resolved.
Outdated

try:
job = self._service.run(
program_id=program_id,
options=options,
inputs=inputs,
session_id=self._session_id,
start_session=self._session_id is None,
callback=callback,
result_decoder=result_decoder,
)

if self._session_id is None:
self._session_id = job.job_id()
finally:
if self._setup_lock.locked():
self._setup_lock.release()
job = self._service.run(
program_id=program_id,
options=options,
inputs=inputs,
session_id=self._session_id,
start_session=self._session_id is None,
Comment thread
kt474 marked this conversation as resolved.
Outdated
callback=callback,
result_decoder=result_decoder,
)

if self._session_id is None:
self._session_id = job.job_id()
Comment thread
kt474 marked this conversation as resolved.
Outdated

if self._backend is None:
self._backend = job.backend().name
Expand Down
2 changes: 1 addition & 1 deletion qiskit_ibm_runtime/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def default(self, obj: Any) -> Any: # pylint: disable=arguments-differ
if hasattr(obj, "to_json"):
return {"__type__": "to_json", "__value__": obj.to_json()}
if isinstance(obj, QuantumCircuit):
kwargs: dict[str, object] = {"use_symengine": bool(optionals.HAS_SYMENGINE)}
kwargs: Dict[str, object] = {"use_symengine": bool(optionals.HAS_SYMENGINE)}
if _TERRA_VERSION[0] >= 1:
# NOTE: This can be updated only after the server side has
# updated to a newer qiskit version.
Expand Down
29 changes: 1 addition & 28 deletions test/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

"""Tests for Session classession."""

import sys
import time
from concurrent.futures import ThreadPoolExecutor, wait

from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch

from qiskit_ibm_runtime.fake_provider import FakeManila
from qiskit_ibm_runtime import Session
Expand Down Expand Up @@ -128,29 +124,6 @@ def test_run(self):
self.assertEqual(session.session_id, job.job_id())
self.assertEqual(session.backend(), backend)

def test_run_is_thread_safe(self):
"""Test the session sends a session starter job once, and only once."""
service = MagicMock()
api = MagicMock()
service._api_client = api

def _wait_a_bit(*args, **kwargs):
# pylint: disable=unused-argument
switchinterval = sys.getswitchinterval()
time.sleep(switchinterval * 2)
return MagicMock()

service.run = Mock(side_effect=_wait_a_bit)

session = Session(service=service, backend="ibm_gotham")
with ThreadPoolExecutor(max_workers=2) as executor:
results = list(map(lambda _: executor.submit(session.run, "", {}), range(5)))
wait(results)

calls = service.run.call_args_list
session_starters = list(filter(lambda c: c.kwargs["start_session"] is True, calls))
self.assertEqual(len(session_starters), 1)

def test_close_without_run(self):
"""Test closing without run."""
service = MagicMock()
Expand Down