1 import java.io.File; 2 import java.io.FileInputStream; 3 import java.util.List; 4 import java.util.Map; 5 import java.util.HashMap; 6 import java.util.ArrayList; 7 import java.util.Deque; 8 import java.util.ArrayDeque; 9 import java.util.Set; 10 import java.util.HashSet; 11 import java.util.Iterator; 12 import java.util.Arrays; 13 import java.util.Queue; 14 import java.util.Optional; 15 import java.util.stream.Stream; 16 import java.util.function.Consumer; 17 import java.util.function.Function; 18 import java.util.function.Supplier; 19 import java.time.*; 20 import java.nio.file.Files; 21 import java.nio.file.Paths; 22 import java.io.IOException; 23 import java.io.BufferedReader; 24 import java.io.FileWriter; 25 import java.io.FileReader; 26 import java.lang.reflect.Method; 27 import java.lang.ref.Cleaner; 28 import java.lang.foreign.MemorySegment; 29 import java.lang.foreign.ValueLayout; 30 import java.lang.foreign.AddressLayout; 31 import java.lang.foreign.Arena; 32 import java.lang.foreign.MemorySegment.Scope; 33 import static java.lang.foreign.ValueLayout.*; 34 import jdk.incubator.vector.VectorSpecies; 35 import jdk.incubator.vector.FloatVector; 36 import static oneapi.levelzero.ze_api_h.*; 37 import oneapi.levelzero.ze_api_h; 38 import oneapi.levelzero.ze_context_desc_t; 39 import oneapi.levelzero.ze_kernel_desc_t; 40 import oneapi.levelzero.ze_command_queue_desc_t; 41 import oneapi.levelzero.ze_command_list_desc_t; 42 import oneapi.levelzero.ze_command_queue_group_properties_t; 43 import oneapi.levelzero.ze_event_pool_desc_t; 44 import oneapi.levelzero.ze_event_desc_t; 45 import oneapi.levelzero.ze_fence_desc_t; 46 import oneapi.levelzero.ze_module_desc_t; 47 import oneapi.levelzero.ze_group_count_t; 48 import oneapi.levelzero.ze_host_mem_alloc_desc_t; 49 import oneapi.levelzero.ze_device_mem_alloc_desc_t; 50 import oneapi.levelzero.ze_device_properties_t; 51 import oneapi.levelzero.ze_device_compute_properties_t; 52 import oneapi.levelzero.ze_driver_properties_t; 53 import oneapi.levelzero.ze_driver_extension_properties_t; 54 import org.json.JSONArray; 55 import org.json.JSONObject; 56 57 import java.util.Random; 58 59 public class LevelZero { 60 public static final AddressLayout driver_handle_t = AddressLayout.ADDRESS; 61 private final Arena arena; 62 private final MemorySegment driverHandle; 63 private final MemorySegment contextHandle; 64 private final MemorySegment deviceHandle; 65 private final MemorySegment queueHandle; 66 private final MemorySegment eventPoolDescription; 67 private final String homeDir = System.getProperty("user.home"); 68 private final String cacheDir = homeDir + "/.triton/cache/"; 69 private final String addKernelCache = "7961f2e8b433c656051d8638d6a3bb65f43f6cb885525c05d611100dd905aa31"; 70 private final String softmaxKernelCache = "f0c32acd1173759227ef8e0e8d197c94493b90ebf8d1fc254399ffac6b527d6a"; 71 private final String matmulKernelCache = "07e17c2833c9c9efea8ccd782af1c3ee05dcac3efb2cb75f2f8a6eecffe381ef"; 72 private final static VectorSpecies<Float> SPECIES = FloatVector.SPECIES_256; 73 private Double timeElapsedForRun, timeElapsedForRerun; 74 75 static { 76 System.loadLibrary("ze_loader"); 77 } 78 79 static void debug(String format, Object... args) { 80 System.out.printf(format + "%n", args); 81 } 82 83 private static void check(int result) { 84 if (result != ZE_RESULT_SUCCESS()) { 85 throw new RuntimeException(String.format("Call failed: 0x%x (%d)", result, result)); 86 } 87 } 88 89 MemorySegment contextHandle() { 90 return contextHandle; 91 } 92 93 MemorySegment deviceHandle() { 94 return deviceHandle; 95 } 96 97 public LevelZero() { 98 arena = Arena.ofShared(); 99 100 // get driver 101 check(zeInit(ZE_INIT_FLAG_GPU_ONLY())); 102 MemorySegment driverCount = arena.allocate(Integer.BYTES); 103 check(zeDriverGet(driverCount, MemorySegment.NULL)); 104 debug("driverCount = %d", driverCount.get(JAVA_INT, 0)); 105 MemorySegment driverHandles = arena.allocate(driverCount.get(JAVA_INT, 0) * driver_handle_t.byteSize(), 8); 106 check(zeDriverGet(driverCount, driverHandles)); 107 driverHandle = driverHandles.get(ADDRESS, 0); 108 109 // create context 110 MemorySegment pContextDesc = arena.allocate(ze_context_desc_t.layout()); 111 ze_context_desc_t.stype(pContextDesc, ZE_STRUCTURE_TYPE_CONTEXT_DESC()); 112 MemorySegment pContextHandle = arena.allocate(ze_context_handle_t); 113 check(zeContextCreate(driverHandle, pContextDesc, pContextHandle)); 114 contextHandle = pContextHandle.get(ADDRESS, 0); 115 116 // get device 117 MemorySegment pDeviceCount = arena.allocate(Integer.BYTES); 118 check(zeDeviceGet(driverHandle, pDeviceCount, MemorySegment.NULL)); 119 int deviceCount = pDeviceCount.get(JAVA_INT, 0); 120 assert deviceCount > 0; 121 debug("deviceCount = %d", deviceCount); 122 MemorySegment deviceHandles = arena.allocate(deviceCount * ze_device_handle_t.byteSize(), 8); 123 check(zeDeviceGet(driverHandle, pDeviceCount, deviceHandles)); 124 for (int i = 0; i < deviceCount; i++) { 125 debug("device #%d: %s", i, deviceHandles.get(ze_device_handle_t, i * ze_device_handle_t.byteSize())); 126 } 127 deviceHandle = deviceHandles.get(ze_device_handle_t, 0 * ze_device_handle_t.byteSize()); 128 MemorySegment pDeviceProperties = arena.allocate(ze_device_properties_t.layout()); 129 ze_device_properties_t.stype(pDeviceProperties, ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES()); 130 check(zeDeviceGetProperties(deviceHandle, pDeviceProperties)); 131 debug("deviceProperties:\n\ttype = %d\n\tvendorId = %d\n\tmaxMemAllocSize = %d\n\tdeviceId = %d\n\tcoreClockRate = %d", 132 ze_device_properties_t.type(pDeviceProperties), 133 ze_device_properties_t.vendorId(pDeviceProperties), 134 ze_device_properties_t.maxMemAllocSize(pDeviceProperties), 135 ze_device_properties_t.deviceId(pDeviceProperties), 136 ze_device_properties_t.coreClockRate(pDeviceProperties)); 137 138 MemorySegment pDeviceComputeProperties = arena.allocate(ze_device_compute_properties_t.layout()); 139 ze_device_compute_properties_t.stype(pDeviceComputeProperties, ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES()); 140 check(zeDeviceGetComputeProperties(deviceHandle, pDeviceComputeProperties)); 141 debug("deviceProperties:\n\tshared = %d\n\tmaxTotalGroupSize = %d", 142 ze_device_compute_properties_t.maxSharedLocalMemory(pDeviceComputeProperties), 143 ze_device_compute_properties_t.maxTotalGroupSize(pDeviceComputeProperties)); 144 145 // create queue 146 MemorySegment pNumQueueGroups = arena.allocate(JAVA_INT, 1); 147 check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, MemorySegment.NULL)); 148 debug("#Queue Groups: %d", pNumQueueGroups.get(JAVA_INT, 0)); 149 MemorySegment pGroupProperties = arena.allocate(ze_command_queue_group_properties_t.layout(), pNumQueueGroups.get(JAVA_INT, 0)); 150 check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, pGroupProperties)); 151 152 MemorySegment pQueueDesc = arena.allocate(ze_command_queue_desc_t.layout()); 153 ze_command_queue_desc_t.stype(pQueueDesc, ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC()); 154 ze_command_queue_desc_t.index(pQueueDesc, 0); 155 ze_command_queue_desc_t.mode(pQueueDesc, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS()); 156 ze_command_queue_desc_t.ordinal(pQueueDesc, 0); 157 MemorySegment pQueueHandle = arena.allocate(ze_command_queue_handle_t); 158 check(zeCommandQueueCreate(contextHandle, deviceHandle, pQueueDesc, pQueueHandle)); 159 queueHandle = pQueueHandle.get(ADDRESS, 0); 160 161 eventPoolDescription = arena.allocate(ze_event_pool_desc_t.layout()); 162 ze_event_pool_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_EVENT_POOL_DESC()); 163 ze_event_pool_desc_t.count(eventPoolDescription, 20); 164 ze_event_pool_desc_t.flags(eventPoolDescription, ZE_EVENT_POOL_FLAG_HOST_VISIBLE()); 165 166 timeElapsedForRun = timeElapsedForRerun = 0.0; 167 } 168 169 public void clear() { 170 check(zeCommandQueueDestroy(queueHandle)); 171 check(zeContextDestroy(contextHandle)); 172 } 173 174 public void test(String testName) { 175 Object[] args = {}; 176 Random rand = new Random(); 177 if (testName.equals("add")) { 178 String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json"; 179 String moduleName = cacheDir + addKernelCache + "/add_kernel.spv"; 180 181 int BLOCK_SIZE = 64; 182 int elementSize = 4096; 183 int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE; 184 185 JSONObject jsonObject = loadJson(jsonFileName); 186 String kernelName = jsonObject.getString("name"); 187 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 188 int num_warps = jsonObject.getInt("num_warps"); 189 int shared = jsonObject.getInt("shared"); 190 191 float[] input1 = new float[elementSize]; 192 float[] input2 = new float[elementSize]; 193 float[] output = new float[elementSize]; 194 for (int i = 0; i < elementSize; i++) { 195 input1[i] = rand.nextFloat(); 196 input2[i] = rand.nextFloat(); 197 } 198 args = new Object[] {input1, input2, output, elementSize}; 199 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 200 201 float[] expected = Test.add(input1, input2, elementSize); 202 Test.check(expected, output); 203 } else if (testName.equals("softmax")) { 204 String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json"; 205 String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv"; 206 207 JSONObject jsonObject = loadJson(jsonFileName); 208 String kernelName = jsonObject.getString("name"); 209 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 210 int num_warps = jsonObject.getInt("num_warps"); 211 int shared = jsonObject.getInt("shared"); 212 213 int elementSizeX = 4096, elementSizeY = 64; 214 int gridSize = elementSizeX; 215 float[] input = new float[elementSizeX * elementSizeY]; 216 float[] output = new float[elementSizeX * elementSizeY]; 217 byte[] sharedMemory = new byte[shared]; // use for storing temporary value of max element and sum of exp 218 for (int i = 0; i < elementSizeX * elementSizeY; i++) { 219 input[i] = rand.nextFloat(); 220 } 221 args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory}; 222 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 223 224 float[] expected = Test.softmax(input, elementSizeX, elementSizeY); 225 Test.check(expected, output); 226 } else if (testName.equals("matmul")) { 227 String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json"; 228 String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv"; 229 230 JSONObject jsonObject = loadJson(jsonFileName); 231 String kernelName = jsonObject.getString("name"); 232 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 233 int num_warps = jsonObject.getInt("num_warps"); 234 int shared = jsonObject.getInt("shared"); 235 236 int M = 1024, N = 1024, K = 1024; 237 int BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 64; 238 int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N); 239 float[] a = new float[M * K]; 240 float[] b = new float[K * N]; 241 float[] c = new float[M * N]; 242 byte[] sharedMemory = new byte[shared]; 243 244 for (int i = 0; i < M * K; i++) { 245 a[i] = rand.nextFloat(); 246 } 247 for (int i = 0; i < K * N; i++) { 248 b[i] = rand.nextFloat(); 249 } 250 args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory}; 251 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 252 253 float[] expected = Test.matmul(a, b, M, N, K); 254 Test.check(expected, c); 255 } else { 256 throw new RuntimeException("Unsupported test: " + testName); 257 } 258 } 259 260 public void run(String kernelName, String fileName, Object[] args, int threads_per_warp, int num_warps, int shared, int gridSize) { 261 debug("=========== run %s ===========", kernelName); 262 MemorySegment spirvBinary = loadModule(fileName); 263 List<Arg> kernelArgs = collectArgs(args); 264 int[] globalSizes = new int[] {gridSize * threads_per_warp * num_warps, 1, 1}; 265 int[] localSizes = new int[] {threads_per_warp * num_warps, 1, 1}; 266 KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes); 267 debug("geometry = %s", geometry); 268 MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, shared, false); 269 executeCommandList(commandListHandle); 270 check(zeCommandQueueSynchronize(queueHandle, -1L)); 271 for (int i = 0; i < kernelArgs.size(); i++) { 272 copyArgToHost(kernelArgs.get(i), contextHandle); 273 } 274 275 for (int i = 0; i < kernelArgs.size(); i++) { 276 Arg arg = kernelArgs.get(i); 277 MemorySegment dataSegment = arg.dataSegment(); 278 if (dataSegment != null) { 279 check(zeMemFree(contextHandle, dataSegment)); 280 } 281 } 282 check(zeCommandListDestroy(commandListHandle)); 283 } 284 285 public void runRefMatmul(String kernelName, String fileName, Object[] args, int size) { 286 debug("=========== run %s ===========", kernelName); 287 MemorySegment spirvBinary = loadModule(fileName); 288 List<Arg> kernelArgs = collectArgs(args); 289 int[] globalSizes = new int[] {size, size, 1}; 290 int[] localSizes = new int[] {512, 1, 1}; 291 KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes); 292 debug("geometry = %s", geometry); 293 MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, 0, true); 294 executeCommandList(commandListHandle); 295 check(zeCommandQueueSynchronize(queueHandle, -1L)); 296 for (int i = 0; i < kernelArgs.size(); i++) { 297 copyArgToHost(kernelArgs.get(i), contextHandle); 298 } 299 300 for (int i = 0; i < kernelArgs.size(); i++) { 301 Arg arg = kernelArgs.get(i); 302 MemorySegment dataSegment = arg.dataSegment(); 303 if (dataSegment != null) { 304 check(zeMemFree(contextHandle, dataSegment)); 305 } 306 } 307 check(zeCommandListDestroy(commandListHandle)); 308 } 309 310 private List<Arg> collectArgs(Object[] values) { 311 List<Arg> args = new ArrayList<>(); 312 for (int i = 0; i < values.length; i++) { 313 args.add(Arg.createArg(this, "arg" + i, values[i])); 314 } 315 debug("args = %s", args); 316 return args; 317 } 318 319 MemorySegment loadModule(String fileName) { 320 byte[] data = readBytes(fileName); 321 MemorySegment segment = arena.allocate(data.length); 322 segment.copyFrom(MemorySegment.ofArray(data)); 323 return segment; 324 } 325 326 byte[] readBytes(String filename) { 327 File file = new File(filename); 328 try (FileInputStream fis = new FileInputStream(file)) { 329 byte[] data = new byte[(int) file.length()]; 330 fis.read(data); 331 return data; 332 } catch (IOException e) { 333 throw new RuntimeException(e); 334 } 335 } 336 337 void provisionArg(Arg arg) { 338 if (arg.cls() == byte[].class) { 339 byte[] array = (byte[])arg.value(); 340 int segmentSize = array.length; 341 arg.setDataSegment(allocateSharedSegment(segmentSize)); 342 arg.dataSegment().copyFrom(MemorySegment.ofArray(array)); 343 arg.setSize(8); 344 arg.setNeedsCleanup(true); 345 } 346 else if (arg.cls() == short[].class) { 347 short[] array = (short[])arg.value(); 348 int segmentSize = array.length * Short.BYTES; 349 arg.setDataSegment(allocateSharedSegment(segmentSize)); 350 arg.dataSegment().copyFrom(MemorySegment.ofArray(array)); 351 arg.setSize(8); 352 arg.setNeedsCleanup(true); 353 } 354 else if (arg.cls() == int[].class) { 355 int[] array = (int[])arg.value(); 356 int segmentSize = array.length * Integer.BYTES; 357 arg.setDataSegment(allocateSharedSegment(segmentSize)); 358 arg.dataSegment().copyFrom(MemorySegment.ofArray(array)); 359 arg.setSize(8); 360 arg.setNeedsCleanup(true); 361 } 362 else if (arg.cls() == float[].class) { 363 float[] array = (float[])arg.value(); 364 int segmentSize = array.length * Float.BYTES; 365 arg.setDataSegment(allocateSharedSegment(segmentSize)); 366 arg.dataSegment().copyFrom(MemorySegment.ofArray(array)); 367 arg.setSize(8); 368 arg.setNeedsCleanup(true); 369 } 370 else if (VectorSpecies.class.isAssignableFrom(arg.cls())) { 371 arg.setSize(4); 372 } 373 else if (arg.cls() == Short.class) { 374 arg.setSize(2); 375 } 376 else if (arg.cls() == Integer.class || arg.cls() == Float.class || arg.cls() == Boolean.class) { 377 arg.setSize(4); 378 } 379 else if (arg.cls() == Long.class) { 380 arg.setSize(8); 381 } 382 else if (arg.cls() == GPU.Index.class) { 383 MemorySegment pBuffer = arena.allocate(ADDRESS); 384 arg.setDataSegment(allocateSharedSegment(24)); 385 arg.setSize(24); 386 } 387 else throw new RuntimeException("unsupported type: " + arg.cls()); 388 } 389 390 void copyArgToHost(Arg arg, MemorySegment contextHandle) { 391 if (arg.cls() == short[].class) { 392 short[] array = (short[])arg.value(); 393 MemorySegment arraySegment = MemorySegment.ofArray(array); 394 arraySegment.copyFrom(arg.dataSegment()); 395 } 396 else if (arg.cls() == int[].class) { 397 int[] array = (int[])arg.value(); 398 MemorySegment arraySegment = MemorySegment.ofArray(array); 399 arraySegment.copyFrom(arg.dataSegment()); 400 } 401 else if (arg.cls() == float[].class) { 402 float[] array = (float[])arg.value(); 403 MemorySegment arraySegment = MemorySegment.ofArray(array); 404 arraySegment.copyFrom(arg.dataSegment()); 405 } 406 // else nothing to do 407 } 408 409 private MemorySegment createCommandList(MemorySegment spirvModule, String kernelName, KernelGeometry geometry, List<Arg> args, int shared, boolean suggested) { 410 Arena arena = Arena.ofShared(); 411 MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t); 412 MemorySegment commandListDesc = arena.allocate(ze_command_list_desc_t.layout()); 413 ze_command_list_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC()); 414 ze_command_list_desc_t.commandQueueGroupOrdinal(commandListDesc, 0); 415 MemorySegment moduleHandle = createModule(kernelName, spirvModule); 416 check(zeCommandListCreate(contextHandle, deviceHandle, commandListDesc, pCommandListHandle)); 417 MemorySegment commandListHandle = pCommandListHandle.get(ADDRESS, 0); 418 MemorySegment kernelHandle = createKernel(moduleHandle, kernelName, geometry, suggested); 419 for (int i = 0; i < args.size(); i++) { 420 Arg arg = args.get(i); 421 setKernelArg(arg, i, commandListHandle, kernelHandle, (shared != 0) && (i == args.size() - 1)); 422 } 423 MemorySegment groupCount = arena.allocate(ze_group_count_t.layout()); 424 ze_group_count_t.groupCountX(groupCount, (geometry.globalSizes()[0] + geometry.localSizes()[0] - 1) / geometry.localSizes()[0]); 425 ze_group_count_t.groupCountY(groupCount, (geometry.globalSizes()[1] + geometry.localSizes()[1] - 1) / geometry.localSizes()[1]); 426 ze_group_count_t.groupCountZ(groupCount, (geometry.globalSizes()[2] + geometry.localSizes()[2] - 1) / geometry.localSizes()[2]); 427 MemorySegment pKernelWaitHandles = MemorySegment.NULL; 428 check(zeCommandListAppendLaunchKernel(commandListHandle, kernelHandle, groupCount, MemorySegment.NULL, 0, pKernelWaitHandles)); 429 check(zeCommandListClose(commandListHandle)); 430 return commandListHandle; 431 } 432 433 private MemorySegment executeCommandList(MemorySegment commandListHandle) { 434 MemorySegment fenceDesc = arena.allocate(ze_fence_desc_t.layout()); 435 ze_module_desc_t.stype(fenceDesc, ZE_STRUCTURE_TYPE_FENCE_DESC()); 436 ze_fence_desc_t.flags(fenceDesc, ZE_FENCE_FLAG_SIGNALED()); 437 MemorySegment pFenceHandle = arena.allocate(ze_fence_handle_t); 438 check(zeFenceCreate(queueHandle, fenceDesc, pFenceHandle)); 439 MemorySegment fenceHandle = pFenceHandle.get(ADDRESS, 0); 440 MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t); 441 pCommandListHandle.set(ADDRESS, 0, commandListHandle); 442 Instant start = Instant.now(); 443 check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle)); 444 check(zeCommandQueueSynchronize(queueHandle, -1L)); 445 Instant finish = Instant.now(); 446 Double timeElapsed = Duration.between(start, finish).toNanos() * 1e-6; 447 timeElapsedForRun += timeElapsed; 448 debug("time: %f %f\n", timeElapsed, timeElapsedForRun); 449 450 start = Instant.now(); 451 check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle)); 452 check(zeCommandQueueSynchronize(queueHandle, -1L)); 453 finish = Instant.now(); 454 timeElapsed = Duration.between(start, finish).toNanos() * 1e-6; 455 timeElapsedForRerun += timeElapsed; 456 debug("time for rerun: %f %f\n", timeElapsed, timeElapsedForRerun); 457 return fenceHandle; 458 } 459 460 private MemorySegment createKernel(MemorySegment moduleHandle, String kernelNameString, KernelGeometry geometry, boolean suggested) { 461 MemorySegment kernelDesc = arena.allocate(ze_kernel_desc_t.layout()); 462 MemorySegment kernelName = arena.allocateFrom(kernelNameString); 463 ze_kernel_desc_t.stype(kernelDesc, ZE_STRUCTURE_TYPE_KERNEL_DESC()); 464 ze_kernel_desc_t.pKernelName(kernelDesc, kernelName); 465 debug("name = %s", kernelNameString); 466 MemorySegment pKernelHandle = arena.allocate(ze_kernel_handle_t); 467 check(zeKernelCreate(moduleHandle, kernelDesc, pKernelHandle)); 468 int[] globalSizes = geometry.globalSizes(); 469 int[] localSizes = geometry.localSizes(); 470 MemorySegment kernelHandle = pKernelHandle.get(ADDRESS, 0); 471 if (suggested) { 472 MemorySegment pGroupSizeX = arena.allocate(JAVA_INT, localSizes[0]); 473 MemorySegment pGroupSizeY = arena.allocate(JAVA_INT, localSizes[1]); 474 MemorySegment pGroupSizeZ = arena.allocate(JAVA_INT, localSizes[2]); 475 check(zeKernelSuggestGroupSize(kernelHandle, globalSizes[0], globalSizes[1], globalSizes[2], pGroupSizeX, pGroupSizeY, pGroupSizeZ)); 476 geometry.localSizes()[0] = pGroupSizeX.get(JAVA_INT, 0); 477 geometry.localSizes()[1] = pGroupSizeY.get(JAVA_INT, 0); 478 geometry.localSizes()[2] = pGroupSizeZ.get(JAVA_INT, 0); 479 debug("use suggested group size", geometry.toString()); 480 check(zeKernelSetGroupSize(kernelHandle, pGroupSizeX.get(JAVA_INT, 0), pGroupSizeY.get(JAVA_INT, 0), pGroupSizeZ.get(JAVA_INT, 0))); 481 } else { 482 debug("use localSizes", geometry.toString()); 483 check(zeKernelSetGroupSize(kernelHandle, localSizes[0], localSizes[1], localSizes[2])); 484 } 485 return kernelHandle; 486 } 487 488 private void setKernelArg(Arg arg, int ordinal, MemorySegment commandListHandle, MemorySegment kernelHandle, boolean shared) { 489 MemorySegment dataSegment = arg.dataSegment(); 490 Class<?> cls = arg.cls(); 491 debug("ordinal = %d, cls = %s, data = %s", ordinal, cls.getSimpleName(), dataSegment); 492 if (shared) { // shared memory 493 check(zeKernelSetArgumentValue(kernelHandle, ordinal, dataSegment.byteSize(), dataSegment)); 494 } 495 else if (cls == byte[].class || cls == short[].class || cls == int[].class || cls == float[].class || cls.getSimpleName().equals("NativeMemorySegmentImpl")) { 496 check(zeCommandListAppendMemoryPrefetch(commandListHandle, dataSegment, dataSegment.byteSize())); 497 check(zeCommandListAppendMemAdvise(commandListHandle, deviceHandle, dataSegment, dataSegment.byteSize(), ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION())); 498 MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment); 499 check(zeKernelSetArgumentValue(kernelHandle, ordinal, ADDRESS.byteSize(), pDataSegment)); 500 } 501 else if (cls == Short.class) { 502 MemorySegment pArgValue = arena.allocateFrom(JAVA_SHORT, (short)arg.value()); 503 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Short.BYTES, pArgValue)); 504 } 505 else if (VectorSpecies.class.isAssignableFrom(cls)) { 506 MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, FloatVector.SPECIES_256.length()); 507 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue)); 508 } 509 else if (cls == Integer.class || cls == Boolean.class) { 510 MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, (int)arg.value()); 511 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue)); 512 } 513 else if (cls == Long.class) { 514 MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, (long)arg.value()); 515 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Long.BYTES, pArgValue)); 516 } 517 else if (cls == Float.class) { 518 MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, Float.floatToIntBits((float)arg.value())); 519 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Float.BYTES, pArgValue)); 520 } 521 else if (cls == GPU.Index.class) { 522 MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment); 523 check(zeKernelSetArgumentValue(kernelHandle, ordinal, 24, pDataSegment)); 524 } 525 else throw new RuntimeException("unsupported type: " + cls); 526 } 527 528 private MemorySegment createModule(String moduleName, MemorySegment spirvCode) { 529 MemorySegment pModuleHandle = arena.allocate(ze_module_handle_t); 530 MemorySegment moduleDesc = arena.allocate(ze_module_desc_t.layout()); 531 ze_module_desc_t.stype(moduleDesc, ZE_STRUCTURE_TYPE_MODULE_DESC()); 532 ze_module_desc_t.format(moduleDesc, ZE_MODULE_FORMAT_IL_SPIRV()); 533 ze_module_desc_t.pInputModule(moduleDesc, spirvCode); 534 ze_module_desc_t.inputSize(moduleDesc, spirvCode.byteSize()); 535 ze_module_desc_t.pBuildFlags(moduleDesc, arena.allocateFrom("")); 536 MemorySegment buildLogHandle = arena.allocate(ze_module_build_log_handle_t); 537 check(zeModuleCreate(contextHandle, deviceHandle, moduleDesc, pModuleHandle, buildLogHandle)); 538 MemorySegment moduleHandle = pModuleHandle.get(ADDRESS, 0); 539 return moduleHandle; 540 } 541 542 public MemorySegment allocateSharedSegment(long byteSize) { 543 return allocateSharedSegment(contextHandle(), deviceHandle(), byteSize, Arena.global()); 544 } 545 546 public static MemorySegment allocateSharedSegment(MemorySegment contextHandle, MemorySegment deviceHandle, long byteSize, Arena arena) { 547 MemorySegment pDeviceMemAllocDesc = arena.allocate(ze_device_mem_alloc_desc_t.layout()); 548 ze_device_mem_alloc_desc_t.stype(pDeviceMemAllocDesc, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC()); 549 ze_device_mem_alloc_desc_t.ordinal(pDeviceMemAllocDesc, 0); 550 MemorySegment pHostMemAllocDesc = arena.allocate(ze_host_mem_alloc_desc_t.layout()); 551 ze_host_mem_alloc_desc_t.stype(pHostMemAllocDesc, ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC()); 552 MemorySegment pBuffer = arena.allocate(ADDRESS); 553 check(zeMemAllocShared(contextHandle, pDeviceMemAllocDesc, pHostMemAllocDesc, byteSize, 1, deviceHandle, pBuffer)); 554 long address = pBuffer.get(JAVA_LONG, 0); 555 return MemorySegment.ofAddress(address).reinterpret(byteSize); 556 } 557 558 private static record KernelGeometry(int[] globalSizes, int[] localSizes) { 559 public KernelGeometry() { 560 this(new int[3], new int[] {512, 1, 1}); 561 } 562 563 @Override 564 public String toString() { 565 return String.format("global: %s, local: %s", Arrays.toString(globalSizes), Arrays.toString(localSizes)); 566 } 567 } 568 569 public void benchAddKernel() { 570 String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json"; 571 String moduleName = cacheDir + addKernelCache + "/add_kernel.spv"; 572 JSONObject jsonObject = loadJson(jsonFileName); 573 String kernelName = jsonObject.getString("name"); 574 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 575 int num_warps = jsonObject.getInt("num_warps"); 576 int shared = jsonObject.getInt("shared"); 577 Random rand = new Random(); 578 579 Writer writer = new Writer("benchmark/vector_add_benchmark.txt"); 580 writer.write("elementSize timeElapsed timeElapsedForRerun RTT gb/s \n"); 581 582 for (int elementSize = (1 << 12); elementSize <= (1 << 28); elementSize <<= 1) { 583 int BLOCK_SIZE = 1024; 584 int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE; 585 586 float[] input1 = new float[elementSize]; 587 float[] input2 = new float[elementSize]; 588 float[] output = new float[elementSize]; 589 for (int i = 0; i < elementSize; i++) { 590 input1[i] = rand.nextFloat(); 591 input2[i] = rand.nextFloat(); 592 } 593 Object[] args = new Object[] {input1, input2, output, elementSize}; 594 595 // warmup 596 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 597 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0; 598 599 int nTimes = 10; 600 Instant start = Instant.now(); 601 for (int i = 0; i < nTimes; ++i) 602 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 603 Instant finish = Instant.now(); 604 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes; 605 Double timeElapsedForRun = this.timeElapsedForRun / nTimes; 606 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes; 607 writer.write(String.format("%d %.4f %.4f %.4f %.4f\n", elementSize, timeElapsedForRun, timeElapsedForRerun, RTT, (4f * 3f * elementSize / timeElapsedForRerun * 1e-6))); 608 } 609 writer.close(); 610 } 611 612 public void benchSoftmaxKernel() { 613 String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json"; 614 String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv"; 615 JSONObject jsonObject = loadJson(jsonFileName); 616 String kernelName = jsonObject.getString("name"); 617 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 618 int num_warps = jsonObject.getInt("num_warps"); 619 int shared = jsonObject.getInt("shared"); 620 Random rand = new Random(); 621 622 Writer writer = new Writer("benchmark/softmax_benchmark.txt"); 623 writer.write("elementSizeX elementSizeY timeElapsed timeElapsedForRerun RTT gb/s \n"); 624 625 for (int i = 2; i < 50; i++) { 626 int elementSizeX = 4096; 627 int elementSizeY = 128 * i; 628 int gridSize = elementSizeX; 629 float[] input = new float[elementSizeX * elementSizeY]; 630 float[] output = new float[elementSizeX * elementSizeY]; 631 byte[] sharedMemory = new byte[shared]; 632 for (int j = 0; j < elementSizeX * elementSizeY; j++) { 633 input[j] = rand.nextFloat(); 634 } 635 Object[] args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory}; 636 637 // warmup 638 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 639 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0; 640 641 int nTimes = 10; 642 Instant start = Instant.now(); 643 for (int j = 0; j < nTimes; ++j) 644 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 645 Instant finish = Instant.now(); 646 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes; 647 Double timeElapsedForRun = this.timeElapsedForRun / nTimes; 648 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes; 649 writer.write(String.format("%d %d %.4f %.4f %.4f %.4f\n", elementSizeX, elementSizeY, timeElapsedForRun, timeElapsedForRerun, RTT, (4 * 2 * 1e-9 * elementSizeX * elementSizeY / (timeElapsedForRerun * 1e-3)))); 650 } 651 writer.close(); 652 } 653 654 public void benchMatmulKernel() { 655 String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json"; 656 String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv"; 657 658 JSONObject jsonObject = loadJson(jsonFileName); 659 String kernelName = jsonObject.getString("name"); 660 int threads_per_warp = jsonObject.getInt("threads_per_warp"); 661 int num_warps = jsonObject.getInt("num_warps"); 662 int shared = jsonObject.getInt("shared"); 663 Random rand = new Random(); 664 int BLOCK_SIZE_M = 128, BLOCK_SIZE_N = 64; 665 666 667 Writer writer = new Writer("benchmark/matmul_benchmark.txt"); 668 writer.write("M N K timeElapsed timeElapsedForRerun RTT TFLOPS \n"); 669 670 for (int i = 2; i <= 64; i++) { 671 int M = 128 * i; 672 int N = 128 * i; 673 int K = 128 * i; 674 int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N); 675 676 float[] a = new float[M * K]; 677 float[] b = new float[K * N]; 678 float[] c = new float[M * N]; 679 byte[] sharedMemory = new byte[shared]; 680 for (int j = 0; j < M * K; j++) { 681 a[j] = rand.nextFloat(); 682 } 683 for (int j = 0; j < K * N; j++) { 684 b[j] = rand.nextFloat(); 685 } 686 Object[] args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory}; 687 688 // warmup 689 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 690 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0; 691 692 int nTimes = 10; 693 Instant start = Instant.now(); 694 for (int j = 0; j < nTimes; ++j) 695 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize); 696 Instant finish = Instant.now(); 697 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes; 698 Double timeElapsedForRun = this.timeElapsedForRun / nTimes; 699 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes; 700 writer.write(String.format("%d %d %d %.4f %.4f %.4f %.4f\n", M, N, K, timeElapsedForRun, timeElapsedForRerun, RTT, (2 * 1e-12 * M * N * K / (timeElapsedForRerun * 1e-3)))); 701 } 702 writer.close(); 703 } 704 705 public static void main(String[] args) { 706 LevelZero lz = new LevelZero(); 707 lz.test("add"); 708 lz.test("softmax"); 709 lz.test("matmul"); 710 lz.benchAddKernel(); 711 lz.benchSoftmaxKernel(); 712 lz.benchMatmulKernel(); 713 lz.clear(); 714 } 715 716 717 public static class Arg { 718 private final String name; 719 private final Object value; 720 private final Class<?> cls; 721 private int size; 722 private boolean needsCleanup; 723 private MemorySegment dataSegment; 724 725 public static Arg createArg(LevelZero lz, String name, Object value) { 726 Arg arg = new Arg(name, value); 727 lz.provisionArg(arg); 728 return arg; 729 } 730 731 private Arg(String name, Object value) { 732 this.name = name; 733 this.cls = value.getClass(); 734 this.value = value; 735 } 736 737 public String name() { 738 return name; 739 } 740 741 public Object value() { 742 return value; 743 } 744 745 public Class<?> cls() { 746 return cls; 747 } 748 749 public void setSize(int size) { 750 this.size = size; 751 } 752 753 public int size() { 754 return size; 755 } 756 757 public void setDataSegment(MemorySegment segment) { 758 dataSegment = segment; 759 } 760 761 public MemorySegment dataSegment() { 762 return dataSegment; 763 } 764 765 public void setNeedsCleanup(boolean needsCleanup) { 766 this.needsCleanup = needsCleanup; 767 } 768 769 public boolean needsCleanup() { 770 return needsCleanup; 771 } 772 773 public String toString() { 774 return String.format("name = %s, cls = %s", name, cls); 775 } 776 } 777 778 private JSONObject loadJson(String fileName) { 779 StringBuilder jsonString = new StringBuilder(); 780 try (BufferedReader br = new BufferedReader(new FileReader(fileName))) 781 { 782 String line; 783 while ((line = br.readLine()) != null) { 784 jsonString.append(line); 785 } 786 } catch (IOException e) { 787 e.printStackTrace(); 788 } 789 JSONObject jsonObject = new JSONObject(jsonString.toString()); 790 return jsonObject; 791 } 792 793 private class Test { 794 public static float[] add(float[] a, float[] b, int SIZE) { 795 float[] output = new float[SIZE]; 796 for (int i = 0; i < SIZE; ++i) 797 output[i] = a[i] + b[i]; 798 return output; 799 } 800 public static float[] softmax(float[] a, int X, int Y) { 801 float[] output = new float[X * Y]; 802 for (int i = 0; i < X; ++i) { 803 float max = Float.MIN_VALUE; 804 for (int j = 0; j < Y; ++j) { 805 max = Math.max(max, a[i * Y + j]); 806 } 807 float sum = 0; 808 for (int j = 0; j < Y; ++j) { 809 output[i * Y + j] = (float)Math.exp(a[i * Y + j] - max); 810 sum += output[i * Y + j]; 811 } 812 for (int j = 0; j < Y; ++j) { 813 output[i * Y + j] /= sum; 814 } 815 } 816 return output; 817 } 818 public static float[] matmul(float[] a, float[] b, int M, int N, int K) { 819 float[] output = new float[M * N]; 820 for (int i = 0; i < M; i++) { 821 for (int j = 0; j < N; j++) { 822 float tmp = 0; 823 for (int k = 0; k < K; k++) { 824 tmp += a[i * K + k] * b[k * N + j]; 825 } 826 output[i * N + j] = tmp; 827 } 828 } 829 return output; 830 } 831 public static void check(float[] expected, float[] output) { 832 for (int i = 0; i < expected.length; i++) { 833 if (Math.abs(expected[i] - output[i]) > 1e-2) { 834 System.out.printf("Mismatch at %d: %f != %f%n", i, expected[i], output[i]); 835 throw new RuntimeException("Mismatch"); 836 } 837 } 838 System.out.println("Test passed"); 839 } 840 } 841 842 private class Writer { 843 private final String fileName; 844 private final FileWriter writer; 845 846 public Writer(String fileName) { 847 this.fileName = fileName; 848 try { 849 writer = new FileWriter(fileName, false); 850 } catch (IOException e) { 851 throw new RuntimeException(e); 852 } 853 } 854 855 public void write(String line) { 856 try { 857 writer.write(line); 858 } catch (IOException e) { 859 throw new RuntimeException(e); 860 } 861 } 862 863 public void close() { 864 try { 865 writer.close(); 866 } catch (IOException e) { 867 throw new RuntimeException(e); 868 } 869 } 870 } 871 }