2424from tvm import te
2525from tvm .relay import transform
2626from tvm .contrib import cudnn
27+ from tvm .relay .build_module import bind_params_by_name
2728
2829from ...dataflow_pattern import is_op , wildcard
2930from .te_target import lower_composite , relay_to_runtime
@@ -50,6 +51,8 @@ def partition_for_cudnn(
5051 tvm.IRModule
5152 The partitioned module.
5253 """
54+ if params :
55+ mod ["main" ] = bind_params_by_name (mod ["main" ], params )
5356
5457 seq = tvm .transform .Sequential (
5558 [
@@ -71,19 +74,78 @@ def softmax_pattern() -> relay.Pattern:
7174 """Create pattern for softmax."""
7275 return is_op ("nn.softmax" )(wildcard ())
7376
77+ def log_softmax_pattern () -> relay .Pattern :
78+ """Create pattern for log_softmax."""
79+ return is_op ("nn.log_softmax" )(wildcard ())
80+
81+ def conv2d_pattern () -> relay .Pattern :
82+ """Create pattern for conv2d."""
83+ return is_op ("nn.conv2d" )(wildcard (), wildcard ())
84+
7485 def check_softmax (matched : relay .Call ) -> bool :
7586 """Check if softmax is supported by cuDNN."""
7687 if matched .args [0 ].checked_type .dtype not in ["float64" , "float32" , "float16" ]:
7788 return False
7889
7990 return True
8091
92+ def check_log_softmax (matched : relay .Call ) -> bool :
93+ """Check if log_softmax is supported by cuDNN."""
94+ if matched .args [0 ].checked_type .dtype not in ["float64" , "float32" , "float16" ]:
95+ return False
96+
97+ if len (matched .args [0 ].checked_type .shape ) != 2 :
98+ return False
99+
100+ if matched .attrs ["axis" ] not in (1 , - 1 ):
101+ return False
102+
103+ return True
104+
105+ def check_conv2d (matched : relay .Call ) -> bool :
106+ if matched .args [0 ].checked_type .dtype not in ["float64" , "float32" , "float16" ]:
107+ return False
108+
109+ if matched .attrs ["data_layout" ] != "NCHW" or matched .attrs ["kernel_layout" ] != "OIHW" :
110+ return False
111+
112+ padding = matched .attrs ["padding" ]
113+ if padding [0 ] != padding [2 ] or padding [1 ] != padding [3 ]:
114+ return False
115+
116+ return True
117+
81118 return [
82119 ("cudnn.softmax" , softmax_pattern (), check_softmax ),
120+ ("cudnn.log_softmax" , log_softmax_pattern (), check_log_softmax ),
121+ ("cudnn.conv2d" , conv2d_pattern (), check_conv2d ),
83122 ]
84123
85124
86125@lower_composite ("cudnn.softmax" )
87126def _lower_softmax (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
88127 """Lower a softmax using cuDNN."""
89128 return cudnn .softmax (inputs [0 ], axis = op .attrs ["axis" ])
129+
130+
131+ @lower_composite ("cudnn.log_softmax" )
132+ def _lower_log_softmax (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
133+ """Lower a log_softmax using cuDNN."""
134+ return cudnn .log_softmax (inputs [0 ], axis = op .attrs ["axis" ])
135+
136+
137+ @lower_composite ("cudnn.conv2d" )
138+ def _lower_conv2d (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
139+ """Lower a conv2d using cuDNN."""
140+ return cudnn .conv_forward (
141+ inputs [0 ],
142+ inputs [1 ],
143+ pad = op .attrs ["padding" ],
144+ stride = op .attrs ["strides" ],
145+ dilation = op .attrs ["dilation" ],
146+ conv_mode = 1 ,
147+ tensor_format = 0 ,
148+ algo = 1 ,
149+ conv_dtype = op .checked_type .dtype ,
150+ groups = op .attrs ["groups" ],
151+ )
0 commit comments