Skip to content

Commit

Permalink
Merge pull request #9 from databricks/BugFixes
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
HariGS-DB authored May 16, 2023
2 parents 7206358 + a601983 commit bd12648
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions GroupMigration/WSGroupMigration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,21 @@ def __init__(self, groupL : list, cloud : str, inventoryTableName : str, workspa
self.verbose = verbose
self.numThreads = numThreads


self.lastInventoryRun = None
self.checkAllDB = False
print(f'Clearing inventory table {self.inventoryTableName}')
spark.sql(f"drop table if exists {self.inventoryTableName}")
spark.sql(f"drop table if exists {self.inventoryTableName+'TableACL'}")


#Check if we should automatically generate list, and do it immediately.
#Implementers Note: Could change this section to a lazy calculation by setting groupL to nil or some sentinel value and adding checks before use.
res=requests.get(f"{self.workspace_url}/api/2.0/preview/scim/v2/Me", headers=self.headers)
#print(res.text)
if res.status_code == 403:
print("token not valid.")
return
if(autoGenerateList) :
print("autoGenerateList parameter is set to TRUE. Ignoring groupL parameter and instead will automatically generate list of migraiton groups.")
self.groupL = self.findMigrationEligibleGroups()
Expand Down Expand Up @@ -837,7 +844,7 @@ def getSingleFolderList(self, path:str, depth:int) -> dict:
return (path, subFolders, notebooks, files)

for c in resFolderJson['objects']:
if c['object_type']=="DIRECTORY" and c['path'].startswith('/Repos') == False and c['path'].startswith('/Shared') == False and c['path'].endswith('/Trash') == False:
if c['object_type']=="DIRECTORY" and c['path'].startswith('/Shared') == False and c['path'].endswith('/Trash') == False:
subFolders[c['object_id']] = c['path']
elif c['object_type']=="NOTEBOOK" and c['path'].startswith('/Repos') == False and c['path'].startswith('/Shared') == False:
notebooks[c['object_id']] = c['path']
Expand Down Expand Up @@ -1117,27 +1124,22 @@ def updateGroup2Permission(self, object:str, groupPermission : dict, level:str):
dataAcl=[]
for acl in aclList:
try:
if 'user_name' in acl.keys():
if acl['user_name']==self.userName:
addUser=False
gName=acl['group_name']
if gName=="ADMIN" and acl['permission_level']!='CAN_MANAGE':
dataAcl.append({'group_name': gName, 'permission_level': "CAN_MANAGE"})
if level=="Workspace":
if acl['group_name'] in self.WorkspaceGroupNames:
gName="db-temp-"+acl['group_name']
dataAcl.append({'group_name': gName, 'permission_level': acl['permission_level']})
elif level=="Account":
if acl['group_name'] in self.TempGroupNames:
gName=acl['group_name'][8:]
#dataAcl.append({'group_name': gName, 'permission_level': acl['permission_level']})
else:
gName=acl['group_name']
#acl['group_name']=gName
dataAcl.append(acl)
except KeyError:
dataAcl.append(acl)
continue
if addUser:
dataAcl.append({"user_name": self.userName,"permission_level": "CAN_MANAGE"})
data={"access_control_list":dataAcl}
resAppPerm=requests.post(f"{self.workspace_url}/api/2.0/preview/sql/permissions/{object}/{object_id}", headers=self.headers, data=json.dumps(data))
except Exception as e:
Expand Down Expand Up @@ -1187,15 +1189,16 @@ def getDBACL(self, db: str):
try:
aclList=[]
dbdf=self.getGrantsOnObjects(db, "DATABASE", db)
aclList+=dbdf.collect()
if not self.checkAllDB:
userListCollect=dbdf.filter(col('ObjectType')=="DATABASE").filter(array_contains(col('ActionTypes'),"USAGE")).select(col('Principal')).collect()
userListCollect=dbdf.filter(col('ObjectType')=="DATABASE").filter((array_contains(col('ActionTypes'),"USAGE") | array_contains(col('ActionTypes'),"OWN"))).select(col('Principal')).collect()
userList=[ p.Principal for p in userListCollect]
userList=list(set(userList))
if not self.checkPrincipalInGroupOrMember(userList, db):
#print(f'selected groups or members of the groups have no USAGE permission on database level. Skipping object level permission check for database {db}.')
#print(f'selected groups or members of the groups have no USAGE or OWN permission on database level. Skipping object level permission check for database {db}.')
return []

aclList+=dbdf.collect()

tables = self.runVerboseSql("show tables in spark_catalog.{}".format(db)).filter(col("isTemporary") == False)
for table in tables.collect():
try:
Expand All @@ -1221,15 +1224,15 @@ def getDBACL(self, db: str):
def checkPrincipalInGroupOrMember(self, principalList: str, name: str)->bool:
for p in principalList:
if p in self.groupGroupList:
print(f'Group {p} is given USAGE permission for {name}.')
print(f'Group {p} is given USAGE or OWN permission for {name}.')
return True
for p in principalList:
if p in self.groupUserList:
print(f'User {p} is given USAGE permission for {name}.')
print(f'User {p} is given USAGE or OWN permission for {name}.')
return True
for p in principalList:
if p in self.groupSPList:
print(f'SP {p} is given USAGE permission for {name}.')
print(f'SP {p} is given USAGE or OWN permission for {name}.')
return True
return False

Expand All @@ -1247,6 +1250,8 @@ def getTableACLs(self)-> list:
common_df = common_df.unionAll(self.getGrantsOnObjects(None, "ANY FILE", None))
# CATALOG
common_df = common_df.unionAll(self.getGrantsOnObjects(None, "CATALOG", None))
aclList = []
aclList = common_df.collect()
#check if any group is given permission at catalog level
userListCollect=common_df.filter(col('ObjectType')=="CATALOG$").filter(array_contains(col('ActionTypes'),"USAGE")).select(col('Principal')).collect()
userList=[ p.Principal for p in userListCollect]
Expand All @@ -1263,7 +1268,7 @@ def getTableACLs(self)-> list:
#database_names=['aaron_binns','hsdb']
currentCount=0
try:
aclList = []
#aclList = []
aclFinalList = []
with concurrent.futures.ThreadPoolExecutor(max_workers=self.numThreads) as executor:
future_db = [executor.submit(self.getDBACL, f"`{databaseName}`" ) for databaseName in database_names]
Expand All @@ -1280,12 +1285,14 @@ def getTableACLs(self)-> list:

def generate_table_acls_command(self, action_types, object_type, object_key, groupName):
lines = []
grant_privs = [ x for x in action_types if not x.startswith("DENIED_") ]
deny_privs = [ x[len("DENIED_"):] for x in action_types if x.startswith("DENIED_") ]
grant_privs = [ x for x in action_types if not x.startswith("DENIED_") and x != "OWN"]
deny_privs = [ x[len("DENIED_"):] for x in action_types if x.startswith("DENIED_") and x != "OWN"]
if grant_privs:
lines.append(f"GRANT {', '.join(grant_privs)} ON {object_type} {object_key} TO `{groupName}`;")
if deny_privs:
lines.append(f"DENY {', '.join(deny_privs)} ON {object_type} {object_key} TO `{groupName}`;")
if "OWN" in action_types:
lines.append(f"ALTER {object_type} {object_key} OWNER TO `{groupName}`;")
return lines

def updateDataObjectsPermission(self, aclList : list, level:str):
Expand All @@ -1297,7 +1304,16 @@ def updateDataObjectsPermission(self, aclList : list, level:str):
gName="db-temp-"+acl.Principal
elif level=="Account":
gName=acl.Principal[8:]
lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
if acl.ObjectType == "ANONYMOUS_FUNCTION":
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'ANONYMOUS FUNCTION', '', gName))
elif acl.ObjectType == "ANY_FILE":
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'ANY FILE', '', gName))
elif acl.ObjectType == "CATALOG$":
lines.extend(self.generate_table_acls_command(acl.ActionTypes, 'CATALOG', '', gName))
elif acl.ObjectType in ["DATABASE", "TABLE"]:
# DATABASE, TABLE, VIEW (view's seem to show up as tables)
lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
#lines.extend(self.generate_table_acls_command(acl.ActionTypes, acl.ObjectType, acl.ObjectKey, gName))
for aclQuery in lines:
#print(aclQuery)
self.runVerboseSql(aclQuery)
Expand Down

0 comments on commit bd12648

Please sign in to comment.