Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1772: Added support for remaining ttnn pipeline pybinds and refactored naming convention for ttnn pipeline passes #1777

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tapspatel
Copy link
Contributor

@tapspatel tapspatel commented Jan 14, 2025

This allows the user to write custom mlir modules through ttir_builder (or personally) and run individual passes in python

Sample use

def run_ttir_to_ttir_decomposition_pass(module, dump_module = False):
    ttir_to_ttir_decomposition_pass(module)

    if dump_module:
        print(module)

def run_ttir_load_system_desc(module, dump_module = False):
    ttir_load_system_desc(module)

    if dump_module:
        print(module)

def run_ttir_implicit_device(module, dump_module = False):
    ttir_implicit_device(module)

    if dump_module:
        print(module)

def run_ttnn_layout(module, dump_module = False):
    ttnn_layout(module)

    if dump_module:
        print(module)

def run_convert_ttir_to_ttnn_pass(module, dump_module = False):
    convert_ttir_to_ttnn_pass(module)

    if dump_module:
        print(module)

def run_remove_dead_values_pass(module, dump_module = False):
    remove_dead_values_pass(module)

    if dump_module:
        print(module)

def run_ttnn_workarounds(module, dump_module = False):
    ttnn_workarounds(module)

    if dump_module:
        print(module)

def run_canonicalizer_pass(module, dump_module = False):
    canonicalizer_pass(module)

    if dump_module:
        print(module)

def run_ttnn_decompose_layouts(module, dump_module = False):
    ttnn_decompose_layouts(module)

    if dump_module:
        print(module)

def run_ttnn_deallocate(module, dump_module = False):
    ttnn_deallocate(module)

    if dump_module:
        print(module)


def test_relu_decomp():
    function_name = inspect.currentframe().f_code.co_name
    golden_map = {}

    with Context() as ctx, Location.name(f"{function_name}"):
        parent_location = Location.name(f"{function_name}")
        tt.register_dialect(ctx)
        ttir.register_dialect(ctx)

        module = Module.create()
        with InsertionPoint(module.body):

            input_shape_list = [(128, 128)]

            input_operands = []
            for shape in input_shape_list:
                input_operands.append(RankedTensorType.get(shape, F32Type.get()))

            golden_inputs = []
            for shape in input_shape_list:
                golden_inputs.append(torch.randn(shape, dtype=torch.float32))

            @func.func(*input_operands, name=f"{function_name}")
            def relu(inputs):

                ttir_op_res, golden_dict = create_relu(
                    inputs, [(128, 128)], golden_inputs
                )
                golden_map[golden_dict["location"]] = golden_dict["golden_output"]
                return ttir_op_res

        run_ttir_to_ttir_decomposition_pass(module, False)
        run_ttir_load_system_desc(module, False)
        run_ttir_implicit_device(module, False)
        run_ttnn_layout(module, True)
        run_convert_ttir_to_ttnn_pass(module, False)
        run_remove_dead_values_pass(module, False)
        run_ttnn_workarounds(module, False)
        run_canonicalizer_pass(module, False)
        run_ttnn_decompose_layouts(module, False)
        run_ttnn_deallocate(module, False)
        print(module)
def get_type(input: Operand):
    if isinstance(input, Value):
        type = input.type
    elif isinstance(input, OpView):
        type = input.operation.result.type
    elif isinstance(input, Operation):
        type = input.result.type
    else:
        raise TypeError(f"Invalid input {type(input)}")

    assert isinstance(type, RankedTensorType), "Only ranked tensors are supported"

    return type


def output_type_operands(output_shape_list):
    output_operands = []
    output_type_list = []

    for shape in output_shape_list:
        output_operands.append(create_empty(shape))

    for operand in output_operands:
        output_type_list.append(get_type(operand))

    return output_operands, output_type_list


empty_index = 0


def create_empty(shape):
    global empty_index

    res = tensor.EmptyOp(
        shape, F32Type.get(), loc=Location.name(f"empty_{empty_index}")
    )
    empty_index += 1
    return res

relu_index = 0


def create_relu(inputs, output_shape_list, golden_inputs):
    global relu_index
    location = f"relu_{relu_index}"

    relu_output_operands, relu_output_type_list = output_type_operands(
        output_shape_list
    )
    res = ttir.ReluOp(
        relu_output_type_list,
        [inputs],
        relu_output_operands,
        loc=Location.name(location),
    )
    relu_index += 1

    golden_output = torch.relu(*golden_inputs)

    return res, {"location": location, "golden_output": golden_output}

…ed naming convention for ttnn pipeline passes
@vprajapati-tt
Copy link
Contributor

vprajapati-tt commented Jan 14, 2025

I think the PR completes the task it was assigned, but it seems quite verbose to me. I know this is a bit of an overarching goal, but in the future would it be possible to call the passes using string options (similar to ttmlir-opt). Holding a reference of "Pass IDs" and invoking the compiler as:

from ttmlir import passes

...
module = <...>

passes.pipeline(module, from=passes.TTIR_DECOMPOSE, to=passes.TTNN_LAYOUT)

<or>

passes.pipeline(module, passes=[passes.TTIR_DECOMPOSE, ...])

These literals will store the pass' string representations, and then invoke the pass manager from these strings to construct the "pipeline" intended. Let me know what you think. This is just to get your thoughts, I think this change is probably needed for bringup so we should definitely prioritize this and look into this change in the future.

@tapspatel
Copy link
Contributor Author

tapspatel commented Jan 15, 2025

I completely agree with your comment about having pass IDs

passes.pipeline(module, passes=[passes.TTIR_DECOMPOSE, ...])

You would provide a list of passes that get executed in the order you provide (like ttmlir-opt).

The goal of this PR was to extend how it exists currently without trying to modify the design on this. I would expect that to be a wider discussion.

I'm not sure if there is a naming convention for passes that already exist but we should also have the ability to run other pipeline passes through python infra as well (not just to ttnn pipelines). This can allow us to test compiler passes in python infra.

@sdjordjevicTT what are your thoughts on this? Do you think we need some strict naming convention/design regarding how to register pipelines/what get pybinded?

@nsmithtt
Copy link
Contributor

Does this PR solve a particular problem that we were facing? I agree with the verbosity comments and it might be overkill to expose all passes to python, it's not clear that we really need this support.

@vcanicTT
Copy link
Contributor

@tapspatel Can you clarify the idea behind this change? Why do we need this? When do we need this? What are the scenarios and use-cases?

In your comments and description I see we want to extend the module, but I still don't understand what current and future needs are. Especially, this part "... through ttir_builder (or personally) and run individual passes ...". Can you give me a little bit more context for this?

Also, there is no description within the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants