Skip to content

Commit 79ad3df

Browse files
authored
[PPML] refined read/write APIs and data-key management (#7489)
* [PPML] refined read/write APIs and data-key management * refine * refactor examples with new APIs * add python api * Update PPMLContext.scala * crypto mode set/get * fix * seperate key transactions * refine * implement meta format serializer * fix crypto issue * fix ut * fix * fix crypto mode bug * fix style * fix python ut
1 parent 663798f commit 79ad3df

32 files changed

+565
-847
lines changed

python/ppml/src/bigdl/ppml/ppml_context.py

+34-40
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,32 @@ def __init__(self, app_name, ppml_args=None, spark_conf=None):
3838
for (k, v) in spark_conf.getAll():
3939
conf[k] = v
4040
if ppml_args:
41-
kms_type = ppml_args.get("kms_type", "SimpleKeyManagementService")
42-
conf["spark.bigdl.kms.type"] = kms_type
41+
kms_type = ppml_args.get("kms_type", "")
42+
conf["spark.bigdl.primaryKey.defaultKey.kms.type"] = kms_type
4343
if kms_type == "SimpleKeyManagementService":
44-
conf["spark.bigdl.kms.appId"] = check(ppml_args, "app_id")
45-
conf["spark.bigdl.kms.apiKey"] = check(ppml_args, "api_key")
46-
conf["spark.bigdl.kms.primaryKey"] = check(ppml_args, "primary_key")
47-
conf["spark.bigdl.kms.dataKey"] = check(ppml_args, "data_key")
44+
conf["spark.bigdl.primaryKey.defaultKey.kms.appId"] = check(ppml_args, "app_id")
45+
conf["spark.bigdl.primaryKey.defaultKey.kms.apiKey"] = check(ppml_args, "api_key")
46+
conf["spark.bigdl.primaryKey.defaultKey.material"] = check(ppml_args, "primary_key_material")
4847
elif kms_type == "EHSMKeyManagementService":
49-
conf["spark.bigdl.kms.ip"] = check(ppml_args, "kms_server_ip")
50-
conf["spark.bigdl.kms.port"] = check(ppml_args, "kms_server_port")
51-
conf["spark.bigdl.kms.id"] = check(ppml_args, "app_id")
52-
conf["spark.bigdl.kms.apiKey"] = check(ppml_args, "api_key")
53-
conf["spark.bigdl.kms.primaryKey"] = check(ppml_args, "primary_key")
54-
conf["spark.bigdl.kms.dataKey"] = check(ppml_args, "data_key")
48+
conf["spark.bigdl.primaryKey.defaultKey.kms.ip"] = check(ppml_args, "kms_server_ip")
49+
conf["spark.bigdl.primaryKey.defaultKey.kms.port"] = check(ppml_args, "kms_server_port")
50+
conf["spark.bigdl.primaryKey.defaultKey.kms.id"] = check(ppml_args, "app_id")
51+
conf["spark.bigdl.primaryKey.defaultKey.kms.apiKey"] = check(ppml_args, "api_key")
52+
conf["spark.bigdl.primaryKey.defaultKey.material"] = check(ppml_args, "primary_key_material")
5553
elif kms_type == "AzureKeyManagementService":
56-
conf["spark.bigdl.kms.vault"] = check(ppml_args, "vault")
57-
conf["spark.bigdl.kms.clientId"] = ppml_args.get("client_id", "")
58-
conf["spark.bigdl.kms.primaryKey"] = check(ppml_args, "primary_key")
59-
conf["spark.bigdl.kms.dataKey"] = check(ppml_args, "data_key")
54+
conf["spark.bigdl.primaryKey.defaultKey.kms.vault"] = check(ppml_args, "vault")
55+
conf["spark.bigdl.primaryKey.defaultKey.kms.clientId"] = ppml_args.get("client_id", "")
56+
conf["spark.bigdl.primaryKey.defaultKey.material"] = check(ppml_args, "primary_key_material")
6057
elif kms_type == "BigDLKeyManagementService":
61-
conf["spark.bigdl.kms.ip"] = check(ppml_args, "kms_server_ip")
62-
conf["spark.bigdl.kms.port"] = check(ppml_args, "kms_server_port")
63-
conf["spark.bigdl.kms.user"] = check(ppml_args, "kms_user_name")
64-
conf["spark.bigdl.kms.token"] = check(ppml_args, "kms_user_token")
65-
conf["spark.bigdl.kms.primaryKey"] = check(ppml_args, "primary_key")
66-
conf["spark.bigdl.kms.dataKey"] = check(ppml_args, "data_key")
58+
conf["spark.bigdl.primaryKey.defaultKey.kms.ip"] = check(ppml_args, "kms_server_ip")
59+
conf["spark.bigdl.primaryKey.defaultKey.kms.port"] = check(ppml_args, "kms_server_port")
60+
conf["spark.bigdl.primaryKey.defaultKey.kms.user"] = check(ppml_args, "kms_user_name")
61+
conf["spark.bigdl.primaryKey.defaultKey.kms.token"] = check(ppml_args, "kms_user_token")
62+
conf["spark.bigdl.primaryKey.defaultKey.material"] = check(ppml_args, "primary_key_material")
63+
elif kms_type == "":
64+
conf["spark.bigdl.primaryKey.defaultKey.plainText"] = check(ppml_args, "primary_key_plaintext")
6765
else:
68-
invalidInputError(False, "invalid KMS type")
66+
invalidInputError(False, "invalid KMS type.")
6967

7068
conf["spark.hadoop.io.compression.codecs"] = "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec"
7169
spark_conf = init_spark_conf(conf)
@@ -76,32 +74,29 @@ def __init__(self, app_name, ppml_args=None, spark_conf=None):
7674
args = [self.spark._jsparkSession]
7775
super().__init__(None, self.bigdl_type, *args)
7876

79-
def load_keys(self, primary_key_path, data_key_path):
80-
self.value = callBigDlFunc(self.bigdl_type, "loadKeys", self.value, primary_key_path, data_key_path)
81-
82-
def read(self, crypto_mode, kms_name = "", primary_key = "", data_key = ""):
77+
def read(self, crypto_mode, primary_key_name = "defaultKey"):
8378
if isinstance(crypto_mode, CryptoMode):
8479
crypto_mode = crypto_mode.value
85-
df_reader = callBigDlFunc(self.bigdl_type, "read", self.value, crypto_mode,
86-
kms_name, primary_key, data_key)
80+
df_reader = callBigDlFunc(self.bigdl_type, "read",
81+
self.value, crypto_mode, primary_key_name)
8782
return EncryptedDataFrameReader(self.bigdl_type, df_reader)
8883

89-
def write(self, dataframe, crypto_mode, kms_name = "", primary_key = "", data_key = ""):
84+
def write(self, dataframe, crypto_mode, primary_key_name = "defaultKey"):
9085
if isinstance(crypto_mode, CryptoMode):
9186
crypto_mode = crypto_mode.value
92-
df_writer = callBigDlFunc(self.bigdl_type, "write", self.value, dataframe, crypto_mode,
93-
kms_name, primary_key, data_key)
87+
df_writer = callBigDlFunc(self.bigdl_type, "write", self.value,
88+
dataframe, crypto_mode, primary_key_name)
9489
return EncryptedDataFrameWriter(self.bigdl_type, df_writer)
9590

96-
def textfile(self, path, min_partitions=None, crypto_mode="plain_text",
97-
kms_name = "", primary_key = "", data_key = ""):
91+
def textfile(self, path, min_partitions=None,
92+
crypto_mode="plain_text", primary_key_name = "defaultKey"):
9893
if min_partitions is None:
9994
min_partitions = self.spark.sparkContext.defaultMinPartitions
10095
if isinstance(crypto_mode, CryptoMode):
10196
crypto_mode = crypto_mode.value
102-
return callBigDlFunc(self.bigdl_type, "textFile", self.value,
103-
path, min_partitions, crypto_mode,
104-
kms_name, primary_key, data_key)
97+
return callBigDlFunc(self.bigdl_type, "textFile",
98+
self.value, path, min_partitions,
99+
crypto_mode, primary_key_name)
105100

106101

107102
class EncryptedDataFrameReader:
@@ -168,9 +163,8 @@ class CryptoMode(Enum):
168163
# CryptoMode AES_GCM_CTR_V1 for parquet only
169164
AES_GCM_CTR_V1 = "AES_GCM_CTR_V1"
170165

171-
172-
def init_keys(app_id, api_key, primary_key_path, data_key_path):
173-
return callBigDlFunc("float", "initKeys", app_id, api_key, primary_key_path, data_key_path)
166+
def init_keys(app_id, api_key, primary_key_path):
167+
return callBigDlFunc("float", "initKeys", app_id, api_key, primary_key_path)
174168

175169

176170
def generate_encrypted_file(kms, primary_key_path, data_key_path, input_path, output_path):

python/ppml/test/bigdl/ppml/test_ppml_context.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,24 @@ def setUpClass(cls) -> None:
4040

4141
# set key path
4242
primary_key_path = os.path.join(resource_path, "primaryKey")
43-
data_key_path = os.path.join(resource_path, "dataKey")
4443

4544
# init a SparkContext
4645
conf = {"spark.app.name": "PPML TEST",
4746
"spark.hadoop.io.compression.codecs": "com.intel.analytics.bigdl.ppml.crypto.CryptoCodec",
48-
"spark.bigdl.kms.type": "SimpleKeyManagementService",
49-
"spark.bigdl.kms.appId": cls.app_id,
50-
"spark.bigdl.kms.apiKey": cls.app_key,
51-
"spark.bigdl.kms.primaryKey": primary_key_path,
52-
"spark.bigdl.kms.dataKey": data_key_path
47+
"spark.bigdl.primaryKey.defaultKey.kms.type": "SimpleKeyManagementService",
48+
"spark.bigdl.primaryKey.defaultKey.kms.appId": cls.app_id,
49+
"spark.bigdl.primaryKey.defaultKey.kms.apiKey": cls.app_key,
50+
"spark.bigdl.primaryKey.defaultKey.material": primary_key_path
5351
}
5452
init_spark_on_local(conf=conf)
5553

56-
# generate primaryKey and dataKey
57-
init_keys(cls.app_id, cls.app_key,
58-
primary_key_path, data_key_path)
54+
# generate primaryKey
55+
init_keys(cls.app_id, cls.app_key, primary_key_path)
5956

6057
args = {"kms_type": "SimpleKeyManagementService",
6158
"app_id": cls.app_id,
6259
"api_key": cls.app_key,
63-
"primary_key": primary_key_path,
64-
"data_key": data_key_path
60+
"primary_key_material": primary_key_path
6561
}
6662

6763
# init a PPMLContext

0 commit comments

Comments
 (0)