Skip to content

Commit 28daf39

Browse files
KanclerzPiotrigcbot
authored andcommitted
Extend SPV_INTEL_fp_max_error support for fdiv and sqrt
This PR extends SPV_INTEL_fp_max_error to support correctly rounded operations (< 1 ULP precision) for fdiv and sqrt operations.
1 parent 43f0394 commit 28daf39

File tree

7 files changed

+230
-150
lines changed

7 files changed

+230
-150
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/AccuracyDecoratedCallsBiFResolution/AccuracyDecoratedCallsBiFResolution.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ SPDX-License-Identifier: MIT
1010

1111
#include "llvmWrapper/IR/Type.h"
1212
#include "llvm/IR/Attributes.h"
13+
#include <llvm/Demangle/Demangle.h>
14+
#include <llvm/IR/CallingConv.h>
15+
#include <llvm/IR/DerivedTypes.h>
1316
#include "Compiler/IGCPassSupport.h"
1417
#include "Compiler/CodeGenPublic.h"
1518
#include "Probe/Assertion.h"
@@ -37,8 +40,10 @@ static inline const char *toString(Accuracy a) {
3740
return "high accuracy";
3841
case LOW_ACCURACY:
3942
return "low accuracy";
40-
case ENHANCED_PRECISION:
41-
return "enhanced precision";
43+
case ENHANCED_PERFORMANCE:
44+
return "enhanced performance";
45+
case CORRECTLY_ROUNDED:
46+
return "correctly rounded";
4247
}
4348
}
4449

@@ -63,38 +68,43 @@ bool AccuracyDecoratedCallsBiFResolution::runOnModule(Module &M) {
6368
}
6469

6570
void AccuracyDecoratedCallsBiFResolution::visitBinaryOperator(BinaryOperator &inst) {
66-
// not supported by BiFModule yet;
67-
return;
71+
// other binary operators have to be correctly rounded only
72+
if (inst.getOpcode() != Instruction::FDiv)
73+
return;
6874

69-
if (!inst.getType()->isFloatingPointTy())
75+
// only float type is supported for now
76+
if (!inst.getType()->isFloatTy())
7077
return;
7178

7279
MDNode *MD = inst.getMetadata("fpbuiltin-max-error");
7380
if (!MD)
7481
return;
82+
StringRef maxErrorStr = cast<MDString>(MD->getOperand(0))->getString();
83+
double maxError = 0;
84+
maxErrorStr.getAsDouble(maxError);
7585

76-
IGC_ASSERT_MESSAGE(!IGCLLVM::isBFloatTy(inst.getType()),
77-
"bfloat type is not supported with fpbuiltin-max-error decoration");
78-
if (IGCLLVM::isBFloatTy(inst.getType()))
86+
// no need to change anything for max error >= 2.5 ULP
87+
if (maxError >= 2.5)
7988
return;
8089

8190
std::vector<Value *> args{};
8291
args.push_back(inst.getOperand(0));
8392
args.push_back(inst.getOperand(1));
8493

85-
StringRef maxErrorStr = cast<MDString>(MD->getOperand(0))->getString();
86-
const std::string oldFuncName = inst.getOpcodeName();
94+
FunctionType *FT =
95+
FunctionType::get(inst.getType(), {inst.getOperand(0)->getType(), inst.getOperand(1)->getType()}, false);
8796

88-
Instruction *currInst = cast<Instruction>(&inst);
8997
Function *newFunc =
90-
getOrInsertNewFunc(oldFuncName, inst.getType(), args, {}, CallingConv::SPIR_FUNC, maxErrorStr, currInst);
91-
98+
cast<Function>(m_Module->getOrInsertFunction("__builtin_spirv_divide_cr_f32_f32", FT, {}).getCallee());
99+
newFunc->setCallingConv(CallingConv::SPIR_FUNC);
92100
CallInst *newCall = CallInst::Create(newFunc, args, inst.getName(), &inst);
101+
93102
llvm::Attribute attr = llvm::Attribute::get(inst.getContext(), "fpbuiltin-max-error", maxErrorStr);
94103
newCall->addFnAttr(attr);
95104

96105
inst.replaceAllUsesWith(newCall);
97106
inst.eraseFromParent();
107+
m_changed = true;
98108
}
99109

100110
void AccuracyDecoratedCallsBiFResolution::visitCallInst(CallInst &callInst) {
@@ -180,26 +190,43 @@ std::string AccuracyDecoratedCallsBiFResolution::getFunctionName(const StringRef
180190
default:
181191
IGC_ASSERT_MESSAGE(false, "Unreachable, NameToBuiltinDef.hpp is likely broken");
182192
return oldFuncName.str();
183-
case ENHANCED_PRECISION:
193+
case ENHANCED_PERFORMANCE:
184194
return getFunctionName(oldFuncName, LOW_ACCURACY, currInst);
185195
case LOW_ACCURACY:
186196
return getFunctionName(oldFuncName, HIGH_ACCURACY, currInst);
187197
}
188198
}
189199
return m_nameToBuiltin.at(oldFuncName.str()).at(accuracy);
190200
}
201+
namespace {
202+
203+
bool isSqrt(const llvm::Instruction *inst) {
204+
if (const llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(inst)) {
205+
if (const llvm::Function *calledFunc = callInst->getCalledFunction()) {
206+
std::string demangledName = llvm::demangle(calledFunc->getName().str());
207+
if (demangledName.find("__spirv_ocl_sqrt") != std::string::npos) {
208+
return true;
209+
}
210+
}
211+
}
212+
return false;
213+
}
214+
} // namespace
191215

192216
Accuracy AccuracyDecoratedCallsBiFResolution::getAccuracy(double maxError, double cutOff,
193217
const Instruction *currInst) const {
194-
if (maxError < 1.0)
195-
getAnalysis<CodeGenContextWrapper>().getCodeGenContext()->EmitError(
196-
"fpbuiltin-max-error can't have values below 1.0", currInst);
218+
if (maxError < 1.0) {
219+
if (isSqrt(currInst))
220+
return CORRECTLY_ROUNDED;
197221

222+
getAnalysis<CodeGenContextWrapper>().getCodeGenContext()->EmitError(
223+
"fpbuiltin-max-error with values below 1.0 is only supported for sqrt and division", currInst);
224+
}
198225
if (maxError < 4.0)
199226
return HIGH_ACCURACY;
200227

201228
if (maxError < cutOff) // 2^12 (for f32) or 2^26 (for f64)
202229
return LOW_ACCURACY;
203230

204-
return ENHANCED_PRECISION;
231+
return ENHANCED_PERFORMANCE;
205232
}

IGC/Compiler/Optimizer/OpenCLPasses/AccuracyDecoratedCallsBiFResolution/AccuracyDecoratedCallsBiFResolution.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ SPDX-License-Identifier: MIT
1616
#include <unordered_map>
1717

1818
namespace IGC {
19-
enum Accuracy { HIGH_ACCURACY, LOW_ACCURACY, ENHANCED_PRECISION };
19+
enum Accuracy { HIGH_ACCURACY, LOW_ACCURACY, ENHANCED_PERFORMANCE, CORRECTLY_ROUNDED };
2020

2121
class AccuracyDecoratedCallsBiFResolution : public llvm::ModulePass,
2222
public llvm::InstVisitor<AccuracyDecoratedCallsBiFResolution> {
@@ -43,7 +43,7 @@ class AccuracyDecoratedCallsBiFResolution : public llvm::ModulePass,
4343
private:
4444
bool m_changed = false;
4545
llvm::Module *m_Module = nullptr;
46-
// m_nameToBuiltin["_Z15__spirv_ocl_sinf"][ENHANCED_PRECISION] --> "__ocl_svml_sinf_ep"
46+
// m_nameToBuiltin["_Z15__spirv_ocl_sinf"][ENHANCED_PERFORMANCE] --> "__ocl_svml_sinf_ep"
4747
std::unordered_map<std::string, AccurateBuiltins> m_nameToBuiltin{};
4848

4949
llvm::Function *getOrInsertNewFunc(const llvm::StringRef oldFuncName, llvm::Type *funcType,

0 commit comments

Comments
 (0)