Skip to content

Conversation

LeiWang1999
Copy link

  1. First thing of this pr, optimize mask generator:

for m:n 1-d sparsity, we first need to get a mask list, the original implementation take permutation and set to get the list:

""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
    # Early exit if patterns was already created.
    global valid_m4n2_1d_patterns

    if m==4  and n==2 and valid_m4n2_1d_patterns  is not None: return valid_m4n2_1d_patterns
    patterns = torch.zeros(m)
    patterns[:n] = 1
    valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
    if m == 4  and n == 2: valid_m4n2_1d_patterns  = valid_patterns       
    return valid_patterns

How ever, the algorithm complexity of the permutation is m!, when I take m=16, we need to generate 16!=20 922 789 888 000 candidate first, which is unreachable, so I re-design the algrithim, which reduce the complexity to m^2, then we can do best mask at n:m sparsity with big m.

""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
    # Early exit if patterns was already created.
    global valid_m4n2_1d_patterns

    if m==4  and n==2 and valid_m4n2_1d_patterns  is not None: return valid_m4n2_1d_patterns
    valid_patterns = []
    for i in list(combinations(range(0, m), n)):
        cur_pattern = np.zeros(m, dtype=np.int32)
        cur_pattern[list(i)] = 1
        valid_patterns.append(cur_pattern)
    valid_patterns = torch.Tensor(valid_patterns)
    if m == 4  and n == 2: valid_m4n2_1d_patterns  = valid_patterns       
    return valid_patterns
  1. Fix a bug
""" m:n 2d structured pruning: greedy method to select mask """
def mn_2d_greedy(matrix, m, n):
    # Convert to numpy
    mat = matrix.cpu().detach().numpy()
    mask = np.ones(mat.shape, dtype=int)

    rowCount = int(mat.shape[0]/m) * m
    colCount = int(mat.shape[1]/m) * m
    for rowStartIdx in range(0, rowCount, m):
        rowEndIdx = rowStartIdx + m
        for colStartIdx in range(0, colCount, m):
            colEndIdx = colStartIdx + m
            matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]))
            maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])
            maskSub.fill(0.0)
            matrixVecView = matrixSub.reshape(-1)
            maskVecView   = maskSub.reshape(-1)
            linearIdx = np.argsort(matrixVecView)
            matrixIdx = [(int(x/m), x % m) for x in linearIdx]
            rowCounter = collections.Counter()
            colCounter = collections.Counter()
            for currIdx in range(len(linearIdx) - 1, -1, -1):
                currMatrixEntry = matrixIdx[currIdx]
                if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n):
                    continue
                #end if
                maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0
                rowCounter[currMatrixEntry[0]] += 1
                colCounter[currMatrixEntry[1]] += 1

    return torch.tensor(mask.cuda())

the last line shoule be torch.tensor(mask).cuda() otherwise an exception will be thrown (ndarray do not have attribute named cuda).

@crcrpar
Copy link
Collaborator

crcrpar commented Oct 6, 2022

cc: @ChongyuNVIDIA @jpool-nv

@jpool-nv
Copy link
Contributor

Thanks for submitting this, @LeiWang1999 .

I have one small modification for the first issue. It probably doesn't matter, but I'll be more comfortable with your (much faster!) version if the output it gives matches the old, slower version exactly. This can be done with a small change to your code:

valid_patterns = torch.Tensor(valid_patterns)
# becomes
valid_patterns = torch.Tensor(list(set([tuple(vp) for vp in valid_patterns])))

This takes the list of valid_patterns, creates a set of tuples (to "sort" the entries into the same order as the old version), and gets a Tensor from that list.

The results are clearly superior to the existing version. With a small test app:

if __name__ == "__main__":
    for m in range(4,18,2):
        for n in range(1,m):
            start_time = time.clock_gettime_ns(time.CLOCK_REALTIME)
            old_results = compute_valid_1d_patterns(m,n)
            mid_time = time.clock_gettime_ns(time.CLOCK_REALTIME)
            new_results = compute_valid_1d_patterns_new(m,n)
            end_time = time.clock_gettime_ns(time.CLOCK_REALTIME)
            old_time_ms = (mid_time - start_time) / 1000000.
            new_time_ms = (end_time - mid_time) / 1000000.

            all_match = all(old_results.flatten() == new_results.flatten())
            print(f"{old_time_ms:.2f}ms + {new_time_ms:.2f}ms to get {len(old_results)} {n}:{m} equal patterns? {all_match}")

We can see the results all match before I give up waiting for the old version to compute the larger m results :)

0.27ms + 0.05ms to get 4 1:4 equal patterns? True
0.06ms + 0.05ms to get 6 2:4 equal patterns? True
0.03ms + 0.04ms to get 4 3:4 equal patterns? True
0.20ms + 0.04ms to get 6 1:6 equal patterns? True
0.21ms + 0.08ms to get 15 2:6 equal patterns? True
0.24ms + 0.10ms to get 20 3:6 equal patterns? True
0.23ms + 0.08ms to get 15 4:6 equal patterns? True
0.21ms + 0.04ms to get 6 5:6 equal patterns? True
9.31ms + 0.04ms to get 8 1:8 equal patterns? True
9.76ms + 0.12ms to get 28 2:8 equal patterns? True
10.00ms + 0.21ms to get 56 3:8 equal patterns? True
10.20ms + 0.26ms to get 70 4:8 equal patterns? True
10.12ms + 0.22ms to get 56 5:8 equal patterns? True
10.36ms + 0.13ms to get 28 6:8 equal patterns? True
9.64ms + 0.05ms to get 8 7:8 equal patterns? True
971.58ms + 0.06ms to get 10 1:10 equal patterns? True
996.87ms + 0.18ms to get 45 2:10 equal patterns? True
1029.76ms + 0.45ms to get 120 3:10 equal patterns? True
1047.53ms + 0.77ms to get 210 4:10 equal patterns? True
1064.14ms + 1.03ms to get 252 5:10 equal patterns? True
1061.38ms + 0.89ms to get 210 6:10 equal patterns? True
1065.26ms + 0.49ms to get 120 7:10 equal patterns? True
1038.28ms + 0.20ms to get 45 8:10 equal patterns? True
1025.07ms + 0.07ms to get 10 9:10 equal patterns? True
150627.41ms + 0.10ms to get 12 1:12 equal patterns? True

If you expect to re-use the patterns regularly, it'd make sense to cache them into a generalized valid_1d_patterns dict with (m,n) tuples as keys (similar to the existing valid_m4n2_1d_patterns), but that's not necessary for this PR, especially given the new performance.

@LeiWang1999
Copy link
Author

hi, actually I have a deep modification with this pr, we are preparing a research work and we need n:m which m will be 32 * 64 in the worst case, and the C_m^n complexity is intolerable either, we optimize the code and provide a more efficient way to generate the mask with n^2 complexity, and we will also provide a vector-wise and block-wise sparsity parttern, so this pr can probably not be merged for now, waiting for our paper publication and I will make an update one.

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants