Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions aiter/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
hipb_create_extension,
hipb_mm,
logger,
rocb_create_extension,
rocb_mm,
)
from aiter.jit.core import AITER_CONFIG_GEMM_BF16_FILE, AITER_LOG_TUNED_CONFIG
from aiter.jit.utils.chip_info import get_cu_num
Expand All @@ -41,7 +39,7 @@
this_dir = os.path.dirname(os.path.abspath(__file__))


solMap = ["torch", "hipblaslt", "rocblas", "skinny", "asm"]
solMap = ["torch", "hipblaslt", "skinny", "asm"]


def get_solfunc(soltype: int):
Expand All @@ -50,10 +48,8 @@ def get_solfunc(soltype: int):
elif soltype == 1:
return hipb_gemm
elif soltype == 2:
return rocb_gemm
elif soltype == 3:
return skinny_gemm
elif soltype == 4:
elif soltype == 3:
return asm_gemm


Expand Down Expand Up @@ -232,25 +228,6 @@ def hipb_gemm(
return hipb_mm(inp, weights.t(), solidx, bias, otype, scale_a, scale_b, scale_c)


def rocb_gemm(
inp: Tensor,
weights: Tensor,
solidx: int,
bias: Optional[Tensor] = None,
otype: Optional[torch.dtype] = None,
scale_a: Optional[Tensor] = None,
scale_b: Optional[Tensor] = None,
scale_c: Optional[Tensor] = None,
):
assert (
scale_a is None and scale_b is None and scale_c is None
), "scale_a, scale_b, scale_c must be None for rocblas"
out = rocb_mm(inp, weights.t(), solidx)
if bias is not None:
out = out + bias
return out


def torch_gemm(
inp: Tensor,
weights: Tensor,
Expand Down Expand Up @@ -324,7 +301,6 @@ def mm(
scale_c: Optional[Tensor] = None,
):
if self.extensions_created == False:
rocb_create_extension()
hipb_create_extension()
self.extensions_created = True
out = gemm_a16w16(
Expand Down
144 changes: 8 additions & 136 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from aiter.jit.core import AITER_CONFIG_GEMM_BF16_FILE, get_asm_dir
from aiter.utility.base_tuner import GemmCommonTuner

aiter.rocb_create_extension()
aiter.hipb_create_extension()


Expand All @@ -41,11 +40,6 @@ def init_hipblas():
aiter.hipb_create_extension()


@lru_cache(maxsize=1)
def init_rocblas():
aiter.rocb_create_extension()


def call_hipb_mm(input, weight, bias, scale_a, scale_b, solidx, out_dtype):
init_hipblas()
return aiter.hipb_mm(
Expand All @@ -59,11 +53,6 @@ def call_hipb_mm(input, weight, bias, scale_a, scale_b, solidx, out_dtype):
)


def call_rocb_mm(inp, w, solidx):
init_rocblas()
return aiter.rocb_mm(inp, w, solidx)


def run_gemm_bf16_asm(inp, w, out, bias=None, splitK=None, kernelName=None):
return aiter.gemm_a16w16_asm(
inp, w, out, bias=bias, splitK=splitK, kernelName=kernelName
Expand Down Expand Up @@ -142,7 +131,6 @@ def __init__(
indtype,
outdtype,
scaleAB=False,
rocblas_decode=False,
mp=1,
err_ratio=0.01,
profile_file="",
Expand All @@ -156,15 +144,13 @@ def __init__(
self.indtype = indtype
self.outdtype = outdtype
self.scaleAB = scaleAB
self.use_rocblas = indtype == outdtype and str(indtype) != "dtypes.fp8"
self.nb = CACHE_INVALIDATE_BUFFERS
(self.inp, self.weights, _, self.bias, _, scaleA) = generate_data(
m, n, k, indtype, outdtype, scaleAB, 0, bias
)
self.blob = torch.ones(128 * 1024 * 1024, dtype=dtypes.fp32, device="cuda")
self.topn = 20 # number of top solutions from each source
self.hipb_sols = []
self.rocb_sols = []
self.rtol = 1e-2
self.atol = 1e-2
# self.ref = self.get_gemm_ref()
Expand All @@ -176,7 +162,6 @@ def __init__(
# prefer hipblaslt unless rocblas time is less than this
# ratio of hipblaslt time
self.hipb_prefer_ratio = 0.995
self.rocblas_decode = rocblas_decode
self.mp = mp
# self.inbpe = self.inp.element_size()
# self.outbpe = self.ref.element_size()
Expand Down Expand Up @@ -416,111 +401,24 @@ def save_topn_result(self, rets, fast_mode, libtype):
print(gtimedf.head(self.topn), flush=True)
return gtimedf

def find_rocblas_sols(self):
if self.scaleAB or self.bias is not None:
sols = []
else:
sols = aiter.rocb_findallsols(self.inp, self.weights.t())
print(
"M N K dtype",
self.m,
self.n,
self.k,
self.indtype,
self.outdtype,
">>> Total rocb solutions",
len(sols),
flush=True,
)
# print(sols)
self.rocb_sols = sols

def rocb_time_all_sols(self, fast_mode=0, top_sols=0):
coldi = 20
warmi = 20
if fast_mode:
coldi = 2
warmi = 5
solutions = self.rocb_sols
if top_sols:
solutions = self.rocb_top_sols
task = []
gtimes = {}
for solidx in solutions:
info = (
(
self.m,
self.n,
self.k,
False,
str(self.indtype),
str(self.outdtype),
False,
),
solidx,
0,
"rocblas",
"rocblas",
)
task.append(
(
info,
generate_data,
(self.m, self.n, self.k, self.indtype, self.outdtype, False),
call_rocb_mm,
(
[0, 2],
solidx,
),
{
"num_warmup": warmi,
"num_iters": coldi,
},
get_gemm_ref if fast_mode == 0 else None,
([0, 1, 3, 4], self.indtype, self.outdtype),
{},
None, # self.ref if fast_mode == 0 else None,
self.rtol,
self.atol,
)
)
if task:
in_data = [(len(solutions), ())]
ret = mp_tuner(task, in_data, self.mp, fast_mode == 1)
else:
ret = []
if fast_mode == 1:
self.rocb_gtimedf = self.save_topn_result(ret, fast_mode, "rocblas")
return []
return ret

def warmup(self, warmi=500):
for i in range(warmi):
self.blob = self.blob + 0.00001

def functional_get_topn_fastest(self):
rocb_topn = self.rocb_gtimedf["solidx"].head(self.topn).tolist()
self.rocb_top_sols = rocb_topn
hipb_topn = self.hipb_gtimedf["solidx"].head(self.topn).tolist()
self.hipb_top_sols = hipb_topn

def run_fast_solutions(self):
if self.use_rocblas:
self.find_rocblas_sols()
if not (self.rocblas_decode and self.m == 1):
self.find_hipblas_sols()
self.warmup()
rets_rocb_fast = self.rocb_time_all_sols(fast_mode=1)
self.find_hipblas_sols()
self.warmup()
rets_hipb_fast = self.hipb_time_all_sols(fast_mode=1)

def run_best_solutions(self):
self.warmup()
rets_rocb = self.rocb_time_all_sols(fast_mode=0, top_sols=1)
self.warmup()
rets_hipb = self.hipb_time_all_sols(fast_mode=0, top_sols=1)
rets_asm = self.asm_gemm_all_solutions()
return rets_rocb + rets_hipb + rets_asm
return rets_hipb + rets_asm

def run_solutions(self):
self.run_fast_solutions()
Expand All @@ -539,17 +437,6 @@ def cleanup(self):
cpu_blob = self.blob.cpu()
del cpu_blob

def cleanup(self):
if hasattr(self, "inp"):
del self.inp
if hasattr(self, "weights"):
del self.weights
if hasattr(self, "bias") and self.bias is not None:
del self.bias
if hasattr(self, "blob"):
cpu_blob = self.blob.cpu()
del cpu_blob


class GemmTuner(GemmCommonTuner):
ARG_DEFAULTS = {
Expand Down Expand Up @@ -589,12 +476,6 @@ def _setup_specific_arguments(self):
help="dtype: f32 f16 bf16 fp8. Use to override the default value,"
" which is the same as indtype for each shape (see --indtype.)",
)
self.parser.add_argument(
"--rocblas-decode",
action="store_true",
default=False,
help="forces rocblas solution on decode N=1",
)

self.parser.add_argument(
"--all_bias",
Expand Down Expand Up @@ -748,7 +629,6 @@ def tune(self, untunedf, tunedf, args):
indtype=eval(indtype),
outdtype=eval(outdtype),
scaleAB=ds["scaleAB"],
rocblas_decode=args.rocblas_decode,
mp=args.mp,
err_ratio=args.errRatio,
profile_file=args.profile_file,
Expand Down Expand Up @@ -822,27 +702,19 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
best_gtimedfs = pd.DataFrame(columns=self.columns)
for key, df in gtimedf_dic.items():
gtimedf_dic[key] = df[df["err_ratio"] < args.errRatio]
gtimedf_dic[key]["gtimems"] = np.where(
df["libtype"] == "rocblas", df["us"], df["us"] * self.hipb_prefer_ratio
)
# get best solution
best_gtimedf = gtimedf_dic[key].sort_values(by="gtimems")
best_gtimedf = gtimedf_dic[key].sort_values(by="us")

if len(gtimedf_dic[key]) == 0:
print(">>> No rocblas or hipblas or asm solutions found!", flush=True)
print(">>> No hipblas or asm solutions found!", flush=True)
continue
robs_gtimedf = gtimedf_dic[key][gtimedf_dic[key]["libtype"] == "rocblas"]
asm_gtimedf = gtimedf_dic[key][gtimedf_dic[key]["libtype"] == "asm"]
hibs_gtimedf = gtimedf_dic[key][gtimedf_dic[key]["libtype"] == "hipblaslt"]
if len(robs_gtimedf) == 0 and len(hibs_gtimedf) == 0:
if len(hibs_gtimedf) == 0:
print(">>>Only asm solutions found!", flush=True)
elif len(robs_gtimedf) == 0:
print(">>> Only hipblas or asm solutions found!", flush=True)
elif len(hibs_gtimedf) == 0 and len(asm_gtimedf) == 0:
print(">>> Only rocblas solutions found!", flush=True)
resultdf1 = (
best_gtimedf.head(1).drop(["gtimems"], axis=1).reset_index(drop=True)
)
elif len(asm_gtimedf) == 0:
print(">>> no hipblas solutions found!", flush=True)
resultdf1 = best_gtimedf.head(1).reset_index(drop=True)
kernal_name = (
aiter.getHipblasltKernelName(int(resultdf1.iloc[0]["solidx"]))
if resultdf1.iloc[0]["libtype"] == "hipblaslt"
Expand Down
1 change: 0 additions & 1 deletion gradlib/gradlib/gemm_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import multiprocessing as mp
import gc

aiter.rocb_create_extension()
aiter.hipb_create_extension()


Expand Down