Skip to content

Commit 6f29ea9

Browse files
committed
KISS. With an emphasis on S.
1 parent d037047 commit 6f29ea9

File tree

4 files changed

+21
-115
lines changed

4 files changed

+21
-115
lines changed

.github/workflows/examples.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,8 @@ jobs:
2525
2626
- name: Run MLP From Module
2727
run: |-
28-
uv run python/examples/ingress/torch/mlp_from_model.py
28+
uv run python/examples/ingress/torch/mlp_from_model.py
29+
30+
- name: Run schedule application example
31+
run: |-
32+
uv run python/examples/schedule/transform_a_payload_according_to_a_schedule.py

python/examples/mlir/compile_and_run.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
# RUN: %PYTHON %s
2+
13
import torch
24
import argparse
35

46
from mlir import ir
57
from mlir.dialects import transform
68
from mlir.dialects.transform import structured
7-
from mlir.dialects.transform import interpreter
89
from mlir.execution_engine import ExecutionEngine
910
from mlir.passmanager import PassManager
1011

@@ -89,22 +90,6 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8990
return schedule
9091

9192

92-
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
93-
"""
94-
Apply transformation schedule to a kernel module.
95-
The kernel is modified in-place.
96-
97-
Args:
98-
kernel: A module with payload function.
99-
schedule: A module with transform schedule.
100-
"""
101-
interpreter.apply_named_sequence(
102-
payload_root=kernel,
103-
transform_root=schedule.body.operations[0],
104-
transform_module=schedule,
105-
)
106-
107-
10893
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
10994
"""
11095
Create an MLIR pass pipeline.
@@ -141,9 +126,11 @@ def main(args):
141126
ctx = ir.Context()
142127
kernel = create_kernel(ctx)
143128

144-
# Create a transform schedule and apply initial lowering.
145-
schedule = create_schedule(ctx)
146-
apply_schedule(kernel, schedule)
129+
# Create a transform schedule and apply initial lowering to kernel.
130+
# The kernel is modified in-place.
131+
schedule_module = create_schedule(ctx)
132+
named_seq: transform.NamedSequenceOp = schedule_module.body.operations[0]
133+
named_seq.apply(kernel)
147134

148135
# Create a pass pipeline and lower the kernel to LLVM dialect.
149136
pm = create_pass_pipeline(ctx)
Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
# RUN: %PYTHON %s | FileCheck %s
2-
31
# Simply demonstrates applying a schedule to a payload.
4-
#
52
# To do so generates a basic payload and a basic schedule, purely as an example.
6-
# Shows how to do it contained to a single Python process and how do invoke the
7-
# same functionality from the cmdline. The latter to facilitate the case where
8-
# the payload and schedule already exist as .mlir files. Run this file to see
9-
# the concrete schedule IR, pre-transform payload IR and transformed payload IR.
10-
11-
import subprocess
12-
from tempfile import NamedTemporaryFile
133

144
from mlir.ir import Context, Location, InsertionPoint, Operation, Module
15-
from mlir.ir import RankedTensorType, F32Type, FloatAttr, DenseElementsAttr, UnitAttr
5+
from mlir.ir import RankedTensorType, F32Type, UnitAttr
166
from mlir.dialects import arith, func, linalg, tensor, transform
177
from mlir.dialects.transform import structured
188

@@ -34,29 +24,16 @@ def example_payload() -> Module:
3424
with InsertionPoint(payload.body):
3525
matrixType = RankedTensorType.get([16, 16], F32Type.get())
3626

37-
# NB: Do the CHECKing on the transformed output:
38-
# CHECK-LABEL: result of applying schedule to payload
39-
# CHECK: func.func @fold_add_on_two_matmuls
40-
# CHECK-SAME: (%[[MATRIX_A:.*]]: {{.*}}, %[[MATRIX_B:.*]]: {{.*}})
41-
@func.func(matrixType, matrixType)
42-
def fold_add_on_two_matmuls(matrixA, matrixB):
43-
splat_float = FloatAttr.get(F32Type.get(), 1.111111)
44-
splat_attr = DenseElementsAttr.get_splat(matrixType, splat_float)
45-
# CHECK: %[[WEIGHTS:.*]] = arith.constant dense<1.11
46-
weights = arith.constant(matrixType, splat_attr)
47-
c0 = arith.constant(F32Type.get(), 0.0)
27+
@func.func(matrixType, matrixType, matrixType)
28+
def fold_add_on_two_matmuls(matrixA, matrixB, weights):
4829
empty = tensor.empty(matrixType.shape, matrixType.element_type)
49-
# CHECK: %[[ZERO_INIT:.*]] = linalg.fill
30+
c0 = arith.constant(F32Type.get(), 0.0)
5031
zero_init = linalg.fill(c0, outs=[empty])
51-
# CHECK: %[[A_X_WEIGHTS:.*]] = linalg.matmul ins(%[[MATRIX_A]], %[[WEIGHTS]]{{.*}}) outs(%[[ZERO_INIT]]
5232
A_x_weights = linalg.matmul(matrixA, weights, outs=[zero_init])
5333
empty2 = tensor.empty(matrixType.shape, matrixType.element_type)
5434
zero_init2 = linalg.fill(c0, outs=[empty2])
55-
# CHECK: %[[RES:.*]] = linalg.matmul ins(%[[MATRIX_B]], %[[WEIGHTS]]{{.*}}) outs(%[[A_X_WEIGHTS]]
5635
B_x_weights = linalg.matmul(matrixB, weights, outs=[zero_init2])
57-
# CHECK-NOT: linalg.add
5836
added = linalg.add(A_x_weights, B_x_weights, outs=[empty])
59-
# CHECK: return %[[RES]]
6037
return added
6138

6239
print(payload)
@@ -96,37 +73,11 @@ def example_schedule() -> Module:
9673
with Context(), Location.unknown():
9774
payload = example_payload()
9875
schedule_module = example_schedule()
99-
# Demonstrate applying a schedule to a payload, both generated in-process:
76+
# Actual schedule is defined by the contained transfomr.named_sequence:
10077
schedule: transform.NamedSequenceOp = schedule_module.body.operations[0]
78+
schedule_module.body
10179

102-
print(
103-
"NOTE: result of applying schedule to payload directly within Python process:"
104-
)
105-
schedule.apply(payload)
106-
print(payload)
80+
schedule.apply(payload) # The actual transformation happens here.
10781

108-
# Demonstrate applying a schedule to a payload, both as .mlir files, on the cmdline
109-
# (to facilitate the same functionality for out-of-process generated schedules and payloads):
110-
with (
111-
NamedTemporaryFile("w", prefix="payload_", suffix=".mlir") as payload_file,
112-
NamedTemporaryFile("w", prefix="schedule_", suffix=".mlir") as schedule_file,
113-
):
114-
print(payload, file=payload_file, flush=True)
115-
print("NOTE: Have dumped payload to temp file:", payload_file.name)
116-
print(schedule_module, file=schedule_file, flush=True)
117-
print("NOTE: Have dumped schedule to temp file:", schedule_file.name)
118-
119-
cmdline = [
120-
"python",
121-
"-m",
122-
"lighthouse.schedule",
123-
schedule_file.name,
124-
payload_file.name,
125-
]
126-
print(
127-
"NOTE: output of applying schedule to payload from commandline:", *cmdline
128-
)
129-
print(subprocess.run(cmdline, capture_output=True).stdout.decode())
130-
print(
131-
f"NOTE: cleaning-up temp files: {payload_file.name}, {schedule_file.name}"
132-
)
82+
print("NOTE: result of applying schedule to payload:")
83+
print(payload)

python/lighthouse/schedule/__main__.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)