1 import torch 2 import intel_extension_for_pytorch # type: ignore # noqa: F401 3 4 import triton 5 import triton.language as tl 6 import shutil 7 import os 8 9 HOME = os.environ['HOME'] 10 BABYLON_PATH = os.path.join(HOME, 'babylon') 11 12 @triton.jit 13 def add_kernel(): 14 pass 15 16 @triton.jit 17 def matmul_kernel(): 18 pass 19 20 @triton.jit 21 def softmax_kernel(): 22 pass 23 24 ADD_KERNEL_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/add_kernel.mlir" 25 MATMUL_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/matmul_kernel.mlir" 26 SOFTMAX_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/softmax_kernel.mlir" 27 28 if os.path.isdir(f'{HOME}/.triton/cache'): 29 shutil.rmtree(f'{HOME}/.triton/cache') 30 31 triton.compile(triton.compiler.ASTSource(fn=add_kernel, signature={}, constants={}), target_mlir=ADD_KERNEL_MLIR) 32 triton.compile(triton.compiler.ASTSource(fn=softmax_kernel, signature={}, constants={}), target_mlir=SOFTMAX_MLIR, options={"num_warps":32}) 33 triton.compile(triton.compiler.ASTSource(fn=matmul_kernel, signature={}, constants={}), target_mlir=MATMUL_MLIR, options={"threads_per_warp":16, "num_warps":64})