@@ -147,6 +147,8 @@ def setUpClass(cls):
147147 "I" : np .random .rand (2 , 2 ),
148148 "J" : np .random .rand (1 , 3 , 5 ),
149149 "K" : np .random .rand (1 , 2 , 3 , 4 ),
150+ "L" : np .random .rand (2 , 0 , 13 ),
151+ "M" : np .random .rand (13 ),
150152 }
151153
152154 def _get_place (self , force_to_use_cpu = False ):
@@ -320,6 +322,42 @@ def setUp(self):
320322 self .sample = {"paradigm" : "blq,bhlk->bhlqk" , "data" : ["J" , "K" ]}
321323
322324
325+ class TestEinsumZeroSizeTensor (TestEinsum ):
326+ def setUp (self ):
327+ self .sample = {"paradigm" : "...i, ...i" , "data" : ["L" , "M" ]}
328+
329+ def test_backward (self ):
330+ operands = [
331+ TestEinsum .TEST_SAMPLES [operand ] for operand in self .sample ["data" ]
332+ ]
333+ expected_result = np .einsum (self .sample ["paradigm" ], * operands )
334+ equation = self .sample ["paradigm" ]
335+
336+ with paddle .base .dygraph .guard (self ._get_place (force_to_use_cpu = False )):
337+ pd_operands = [
338+ paddle .to_tensor (operand , stop_gradient = False )
339+ for operand in operands
340+ ]
341+ result = paddle .einsum (equation , * pd_operands )
342+ self .check_output_equal (result .numpy (), expected_result )
343+ loss = result .sum ()
344+ loss .backward ()
345+ for x in pd_operands :
346+ np .testing .assert_allclose (x .grad .shape , x .shape )
347+
348+ with paddle .base .dygraph .guard (self ._get_place (force_to_use_cpu = True )):
349+ pd_operands = [
350+ paddle .to_tensor (operand , stop_gradient = False )
351+ for operand in operands
352+ ]
353+ result = paddle .einsum (equation , * pd_operands )
354+ self .check_output_equal (result .numpy (), expected_result )
355+ loss = result .sum ()
356+ loss .backward ()
357+ for x in pd_operands :
358+ np .testing .assert_allclose (x .grad .shape , x .shape )
359+
360+
323361class TestNumpyTests (unittest .TestCase ):
324362 def setUp (self ):
325363 pass
0 commit comments