Skip to content

Commit

Permalink
Fix (#2399)
Browse files Browse the repository at this point in the history
Co-authored-by: hzhou245 <[email protected]>
  • Loading branch information
zhr1201 and hzhou245 authored Mar 13, 2024
1 parent 696d161 commit c8084ef
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions examples/aishell/whisper/local/modify_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@

def main():
parser = argparse.ArgumentParser(description='filter out unused module')
parser.add_argument('--remove_list',
default="",
type=str,
help='list of name filter, comma-separated, e.g."name1, name2"')
parser.add_argument('--add_list',
default="",
type=str,
help='dict of name adder, e.g."{\"key1\": \"value1\"}"')
parser.add_argument(
'--remove_list',
default="",
type=str,
help='list of name filter, comma-separated, e.g."name1, name2"')
parser.add_argument(
'--add_list',
default="",
type=str,
help='dict of name adder, e.g."{\"key1\": \"value1\"}"')
parser.add_argument('--input_ckpt',
required=True,
type=str,
Expand All @@ -39,18 +41,22 @@ def main():

state = torch.load(args.input_ckpt, map_location="cpu")
new_state = {}

if args.remove_list:
remove_list = args.remove_list.split(',')
for k in state.keys():
found = False
for prefix in remove_list:
if prefix in k:
print("skip {}".format(k))
found = True
break
if found:
continue
new_state[k] = state[k]
else:
remove_list = []

for k in state.keys():
found = False
for prefix in remove_list:
if prefix in k:
print("skip {}".format(k))
found = True
break
if found:
continue
new_state[k] = state[k]

if args.add_list:
add_list = json.loads(args.add_list)
Expand Down

0 comments on commit c8084ef

Please sign in to comment.