[mlir][gpu] Add support for lowering math.erf to __nv_erf (#79848)

This commit is contained in:
David Majnemer 2024-01-29 19:35:23 +00:00 committed by GitHub
parent d9f1791a0a
commit 0039a2ff4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 0 deletions

View File

@ -367,6 +367,7 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf", populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil"); "__nv_ceil");
populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos"); populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp"); populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f", populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2"); "__nv_exp2");

View File

@ -627,6 +627,19 @@ gpu.module @test_module_31 {
} }
} }
gpu.module @test_module_32 {
// CHECK: llvm.func @__nv_erff(f32) -> f32
// CHECK: llvm.func @__nv_erf(f64) -> f64
// CHECK-LABEL: func @gpu_erf
func.func @gpu_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.erf %arg_f32 : f32
// CHECK: llvm.call @__nv_erff(%{{.*}}) : (f32) -> f32
%result64 = math.erf %arg_f64 : f64
// CHECK: llvm.call @__nv_erf(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
}
gpu.module @gpumodule { gpu.module @gpumodule {
// CHECK-LABEL: func @kernel_with_block_size() // CHECK-LABEL: func @kernel_with_block_size()
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>} // CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}