Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cosmosdb-preview/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Release History
===============
0.21.0
* Add support for mongo data transfer jobs.

++++++
0.20.0
* Add support for Continuous mode restore with user provided identity.

Expand Down
17 changes: 16 additions & 1 deletion src/cosmosdb-preview/azext_cosmosdb_preview/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,29 @@
Usage: --dest-sql-container database=XX container=XX'
database: Database name of CosmosDB Sql.
container: Container name of CosmosDB Sql.
- name: --source-mongo
short-summary: "Source mongo collection"
long-summary: |
Usage: --source-mongo database=XX collection=XX'
database: Database name of CosmosDB Mongo.
collection: Collection name of CosmosDB Mongo.
- name: --dest-mongo
short-summary: "Destination mongo collection"
long-summary: |
Usage: --dest-mongo database=XX collection=XX'
database: Database name of CosmosDB Mongo.
collection: Collection name of CosmosDB Mongo.

examples:
- name: Copy sql container
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-sql-container database=db1 container=c1 --dest-sql-container database=db2 container=c2
- name: Copy cassandra table
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-cassandra-table keyspace=k1 table=t1 --dest-cassandra-table keyspace=k1 table=t1
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-cassandra-table keyspace=k1 table=t1 --dest-cassandra-table keyspace=k2 table=t2
- name: Copy mongo collection
text: |-
az cosmosdb dts copy -g "rg1" --job-name "j1" --account-name "db1" --source-mongo database=d1 collection=c1 --dest-mongo database=d2 collection=c2
"""

helps['cosmosdb dts'] = """
Expand Down
3 changes: 3 additions & 0 deletions src/cosmosdb-preview/azext_cosmosdb_preview/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CreateGremlinDatabaseRestoreResource,
CreateTableRestoreResource,
AddCassandraTableAction,
AddMongoCollectionAction,
AddSqlContainerAction,
CreateTargetPhysicalPartitionThroughputInfoAction,
CreateSourcePhysicalPartitionThroughputInfoAction,
Expand Down Expand Up @@ -314,8 +315,10 @@ def load_arguments(self, _):
with self.argument_context('cosmosdb dts copy') as c:
c.argument('job_name', job_name_type)
c.argument('source_cassandra_table', nargs='+', action=AddCassandraTableAction, help='Source cassandra table')
c.argument('source_mongo', nargs='+', action=AddMongoCollectionAction, help='Source mongo collection')
c.argument('source_sql_container', nargs='+', action=AddSqlContainerAction, help='Source sql container')
c.argument('dest_cassandra_table', nargs='+', action=AddCassandraTableAction, help='Destination cassandra table')
c.argument('dest_mongo', nargs='+', action=AddMongoCollectionAction, help='Destination mongo collection')
c.argument('dest_sql_container', nargs='+', action=AddSqlContainerAction, help='Destination sql container')
c.argument('worker_count', type=int, help='Worker count')

Expand Down
40 changes: 40 additions & 0 deletions src/cosmosdb-preview/azext_cosmosdb_preview/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DatabaseRestoreResource,
GremlinDatabaseRestoreResource,
CosmosCassandraDataTransferDataSourceSink,
CosmosMongoDataTransferDataSourceSink,
CosmosSqlDataTransferDataSourceSink,
PhysicalPartitionThroughputInfoResource,
PhysicalPartitionId
Expand Down Expand Up @@ -138,6 +139,45 @@ def __call__(self, parser, namespace, values, option_string=None):
namespace.cassandra_table = cassandra_table


class AddMongoCollectionAction(argparse._AppendAction):
def __call__(self, parser, namespace, values, option_string=None):
if not values:
# pylint: disable=line-too-long
raise CLIError(f'usage error: {option_string} [KEY=VALUE ...]')

database_name = None
collection_name = None

for (k, v) in (x.split('=', 1) for x in values):
kl = k.lower()
if kl == 'database':
database_name = v

elif kl == 'collection':
collection_name = v

else:
raise CLIError(
f'Unsupported Key {k} is provided for {option_string} component. All'
' possible keys are: database, collection'
)

if database_name is None:
raise CLIError(f'usage error: missing key database in {option_string} component')

if collection_name is None:
raise CLIError(f'usage error: missing key table in {option_string} component')

mongo_collection = CosmosMongoDataTransferDataSourceSink(database_name=database_name, collection_name=collection_name)

if option_string == "--source-mongo":
namespace.source_mongo = mongo_collection
elif option_string == "--dest-mongo":
namespace.dest_mongo = mongo_collection
else:
namespace.mongo_collection = mongo_collection


class AddSqlContainerAction(argparse._AppendAction):
def __call__(self, parser, namespace, values, option_string=None):
if not values:
Expand Down
50 changes: 34 additions & 16 deletions src/cosmosdb-preview/azext_cosmosdb_preview/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,33 +1114,51 @@ def cosmosdb_data_transfer_copy_job(client,
dest_cassandra_table=None,
source_sql_container=None,
dest_sql_container=None,
source_mongo=None,
dest_mongo=None,
worker_count=0,
job_name=None):
if source_cassandra_table is None and source_sql_container is None:
raise CLIError('source component ismissing')

if source_cassandra_table is not None and source_sql_container is not None:
raise CLIError('Invalid input: multiple source components')

if dest_cassandra_table is None and dest_sql_container is None:
raise CLIError('destination component is missing')

if dest_cassandra_table is not None and dest_sql_container is not None:
raise CLIError('Invalid input: multiple destination components')

job_create_properties = {}

source = None
if source_cassandra_table is not None:
job_create_properties['source'] = source_cassandra_table
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_cassandra_table

if source_sql_container is not None:
job_create_properties['source'] = source_sql_container
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_sql_container

if source_mongo is not None:
if source is not None:
raise CLIError('Invalid input: multiple source components')
source = source_mongo

if source is None:
raise CLIError('source component is missing')
job_create_properties['source'] = source

destination = None
if dest_cassandra_table is not None:
job_create_properties['destination'] = dest_cassandra_table
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_cassandra_table

if dest_sql_container is not None:
job_create_properties['destination'] = dest_sql_container
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_sql_container

if dest_mongo is not None:
if destination is not None:
raise CLIError('Invalid input: multiple destination components')
destination = dest_mongo

if destination is None:
raise CLIError('destination component is missing')
job_create_properties['destination'] = destination

if worker_count > 0:
job_create_properties['worker_count'] = worker_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -106,10 +106,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -184,10 +191,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -1067,10 +1081,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -108,10 +108,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast, overload
from urllib.parse import parse_qs, urljoin, urlparse
import urllib.parse

from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import (
Expand Down Expand Up @@ -129,10 +129,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -1177,10 +1184,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down Expand Up @@ -2266,10 +2280,17 @@ def prepare_request(next_link=None):

else:
# make call to next link with the client's api-version
_parsed_next_link = urlparse(next_link)
_next_request_params = case_insensitive_dict(parse_qs(_parsed_next_link.query))
_parsed_next_link = urllib.parse.urlparse(next_link)
_next_request_params = case_insensitive_dict(
{
key: [urllib.parse.quote(v) for v in value]
for key, value in urllib.parse.parse_qs(_parsed_next_link.query).items()
}
)
_next_request_params["api-version"] = self._config.api_version
request = HttpRequest("GET", urljoin(next_link, _parsed_next_link.path), params=_next_request_params)
request = HttpRequest(
"GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params
)
request = _convert_request(request)
request.url = self._client.format_url(request.url) # type: ignore
request.method = "GET"
Expand Down
Loading