1 import java.io.File;
  2 import java.io.FileInputStream;
  3 import java.util.List;
  4 import java.util.Map;
  5 import java.util.HashMap;
  6 import java.util.ArrayList;
  7 import java.util.Deque;
  8 import java.util.ArrayDeque;
  9 import java.util.Set;
 10 import java.util.HashSet;
 11 import java.util.Iterator;
 12 import java.util.Arrays;
 13 import java.util.Queue;
 14 import java.util.Optional;
 15 import java.util.stream.Stream;
 16 import java.util.function.Consumer;
 17 import java.util.function.Function;
 18 import java.util.function.Supplier;
 19 import java.time.*;
 20 import java.nio.file.Files;
 21 import java.nio.file.Paths;
 22 import java.io.IOException;
 23 import java.io.BufferedReader;
 24 import java.io.FileWriter;
 25 import java.io.FileReader;
 26 import java.lang.reflect.Method;
 27 import java.lang.ref.Cleaner;
 28 import java.lang.foreign.MemorySegment;
 29 import java.lang.foreign.ValueLayout;
 30 import java.lang.foreign.AddressLayout;
 31 import java.lang.foreign.Arena;
 32 import java.lang.foreign.MemorySegment.Scope;
 33 import static java.lang.foreign.ValueLayout.*;
 34 import jdk.incubator.vector.VectorSpecies;
 35 import jdk.incubator.vector.FloatVector;
 36 import static oneapi.levelzero.ze_api_h.*;
 37 import oneapi.levelzero.ze_api_h;
 38 import oneapi.levelzero.ze_context_desc_t;
 39 import oneapi.levelzero.ze_kernel_desc_t;
 40 import oneapi.levelzero.ze_command_queue_desc_t;
 41 import oneapi.levelzero.ze_command_list_desc_t;
 42 import oneapi.levelzero.ze_command_queue_group_properties_t;
 43 import oneapi.levelzero.ze_event_pool_desc_t;
 44 import oneapi.levelzero.ze_event_desc_t;
 45 import oneapi.levelzero.ze_fence_desc_t;
 46 import oneapi.levelzero.ze_module_desc_t;
 47 import oneapi.levelzero.ze_group_count_t;
 48 import oneapi.levelzero.ze_host_mem_alloc_desc_t;
 49 import oneapi.levelzero.ze_device_mem_alloc_desc_t;
 50 import oneapi.levelzero.ze_device_properties_t;
 51 import oneapi.levelzero.ze_device_compute_properties_t;
 52 import oneapi.levelzero.ze_driver_properties_t;
 53 import oneapi.levelzero.ze_driver_extension_properties_t;
 54 import org.json.JSONArray;
 55 import org.json.JSONObject;
 56 
 57 import java.util.Random;
 58 
 59 public class LevelZero {
 60     public static final AddressLayout driver_handle_t = AddressLayout.ADDRESS;
 61     private final Arena arena;
 62     private final MemorySegment driverHandle;
 63     private final MemorySegment contextHandle;
 64     private final MemorySegment deviceHandle;
 65     private final MemorySegment queueHandle;
 66     private final MemorySegment eventPoolDescription;
 67     private final String homeDir = System.getProperty("user.home");
 68     private final String cacheDir = homeDir + "/.triton/cache/";
 69     private final String addKernelCache = "7961f2e8b433c656051d8638d6a3bb65f43f6cb885525c05d611100dd905aa31";
 70     private final String softmaxKernelCache = "f0c32acd1173759227ef8e0e8d197c94493b90ebf8d1fc254399ffac6b527d6a";
 71     private final String matmulKernelCache = "07e17c2833c9c9efea8ccd782af1c3ee05dcac3efb2cb75f2f8a6eecffe381ef";
 72     private final static VectorSpecies<Float> SPECIES = FloatVector.SPECIES_256;
 73     private Double timeElapsedForRun, timeElapsedForRerun;
 74 
 75     static {
 76         System.loadLibrary("ze_loader");
 77     }
 78 
 79     static void debug(String format, Object... args) {
 80         System.out.printf(format + "%n", args);
 81     }
 82 
 83     private static void check(int result) {
 84         if (result != ZE_RESULT_SUCCESS()) {
 85             throw new RuntimeException(String.format("Call failed: 0x%x (%d)", result, result));
 86         }
 87     }
 88 
 89     MemorySegment contextHandle() {
 90         return contextHandle;
 91     }
 92 
 93     MemorySegment deviceHandle() {
 94         return deviceHandle;
 95     }
 96 
 97     public LevelZero() {
 98         arena = Arena.ofShared();
 99 
100         // get driver
101         check(zeInit(ZE_INIT_FLAG_GPU_ONLY()));
102         MemorySegment driverCount = arena.allocate(Integer.BYTES);
103         check(zeDriverGet(driverCount, MemorySegment.NULL));
104         debug("driverCount = %d", driverCount.get(JAVA_INT, 0));
105         MemorySegment driverHandles = arena.allocate(driverCount.get(JAVA_INT, 0) * driver_handle_t.byteSize(), 8);
106         check(zeDriverGet(driverCount, driverHandles));
107         driverHandle = driverHandles.get(ADDRESS, 0);
108 
109         // create context
110         MemorySegment pContextDesc = arena.allocate(ze_context_desc_t.layout());
111         ze_context_desc_t.stype(pContextDesc, ZE_STRUCTURE_TYPE_CONTEXT_DESC());
112         MemorySegment pContextHandle = arena.allocate(ze_context_handle_t);
113         check(zeContextCreate(driverHandle, pContextDesc, pContextHandle));
114         contextHandle = pContextHandle.get(ADDRESS, 0);
115 
116         // get device
117         MemorySegment pDeviceCount = arena.allocate(Integer.BYTES);
118         check(zeDeviceGet(driverHandle, pDeviceCount, MemorySegment.NULL));
119         int deviceCount = pDeviceCount.get(JAVA_INT, 0);
120         assert deviceCount > 0;
121         debug("deviceCount = %d", deviceCount);
122         MemorySegment deviceHandles = arena.allocate(deviceCount * ze_device_handle_t.byteSize(), 8);
123         check(zeDeviceGet(driverHandle, pDeviceCount, deviceHandles));
124         for (int i = 0; i < deviceCount; i++) {
125             debug("device #%d: %s", i, deviceHandles.get(ze_device_handle_t, i * ze_device_handle_t.byteSize()));
126         }
127         deviceHandle = deviceHandles.get(ze_device_handle_t, 0 * ze_device_handle_t.byteSize());
128         MemorySegment pDeviceProperties = arena.allocate(ze_device_properties_t.layout());
129         ze_device_properties_t.stype(pDeviceProperties, ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES());
130         check(zeDeviceGetProperties(deviceHandle, pDeviceProperties));
131         debug("deviceProperties:\n\ttype = %d\n\tvendorId = %d\n\tmaxMemAllocSize = %d\n\tdeviceId = %d\n\tcoreClockRate = %d",
132             ze_device_properties_t.type(pDeviceProperties),
133             ze_device_properties_t.vendorId(pDeviceProperties),
134             ze_device_properties_t.maxMemAllocSize(pDeviceProperties),
135             ze_device_properties_t.deviceId(pDeviceProperties),
136             ze_device_properties_t.coreClockRate(pDeviceProperties));
137 
138         MemorySegment pDeviceComputeProperties = arena.allocate(ze_device_compute_properties_t.layout());
139         ze_device_compute_properties_t.stype(pDeviceComputeProperties, ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES());
140         check(zeDeviceGetComputeProperties(deviceHandle, pDeviceComputeProperties));
141         debug("deviceProperties:\n\tshared = %d\n\tmaxTotalGroupSize = %d",
142             ze_device_compute_properties_t.maxSharedLocalMemory(pDeviceComputeProperties),
143             ze_device_compute_properties_t.maxTotalGroupSize(pDeviceComputeProperties));
144 
145         // create queue
146         MemorySegment pNumQueueGroups = arena.allocate(JAVA_INT, 1);
147         check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, MemorySegment.NULL));
148         debug("#Queue Groups: %d", pNumQueueGroups.get(JAVA_INT, 0));
149         MemorySegment pGroupProperties = arena.allocate(ze_command_queue_group_properties_t.layout(), pNumQueueGroups.get(JAVA_INT, 0));
150         check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, pGroupProperties));
151 
152         MemorySegment pQueueDesc = arena.allocate(ze_command_queue_desc_t.layout());
153         ze_command_queue_desc_t.stype(pQueueDesc, ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC());
154         ze_command_queue_desc_t.index(pQueueDesc, 0);
155         ze_command_queue_desc_t.mode(pQueueDesc, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS());
156         ze_command_queue_desc_t.ordinal(pQueueDesc, 0);
157         MemorySegment pQueueHandle = arena.allocate(ze_command_queue_handle_t);
158         check(zeCommandQueueCreate(contextHandle, deviceHandle, pQueueDesc, pQueueHandle));
159         queueHandle = pQueueHandle.get(ADDRESS, 0);
160 
161         eventPoolDescription = arena.allocate(ze_event_pool_desc_t.layout());
162         ze_event_pool_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_EVENT_POOL_DESC());
163         ze_event_pool_desc_t.count(eventPoolDescription, 20);
164         ze_event_pool_desc_t.flags(eventPoolDescription, ZE_EVENT_POOL_FLAG_HOST_VISIBLE());
165 
166         timeElapsedForRun = timeElapsedForRerun = 0.0;
167     }
168 
169     public void clear() {
170         check(zeCommandQueueDestroy(queueHandle));
171         check(zeContextDestroy(contextHandle));
172     }
173 
174     public void test(String testName) {
175         Object[] args = {};
176         Random rand = new Random();
177         if (testName.equals("add")) {
178             String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json";
179             String moduleName = cacheDir + addKernelCache + "/add_kernel.spv";
180 
181             int BLOCK_SIZE = 64;
182             int elementSize = 4096;
183             int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
184 
185             JSONObject jsonObject = loadJson(jsonFileName);
186             String kernelName = jsonObject.getString("name");
187             int threads_per_warp = jsonObject.getInt("threads_per_warp");
188             int num_warps = jsonObject.getInt("num_warps");
189             int shared = jsonObject.getInt("shared");
190 
191             float[] input1 = new float[elementSize];
192             float[] input2 = new float[elementSize];
193             float[] output = new float[elementSize];
194             for (int i = 0; i < elementSize; i++) {
195                 input1[i] = rand.nextFloat();
196                 input2[i] = rand.nextFloat();
197             }
198             args = new Object[] {input1, input2, output, elementSize};
199             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
200 
201             float[] expected = Test.add(input1, input2, elementSize);
202             Test.check(expected, output);
203         } else if (testName.equals("softmax")) {
204             String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json";
205             String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv";
206 
207             JSONObject jsonObject = loadJson(jsonFileName);
208             String kernelName = jsonObject.getString("name");
209             int threads_per_warp = jsonObject.getInt("threads_per_warp");
210             int num_warps = jsonObject.getInt("num_warps");
211             int shared = jsonObject.getInt("shared");
212 
213             int elementSizeX = 4096, elementSizeY = 64;
214             int gridSize = elementSizeX;
215             float[] input = new float[elementSizeX * elementSizeY];
216             float[] output = new float[elementSizeX * elementSizeY];
217             byte[] sharedMemory = new byte[shared]; // use for storing temporary value of max element and sum of exp
218             for (int i = 0; i < elementSizeX * elementSizeY; i++) {
219                 input[i] = rand.nextFloat();
220             }
221             args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory};
222             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
223 
224             float[] expected = Test.softmax(input, elementSizeX, elementSizeY);
225             Test.check(expected, output);
226         } else if (testName.equals("matmul")) {
227             String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json";
228             String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv";
229 
230             JSONObject jsonObject = loadJson(jsonFileName);
231             String kernelName = jsonObject.getString("name");
232             int threads_per_warp = jsonObject.getInt("threads_per_warp");
233             int num_warps = jsonObject.getInt("num_warps");
234             int shared = jsonObject.getInt("shared");
235 
236             int M = 1024, N = 1024, K = 1024;
237             int BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 64;
238             int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N);
239             float[] a = new float[M * K];
240             float[] b = new float[K * N];
241             float[] c = new float[M * N];
242             byte[] sharedMemory = new byte[shared];
243 
244             for (int i = 0; i < M * K; i++) {
245                 a[i] = rand.nextFloat();
246             }
247             for (int i = 0; i < K * N; i++) {
248                 b[i] = rand.nextFloat();
249             }
250             args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory};
251             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
252 
253             float[] expected = Test.matmul(a, b, M, N, K);
254             Test.check(expected, c);
255         } else {
256             throw new RuntimeException("Unsupported test: " + testName);
257         }
258     }
259 
260     public void run(String kernelName, String fileName, Object[] args, int threads_per_warp, int num_warps, int shared, int gridSize) {
261         debug("=========== run %s ===========", kernelName);
262         MemorySegment spirvBinary = loadModule(fileName);
263         List<Arg> kernelArgs = collectArgs(args);
264         int[] globalSizes = new int[] {gridSize * threads_per_warp * num_warps, 1, 1};
265         int[] localSizes = new int[] {threads_per_warp * num_warps, 1, 1};
266         KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes);
267         debug("geometry = %s", geometry);
268         MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, shared, false);
269         executeCommandList(commandListHandle);
270         check(zeCommandQueueSynchronize(queueHandle, -1L));
271         for (int i = 0; i < kernelArgs.size(); i++) {
272             copyArgToHost(kernelArgs.get(i), contextHandle);
273         }
274 
275         for (int i = 0; i < kernelArgs.size(); i++) {
276             Arg arg = kernelArgs.get(i);
277             MemorySegment dataSegment = arg.dataSegment();
278             if (dataSegment != null) {
279                 check(zeMemFree(contextHandle, dataSegment));
280             }
281         }
282         check(zeCommandListDestroy(commandListHandle));
283     }
284 
285     public void runRefMatmul(String kernelName, String fileName, Object[] args, int size) {
286         debug("=========== run %s ===========", kernelName);
287         MemorySegment spirvBinary = loadModule(fileName);
288         List<Arg> kernelArgs = collectArgs(args);
289         int[] globalSizes = new int[] {size, size, 1};
290         int[] localSizes = new int[] {512, 1, 1};
291         KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes);
292         debug("geometry = %s", geometry);
293         MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, 0, true);
294         executeCommandList(commandListHandle);
295         check(zeCommandQueueSynchronize(queueHandle, -1L));
296         for (int i = 0; i < kernelArgs.size(); i++) {
297             copyArgToHost(kernelArgs.get(i), contextHandle);
298         }
299 
300         for (int i = 0; i < kernelArgs.size(); i++) {
301             Arg arg = kernelArgs.get(i);
302             MemorySegment dataSegment = arg.dataSegment();
303             if (dataSegment != null) {
304                 check(zeMemFree(contextHandle, dataSegment));
305             }
306         }
307         check(zeCommandListDestroy(commandListHandle));
308     }
309 
310     private List<Arg> collectArgs(Object[] values) {
311         List<Arg> args = new ArrayList<>();
312         for (int i = 0; i < values.length; i++) {
313             args.add(Arg.createArg(this, "arg" + i, values[i]));
314         }
315         debug("args = %s", args);
316         return args;
317     }
318 
319     MemorySegment loadModule(String fileName) {
320         byte[] data = readBytes(fileName);
321         MemorySegment segment = arena.allocate(data.length);
322         segment.copyFrom(MemorySegment.ofArray(data));
323         return segment;
324     }
325 
326     byte[] readBytes(String filename) {
327         File file = new File(filename);
328         try (FileInputStream fis = new FileInputStream(file)) {
329             byte[] data = new byte[(int) file.length()];
330             fis.read(data);
331             return data;
332         } catch (IOException e) {
333             throw new RuntimeException(e);
334         }
335     }
336 
337     void provisionArg(Arg arg) {
338         if (arg.cls() == byte[].class) {
339             byte[] array = (byte[])arg.value();
340             int segmentSize = array.length;
341             arg.setDataSegment(allocateSharedSegment(segmentSize));
342             arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
343             arg.setSize(8);
344             arg.setNeedsCleanup(true);
345         }
346         else if (arg.cls() == short[].class) {
347             short[] array = (short[])arg.value();
348             int segmentSize = array.length * Short.BYTES;
349             arg.setDataSegment(allocateSharedSegment(segmentSize));
350             arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
351             arg.setSize(8);
352             arg.setNeedsCleanup(true);
353         }
354         else if (arg.cls() == int[].class) {
355             int[] array = (int[])arg.value();
356             int segmentSize = array.length * Integer.BYTES;
357             arg.setDataSegment(allocateSharedSegment(segmentSize));
358             arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
359             arg.setSize(8);
360             arg.setNeedsCleanup(true);
361         }
362         else if (arg.cls() == float[].class) {
363             float[] array = (float[])arg.value();
364             int segmentSize = array.length * Float.BYTES;
365             arg.setDataSegment(allocateSharedSegment(segmentSize));
366             arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
367             arg.setSize(8);
368             arg.setNeedsCleanup(true);
369         }
370         else if (VectorSpecies.class.isAssignableFrom(arg.cls())) {
371             arg.setSize(4);
372         }
373         else if (arg.cls() == Short.class) {
374             arg.setSize(2);
375         }
376         else if (arg.cls() == Integer.class || arg.cls() == Float.class || arg.cls() == Boolean.class) {
377             arg.setSize(4);
378         }
379         else if (arg.cls() == Long.class) {
380             arg.setSize(8);
381         }
382         else if (arg.cls() == GPU.Index.class) {
383             MemorySegment pBuffer = arena.allocate(ADDRESS);
384             arg.setDataSegment(allocateSharedSegment(24));
385             arg.setSize(24);
386         }
387         else throw new RuntimeException("unsupported type: " + arg.cls());
388     }
389 
390     void copyArgToHost(Arg arg, MemorySegment contextHandle) {
391         if (arg.cls() == short[].class) {
392             short[] array = (short[])arg.value();
393             MemorySegment arraySegment = MemorySegment.ofArray(array);
394             arraySegment.copyFrom(arg.dataSegment());
395         }
396         else if (arg.cls() == int[].class) {
397             int[] array = (int[])arg.value();
398             MemorySegment arraySegment = MemorySegment.ofArray(array);
399             arraySegment.copyFrom(arg.dataSegment());
400         }
401         else if (arg.cls() == float[].class) {
402             float[] array = (float[])arg.value();
403             MemorySegment arraySegment = MemorySegment.ofArray(array);
404             arraySegment.copyFrom(arg.dataSegment());
405         }
406         // else nothing to do
407     }
408 
409     private MemorySegment createCommandList(MemorySegment spirvModule, String kernelName, KernelGeometry geometry, List<Arg> args, int shared, boolean suggested) {
410         Arena arena = Arena.ofShared();
411         MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t);
412         MemorySegment commandListDesc = arena.allocate(ze_command_list_desc_t.layout());
413         ze_command_list_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC());
414         ze_command_list_desc_t.commandQueueGroupOrdinal(commandListDesc, 0);
415         MemorySegment moduleHandle = createModule(kernelName, spirvModule);
416         check(zeCommandListCreate(contextHandle, deviceHandle, commandListDesc, pCommandListHandle));
417         MemorySegment commandListHandle = pCommandListHandle.get(ADDRESS, 0);
418         MemorySegment kernelHandle = createKernel(moduleHandle, kernelName, geometry, suggested);
419         for (int i = 0; i < args.size(); i++) {
420             Arg arg = args.get(i);
421             setKernelArg(arg, i, commandListHandle, kernelHandle, (shared != 0) && (i == args.size() - 1));
422         }
423         MemorySegment groupCount = arena.allocate(ze_group_count_t.layout());
424         ze_group_count_t.groupCountX(groupCount, (geometry.globalSizes()[0] + geometry.localSizes()[0] - 1) / geometry.localSizes()[0]);
425         ze_group_count_t.groupCountY(groupCount, (geometry.globalSizes()[1] + geometry.localSizes()[1] - 1) / geometry.localSizes()[1]);
426         ze_group_count_t.groupCountZ(groupCount, (geometry.globalSizes()[2] + geometry.localSizes()[2] - 1) / geometry.localSizes()[2]);
427         MemorySegment pKernelWaitHandles = MemorySegment.NULL;
428         check(zeCommandListAppendLaunchKernel(commandListHandle, kernelHandle, groupCount, MemorySegment.NULL, 0, pKernelWaitHandles));
429         check(zeCommandListClose(commandListHandle));
430         return commandListHandle;
431     }
432 
433     private MemorySegment executeCommandList(MemorySegment commandListHandle) {
434         MemorySegment fenceDesc = arena.allocate(ze_fence_desc_t.layout());
435         ze_module_desc_t.stype(fenceDesc, ZE_STRUCTURE_TYPE_FENCE_DESC());
436         ze_fence_desc_t.flags(fenceDesc, ZE_FENCE_FLAG_SIGNALED());
437         MemorySegment pFenceHandle = arena.allocate(ze_fence_handle_t);
438         check(zeFenceCreate(queueHandle, fenceDesc, pFenceHandle));
439         MemorySegment fenceHandle = pFenceHandle.get(ADDRESS, 0);
440         MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t);
441         pCommandListHandle.set(ADDRESS, 0, commandListHandle);
442         Instant start = Instant.now();
443         check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle));
444         check(zeCommandQueueSynchronize(queueHandle, -1L));
445         Instant finish = Instant.now();
446         Double timeElapsed = Duration.between(start, finish).toNanos() * 1e-6;
447         timeElapsedForRun += timeElapsed;
448         debug("time: %f %f\n", timeElapsed, timeElapsedForRun);
449 
450         start = Instant.now();
451         check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle));
452         check(zeCommandQueueSynchronize(queueHandle, -1L));
453         finish = Instant.now();
454         timeElapsed = Duration.between(start, finish).toNanos() * 1e-6;
455         timeElapsedForRerun += timeElapsed;
456         debug("time for rerun: %f %f\n", timeElapsed, timeElapsedForRerun);
457         return fenceHandle;
458     }
459 
460     private MemorySegment createKernel(MemorySegment moduleHandle, String kernelNameString, KernelGeometry geometry, boolean suggested) {
461         MemorySegment kernelDesc = arena.allocate(ze_kernel_desc_t.layout());
462         MemorySegment kernelName = arena.allocateFrom(kernelNameString);
463         ze_kernel_desc_t.stype(kernelDesc, ZE_STRUCTURE_TYPE_KERNEL_DESC());
464         ze_kernel_desc_t.pKernelName(kernelDesc, kernelName);
465         debug("name = %s", kernelNameString);
466         MemorySegment pKernelHandle = arena.allocate(ze_kernel_handle_t);
467         check(zeKernelCreate(moduleHandle, kernelDesc, pKernelHandle));
468         int[] globalSizes = geometry.globalSizes();
469         int[] localSizes = geometry.localSizes();
470         MemorySegment kernelHandle = pKernelHandle.get(ADDRESS, 0);
471         if (suggested) {
472             MemorySegment pGroupSizeX = arena.allocate(JAVA_INT, localSizes[0]);
473             MemorySegment pGroupSizeY = arena.allocate(JAVA_INT, localSizes[1]);
474             MemorySegment pGroupSizeZ = arena.allocate(JAVA_INT, localSizes[2]);
475             check(zeKernelSuggestGroupSize(kernelHandle, globalSizes[0], globalSizes[1], globalSizes[2], pGroupSizeX, pGroupSizeY, pGroupSizeZ));
476             geometry.localSizes()[0] = pGroupSizeX.get(JAVA_INT, 0);
477             geometry.localSizes()[1] = pGroupSizeY.get(JAVA_INT, 0);
478             geometry.localSizes()[2] = pGroupSizeZ.get(JAVA_INT, 0);
479             debug("use suggested group size", geometry.toString());
480             check(zeKernelSetGroupSize(kernelHandle, pGroupSizeX.get(JAVA_INT, 0), pGroupSizeY.get(JAVA_INT, 0), pGroupSizeZ.get(JAVA_INT, 0)));
481         } else {
482             debug("use localSizes", geometry.toString());
483             check(zeKernelSetGroupSize(kernelHandle, localSizes[0], localSizes[1], localSizes[2]));
484         }
485         return kernelHandle;
486     }
487 
488     private void setKernelArg(Arg arg, int ordinal, MemorySegment commandListHandle, MemorySegment kernelHandle, boolean shared) {
489         MemorySegment dataSegment = arg.dataSegment();
490         Class<?> cls = arg.cls();
491         debug("ordinal = %d, cls = %s, data = %s", ordinal, cls.getSimpleName(), dataSegment);
492         if (shared) { // shared memory
493             check(zeKernelSetArgumentValue(kernelHandle, ordinal, dataSegment.byteSize(), dataSegment));
494         }
495         else if (cls == byte[].class || cls == short[].class || cls == int[].class || cls == float[].class || cls.getSimpleName().equals("NativeMemorySegmentImpl")) {
496             check(zeCommandListAppendMemoryPrefetch(commandListHandle, dataSegment, dataSegment.byteSize()));
497             check(zeCommandListAppendMemAdvise(commandListHandle, deviceHandle, dataSegment, dataSegment.byteSize(), ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION()));
498             MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment);
499             check(zeKernelSetArgumentValue(kernelHandle, ordinal, ADDRESS.byteSize(), pDataSegment));
500         }
501         else if (cls == Short.class) {
502             MemorySegment pArgValue = arena.allocateFrom(JAVA_SHORT, (short)arg.value());
503             check(zeKernelSetArgumentValue(kernelHandle, ordinal, Short.BYTES, pArgValue));
504         }
505         else if (VectorSpecies.class.isAssignableFrom(cls)) {
506             MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, FloatVector.SPECIES_256.length());
507             check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue));
508         }
509         else if (cls == Integer.class || cls == Boolean.class) {
510             MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, (int)arg.value());
511             check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue));
512         }
513         else if (cls == Long.class) {
514             MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, (long)arg.value());
515             check(zeKernelSetArgumentValue(kernelHandle, ordinal, Long.BYTES, pArgValue));
516         }
517         else if (cls == Float.class) {
518             MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, Float.floatToIntBits((float)arg.value()));
519             check(zeKernelSetArgumentValue(kernelHandle, ordinal, Float.BYTES, pArgValue));
520         }
521         else if (cls == GPU.Index.class) {
522             MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment);
523             check(zeKernelSetArgumentValue(kernelHandle, ordinal, 24, pDataSegment));
524         }
525         else throw new RuntimeException("unsupported type: " + cls);
526     }
527 
528     private MemorySegment createModule(String moduleName, MemorySegment spirvCode) {
529         MemorySegment pModuleHandle = arena.allocate(ze_module_handle_t);
530         MemorySegment moduleDesc = arena.allocate(ze_module_desc_t.layout());
531         ze_module_desc_t.stype(moduleDesc, ZE_STRUCTURE_TYPE_MODULE_DESC());
532         ze_module_desc_t.format(moduleDesc, ZE_MODULE_FORMAT_IL_SPIRV());
533         ze_module_desc_t.pInputModule(moduleDesc, spirvCode);
534         ze_module_desc_t.inputSize(moduleDesc, spirvCode.byteSize());
535         ze_module_desc_t.pBuildFlags(moduleDesc, arena.allocateFrom(""));
536         MemorySegment buildLogHandle = arena.allocate(ze_module_build_log_handle_t);
537         check(zeModuleCreate(contextHandle, deviceHandle, moduleDesc, pModuleHandle, buildLogHandle));
538         MemorySegment moduleHandle = pModuleHandle.get(ADDRESS, 0);
539         return moduleHandle;
540     }
541 
542     public MemorySegment allocateSharedSegment(long byteSize) {
543         return allocateSharedSegment(contextHandle(), deviceHandle(), byteSize, Arena.global());
544     }
545 
546     public static MemorySegment allocateSharedSegment(MemorySegment contextHandle, MemorySegment deviceHandle, long byteSize, Arena arena) {
547         MemorySegment pDeviceMemAllocDesc = arena.allocate(ze_device_mem_alloc_desc_t.layout());
548         ze_device_mem_alloc_desc_t.stype(pDeviceMemAllocDesc, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC());
549         ze_device_mem_alloc_desc_t.ordinal(pDeviceMemAllocDesc, 0);
550         MemorySegment pHostMemAllocDesc = arena.allocate(ze_host_mem_alloc_desc_t.layout());
551         ze_host_mem_alloc_desc_t.stype(pHostMemAllocDesc, ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC());
552         MemorySegment pBuffer = arena.allocate(ADDRESS);
553         check(zeMemAllocShared(contextHandle, pDeviceMemAllocDesc, pHostMemAllocDesc, byteSize, 1, deviceHandle, pBuffer));
554         long address = pBuffer.get(JAVA_LONG, 0);
555         return MemorySegment.ofAddress(address).reinterpret(byteSize);
556     }
557 
558     private static record KernelGeometry(int[] globalSizes, int[] localSizes) {
559         public KernelGeometry() {
560             this(new int[3], new int[] {512, 1, 1});
561         }
562 
563         @Override
564         public String toString() {
565             return String.format("global: %s, local: %s", Arrays.toString(globalSizes), Arrays.toString(localSizes));
566         }
567     }
568 
569     public void benchAddKernel() {
570         String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json";
571         String moduleName = cacheDir + addKernelCache + "/add_kernel.spv";
572         JSONObject jsonObject = loadJson(jsonFileName);
573         String kernelName = jsonObject.getString("name");
574         int threads_per_warp = jsonObject.getInt("threads_per_warp");
575         int num_warps = jsonObject.getInt("num_warps");
576         int shared = jsonObject.getInt("shared");
577         Random rand = new Random();
578 
579         Writer writer = new Writer("benchmark/vector_add_benchmark.txt");
580         writer.write("elementSize timeElapsed timeElapsedForRerun RTT gb/s \n");
581 
582         for (int elementSize = (1 << 12); elementSize <= (1 << 28); elementSize <<= 1) {
583             int BLOCK_SIZE = 1024;
584             int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
585 
586             float[] input1 = new float[elementSize];
587             float[] input2 = new float[elementSize];
588             float[] output = new float[elementSize];
589             for (int i = 0; i < elementSize; i++) {
590                 input1[i] = rand.nextFloat();
591                 input2[i] = rand.nextFloat();
592             }
593             Object[] args = new Object[] {input1, input2, output, elementSize};
594 
595             // warmup
596             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
597             this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
598 
599             int nTimes = 10;
600             Instant start = Instant.now();
601             for (int i = 0; i < nTimes; ++i)
602                 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
603             Instant finish = Instant.now();
604             Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
605             Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
606             Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
607             writer.write(String.format("%d %.4f %.4f %.4f %.4f\n", elementSize, timeElapsedForRun, timeElapsedForRerun, RTT, (4f * 3f * elementSize / timeElapsedForRerun * 1e-6)));
608         }
609         writer.close();
610     }
611 
612     public void benchSoftmaxKernel() {
613         String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json";
614         String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv";
615         JSONObject jsonObject = loadJson(jsonFileName);
616         String kernelName = jsonObject.getString("name");
617         int threads_per_warp = jsonObject.getInt("threads_per_warp");
618         int num_warps = jsonObject.getInt("num_warps");
619         int shared = jsonObject.getInt("shared");
620         Random rand = new Random();
621 
622         Writer writer = new Writer("benchmark/softmax_benchmark.txt");
623         writer.write("elementSizeX elementSizeY timeElapsed timeElapsedForRerun RTT gb/s \n");
624 
625         for (int i = 2; i < 50; i++) {
626             int elementSizeX = 4096;
627             int elementSizeY = 128 * i;
628             int gridSize = elementSizeX;
629             float[] input = new float[elementSizeX * elementSizeY];
630             float[] output = new float[elementSizeX * elementSizeY];
631             byte[] sharedMemory = new byte[shared];
632             for (int j = 0; j < elementSizeX * elementSizeY; j++) {
633                 input[j] = rand.nextFloat();
634             }
635             Object[] args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory};
636 
637             // warmup
638             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
639             this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
640 
641             int nTimes = 10;
642             Instant start = Instant.now();
643             for (int j = 0; j < nTimes; ++j)
644                 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
645             Instant finish = Instant.now();
646             Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
647             Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
648             Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
649             writer.write(String.format("%d %d %.4f %.4f %.4f %.4f\n", elementSizeX, elementSizeY, timeElapsedForRun, timeElapsedForRerun, RTT, (4 * 2 * 1e-9 * elementSizeX * elementSizeY / (timeElapsedForRerun * 1e-3))));
650         }
651         writer.close();
652     }
653 
654     public void benchMatmulKernel() {
655         String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json";
656         String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv";
657 
658         JSONObject jsonObject = loadJson(jsonFileName);
659         String kernelName = jsonObject.getString("name");
660         int threads_per_warp = jsonObject.getInt("threads_per_warp");
661         int num_warps = jsonObject.getInt("num_warps");
662         int shared = jsonObject.getInt("shared");
663         Random rand = new Random();
664         int BLOCK_SIZE_M = 128, BLOCK_SIZE_N = 64;
665 
666 
667         Writer writer = new Writer("benchmark/matmul_benchmark.txt");
668         writer.write("M N K timeElapsed timeElapsedForRerun RTT TFLOPS \n");
669 
670         for (int i = 2; i <= 64; i++) {
671             int M = 128 * i;
672             int N = 128 * i;
673             int K = 128 * i;
674             int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N);
675 
676             float[] a = new float[M * K];
677             float[] b = new float[K * N];
678             float[] c = new float[M * N];
679             byte[] sharedMemory = new byte[shared];
680             for (int j = 0; j < M * K; j++) {
681                 a[j] = rand.nextFloat();
682             }
683             for (int j = 0; j < K * N; j++) {
684                 b[j] = rand.nextFloat();
685             }
686             Object[] args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory};
687 
688             // warmup
689             run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
690             this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
691 
692             int nTimes = 10;
693             Instant start = Instant.now();
694             for (int j = 0; j < nTimes; ++j)
695                 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
696             Instant finish = Instant.now();
697             Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
698             Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
699             Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
700             writer.write(String.format("%d %d %d %.4f %.4f %.4f %.4f\n", M, N, K, timeElapsedForRun, timeElapsedForRerun, RTT, (2 * 1e-12 * M * N * K / (timeElapsedForRerun * 1e-3))));
701         }
702         writer.close();
703     }
704 
705     public static void main(String[] args) {
706         LevelZero lz = new LevelZero();
707         lz.test("add");
708         lz.test("softmax");
709         lz.test("matmul");
710         lz.benchAddKernel();
711         lz.benchSoftmaxKernel();
712         lz.benchMatmulKernel();
713         lz.clear();
714     }
715 
716 
717     public static class Arg {
718         private final String name;
719         private final Object value;
720         private final Class<?> cls;
721         private int size;
722         private boolean needsCleanup;
723         private MemorySegment dataSegment;
724 
725         public static Arg createArg(LevelZero lz, String name, Object value) {
726             Arg arg = new Arg(name, value);
727             lz.provisionArg(arg);
728             return arg;
729         }
730 
731         private Arg(String name, Object value) {
732             this.name = name;
733             this.cls = value.getClass();
734             this.value = value;
735         }
736 
737         public String name() {
738             return name;
739         }
740 
741         public Object value() {
742             return value;
743         }
744 
745         public Class<?> cls() {
746             return cls;
747         }
748 
749         public void setSize(int size) {
750             this.size = size;
751         }
752 
753         public int size() {
754             return size;
755         }
756 
757         public void setDataSegment(MemorySegment segment) {
758             dataSegment = segment;
759         }
760 
761         public MemorySegment dataSegment() {
762             return dataSegment;
763         }
764 
765         public void setNeedsCleanup(boolean needsCleanup) {
766             this.needsCleanup = needsCleanup;
767         }
768 
769         public boolean needsCleanup() {
770             return needsCleanup;
771         }
772 
773         public String toString() {
774             return String.format("name = %s, cls = %s", name, cls);
775         }
776     }
777 
778     private JSONObject loadJson(String fileName) {
779         StringBuilder jsonString = new StringBuilder();
780         try (BufferedReader br = new BufferedReader(new FileReader(fileName)))
781         {
782             String line;
783             while ((line = br.readLine()) != null) {
784                 jsonString.append(line);
785             }
786         } catch (IOException e) {
787             e.printStackTrace();
788         }
789         JSONObject jsonObject = new JSONObject(jsonString.toString());
790         return jsonObject;
791     }
792 
793     private class Test {
794         public static float[] add(float[] a, float[] b, int SIZE) {
795             float[] output = new float[SIZE];
796             for (int i = 0; i < SIZE; ++i)
797                 output[i] = a[i] + b[i];
798             return output;
799         }
800         public static float[] softmax(float[] a, int X, int Y) {
801             float[] output = new float[X * Y];
802             for (int i = 0; i < X; ++i) {
803                 float max = Float.MIN_VALUE;
804                 for (int j = 0; j < Y; ++j) {
805                     max = Math.max(max, a[i * Y + j]);
806                 }
807                 float sum = 0;
808                 for (int j = 0; j < Y; ++j) {
809                     output[i * Y + j] = (float)Math.exp(a[i * Y + j] - max);
810                     sum += output[i * Y + j];
811                 }
812                 for (int j = 0; j < Y; ++j) {
813                     output[i * Y + j] /= sum;
814                 }
815             }
816             return output;
817         }
818         public static float[] matmul(float[] a, float[] b, int M, int N, int K) {
819             float[] output = new float[M * N];
820             for (int i = 0; i < M; i++) {
821                 for (int j = 0; j < N; j++) {
822                     float tmp = 0;
823                     for (int k = 0; k < K; k++) {
824                         tmp += a[i * K + k] * b[k * N + j];
825                     }
826                     output[i * N + j] = tmp;
827                 }
828             }
829             return output;
830         }
831         public static void check(float[] expected, float[] output) {
832             for (int i = 0; i < expected.length; i++) {
833                 if (Math.abs(expected[i] - output[i]) > 1e-2) {
834                     System.out.printf("Mismatch at %d: %f != %f%n", i, expected[i], output[i]);
835                     throw new RuntimeException("Mismatch");
836                 }
837             }
838             System.out.println("Test passed");
839         }
840     }
841 
842     private class Writer {
843         private final String fileName;
844         private final FileWriter writer;
845 
846         public Writer(String fileName) {
847             this.fileName = fileName;
848             try {
849                 writer = new FileWriter(fileName, false);
850             } catch (IOException e) {
851                 throw new RuntimeException(e);
852             }
853         }
854 
855         public void write(String line) {
856             try {
857                 writer.write(line);
858             } catch (IOException e) {
859                 throw new RuntimeException(e);
860             }
861         }
862 
863         public void close() {
864             try {
865                 writer.close();
866             } catch (IOException e) {
867                 throw new RuntimeException(e);
868             }
869         }
870     }
871 }