1 diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py 2 index 3ec8ff32..bd59e1b6 100644 3 --- a/python/triton/compiler/compiler.py 4 +++ b/python/triton/compiler/compiler.py 5 @@ -223,7 +223,7 @@ def filter_traceback(e: BaseException): 6 e.__traceback__ = frames[0] 7 8 9 -def compile(src, target_mlir=None, target=None, options=None): 10 +def compile(src, target=None, options=None): 11 if target is None: 12 target = driver.active.get_current_target() 13 assert isinstance(target, GPUTarget), "target must be of GPUTarget type" 14 @@ -268,7 +268,7 @@ def compile(src, target_mlir=None, target=None, options=None): 15 } 16 # run compilation pipeline and populate metadata 17 stages = dict() 18 - backend.add_stages(stages, options, target_mlir) 19 + backend.add_stages(stages, options) 20 first_stage = list(stages.keys()).index(src.ext) 21 # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. 22 if ir_source: 23 diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py 24 index 9b4261c7..8dad5654 100644 25 --- a/third_party/intel/backend/compiler.py 26 +++ b/third_party/intel/backend/compiler.py 27 @@ -138,11 +138,7 @@ class XPUBackend(BaseBackend): 28 intel.load_dialects(ctx) 29 30 @staticmethod 31 - def make_ttir(mod, metadata, opt, target_mlir): 32 - if (target_mlir): 33 - context = mod.context 34 - mod = ir.parse_mlir_module(f"{target_mlir}", mod.context) 35 - mod.context = context 36 + def make_ttir(mod, metadata, opt): 37 pm = ir.pass_manager(mod.context) 38 pm.enable_debug() 39 passes.common.add_inliner(pm) 40 @@ -254,8 +250,8 @@ class XPUBackend(BaseBackend): 41 metadata["name"] = name 42 return ret 43 44 - def add_stages(self, stages, options, target_mlir): 45 - stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, target_mlir) 46 + def add_stages(self, stages, options): 47 + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) 48 stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties) 49 stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) 50 stages["spv"] = lambda src, metadata: self.make_spv(src, metadata)