1- # RUN: %PYTHON %s | FileCheck %s
2-
31# Simply demonstrates applying a schedule to a payload.
4- #
52# To do so generates a basic payload and a basic schedule, purely as an example.
6- # Shows how to do it contained to a single Python process and how do invoke the
7- # same functionality from the cmdline. The latter to facilitate the case where
8- # the payload and schedule already exist as .mlir files. Run this file to see
9- # the concrete schedule IR, pre-transform payload IR and transformed payload IR.
10-
11- import subprocess
12- from tempfile import NamedTemporaryFile
133
144from mlir .ir import Context , Location , InsertionPoint , Operation , Module
15- from mlir .ir import RankedTensorType , F32Type , FloatAttr , DenseElementsAttr , UnitAttr
5+ from mlir .ir import RankedTensorType , F32Type , UnitAttr
166from mlir .dialects import arith , func , linalg , tensor , transform
177from mlir .dialects .transform import structured
188
@@ -34,29 +24,16 @@ def example_payload() -> Module:
3424 with InsertionPoint (payload .body ):
3525 matrixType = RankedTensorType .get ([16 , 16 ], F32Type .get ())
3626
37- # NB: Do the CHECKing on the transformed output:
38- # CHECK-LABEL: result of applying schedule to payload
39- # CHECK: func.func @fold_add_on_two_matmuls
40- # CHECK-SAME: (%[[MATRIX_A:.*]]: {{.*}}, %[[MATRIX_B:.*]]: {{.*}})
41- @func .func (matrixType , matrixType )
42- def fold_add_on_two_matmuls (matrixA , matrixB ):
43- splat_float = FloatAttr .get (F32Type .get (), 1.111111 )
44- splat_attr = DenseElementsAttr .get_splat (matrixType , splat_float )
45- # CHECK: %[[WEIGHTS:.*]] = arith.constant dense<1.11
46- weights = arith .constant (matrixType , splat_attr )
47- c0 = arith .constant (F32Type .get (), 0.0 )
27+ @func .func (matrixType , matrixType , matrixType )
28+ def fold_add_on_two_matmuls (matrixA , matrixB , weights ):
4829 empty = tensor .empty (matrixType .shape , matrixType .element_type )
49- # CHECK: %[[ZERO_INIT:.*]] = linalg.fill
30+ c0 = arith . constant ( F32Type . get (), 0.0 )
5031 zero_init = linalg .fill (c0 , outs = [empty ])
51- # CHECK: %[[A_X_WEIGHTS:.*]] = linalg.matmul ins(%[[MATRIX_A]], %[[WEIGHTS]]{{.*}}) outs(%[[ZERO_INIT]]
5232 A_x_weights = linalg .matmul (matrixA , weights , outs = [zero_init ])
5333 empty2 = tensor .empty (matrixType .shape , matrixType .element_type )
5434 zero_init2 = linalg .fill (c0 , outs = [empty2 ])
55- # CHECK: %[[RES:.*]] = linalg.matmul ins(%[[MATRIX_B]], %[[WEIGHTS]]{{.*}}) outs(%[[A_X_WEIGHTS]]
5635 B_x_weights = linalg .matmul (matrixB , weights , outs = [zero_init2 ])
57- # CHECK-NOT: linalg.add
5836 added = linalg .add (A_x_weights , B_x_weights , outs = [empty ])
59- # CHECK: return %[[RES]]
6037 return added
6138
6239 print (payload )
@@ -96,37 +73,11 @@ def example_schedule() -> Module:
9673with Context (), Location .unknown ():
9774 payload = example_payload ()
9875 schedule_module = example_schedule ()
99- # Demonstrate applying a schedule to a payload, both generated in-process :
76+ # Actual schedule is defined by the contained transfomr.named_sequence :
10077 schedule : transform .NamedSequenceOp = schedule_module .body .operations [0 ]
78+ schedule_module .body
10179
102- print (
103- "NOTE: result of applying schedule to payload directly within Python process:"
104- )
105- schedule .apply (payload )
106- print (payload )
80+ schedule .apply (payload ) # The actual transformation happens here.
10781
108- # Demonstrate applying a schedule to a payload, both as .mlir files, on the cmdline
109- # (to facilitate the same functionality for out-of-process generated schedules and payloads):
110- with (
111- NamedTemporaryFile ("w" , prefix = "payload_" , suffix = ".mlir" ) as payload_file ,
112- NamedTemporaryFile ("w" , prefix = "schedule_" , suffix = ".mlir" ) as schedule_file ,
113- ):
114- print (payload , file = payload_file , flush = True )
115- print ("NOTE: Have dumped payload to temp file:" , payload_file .name )
116- print (schedule_module , file = schedule_file , flush = True )
117- print ("NOTE: Have dumped schedule to temp file:" , schedule_file .name )
118-
119- cmdline = [
120- "python" ,
121- "-m" ,
122- "lighthouse.schedule" ,
123- schedule_file .name ,
124- payload_file .name ,
125- ]
126- print (
127- "NOTE: output of applying schedule to payload from commandline:" , * cmdline
128- )
129- print (subprocess .run (cmdline , capture_output = True ).stdout .decode ())
130- print (
131- f"NOTE: cleaning-up temp files: { payload_file .name } , { schedule_file .name } "
132- )
82+ print ("NOTE: result of applying schedule to payload:" )
83+ print (payload )
0 commit comments