@@ -10,6 +10,9 @@ SPDX-License-Identifier: MIT
1010
1111#include " llvmWrapper/IR/Type.h"
1212#include " llvm/IR/Attributes.h"
13+ #include < llvm/Demangle/Demangle.h>
14+ #include < llvm/IR/CallingConv.h>
15+ #include < llvm/IR/DerivedTypes.h>
1316#include " Compiler/IGCPassSupport.h"
1417#include " Compiler/CodeGenPublic.h"
1518#include " Probe/Assertion.h"
@@ -37,8 +40,10 @@ static inline const char *toString(Accuracy a) {
3740 return " high accuracy" ;
3841 case LOW_ACCURACY:
3942 return " low accuracy" ;
40- case ENHANCED_PRECISION:
41- return " enhanced precision" ;
43+ case ENHANCED_PERFORMANCE:
44+ return " enhanced performance" ;
45+ case CORRECTLY_ROUNDED:
46+ return " correctly rounded" ;
4247 }
4348}
4449
@@ -63,38 +68,43 @@ bool AccuracyDecoratedCallsBiFResolution::runOnModule(Module &M) {
6368}
6469
6570void AccuracyDecoratedCallsBiFResolution::visitBinaryOperator (BinaryOperator &inst) {
66- // not supported by BiFModule yet;
67- return ;
71+ // other binary operators have to be correctly rounded only
72+ if (inst.getOpcode () != Instruction::FDiv)
73+ return ;
6874
69- if (!inst.getType ()->isFloatingPointTy ())
75+ // only float type is supported for now
76+ if (!inst.getType ()->isFloatTy ())
7077 return ;
7178
7279 MDNode *MD = inst.getMetadata (" fpbuiltin-max-error" );
7380 if (!MD)
7481 return ;
82+ StringRef maxErrorStr = cast<MDString>(MD->getOperand (0 ))->getString ();
83+ double maxError = 0 ;
84+ maxErrorStr.getAsDouble (maxError);
7585
76- IGC_ASSERT_MESSAGE (!IGCLLVM::isBFloatTy (inst.getType ()),
77- " bfloat type is not supported with fpbuiltin-max-error decoration" );
78- if (IGCLLVM::isBFloatTy (inst.getType ()))
86+ // no need to change anything for max error >= 2.5 ULP
87+ if (maxError >= 2.5 )
7988 return ;
8089
8190 std::vector<Value *> args{};
8291 args.push_back (inst.getOperand (0 ));
8392 args.push_back (inst.getOperand (1 ));
8493
85- StringRef maxErrorStr = cast<MDString>(MD-> getOperand ( 0 ))-> getString ();
86- const std::string oldFuncName = inst.getOpcodeName ( );
94+ FunctionType *FT =
95+ FunctionType::get (inst. getType (), {inst. getOperand ( 0 )-> getType (), inst.getOperand ( 1 )-> getType ()}, false );
8796
88- Instruction *currInst = cast<Instruction>(&inst);
8997 Function *newFunc =
90- getOrInsertNewFunc (oldFuncName, inst. getType (), args , {}, CallingConv::SPIR_FUNC, maxErrorStr, currInst );
91-
98+ cast<Function>(m_Module-> getOrInsertFunction ( " __builtin_spirv_divide_cr_f32_f32 " , FT , {}). getCallee () );
99+ newFunc-> setCallingConv (CallingConv::SPIR_FUNC);
92100 CallInst *newCall = CallInst::Create (newFunc, args, inst.getName (), &inst);
101+
93102 llvm::Attribute attr = llvm::Attribute::get (inst.getContext (), " fpbuiltin-max-error" , maxErrorStr);
94103 newCall->addFnAttr (attr);
95104
96105 inst.replaceAllUsesWith (newCall);
97106 inst.eraseFromParent ();
107+ m_changed = true ;
98108}
99109
100110void AccuracyDecoratedCallsBiFResolution::visitCallInst (CallInst &callInst) {
@@ -180,26 +190,43 @@ std::string AccuracyDecoratedCallsBiFResolution::getFunctionName(const StringRef
180190 default :
181191 IGC_ASSERT_MESSAGE (false , " Unreachable, NameToBuiltinDef.hpp is likely broken" );
182192 return oldFuncName.str ();
183- case ENHANCED_PRECISION :
193+ case ENHANCED_PERFORMANCE :
184194 return getFunctionName (oldFuncName, LOW_ACCURACY, currInst);
185195 case LOW_ACCURACY:
186196 return getFunctionName (oldFuncName, HIGH_ACCURACY, currInst);
187197 }
188198 }
189199 return m_nameToBuiltin.at (oldFuncName.str ()).at (accuracy);
190200}
201+ namespace {
202+
203+ bool isSqrt (const llvm::Instruction *inst) {
204+ if (const llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(inst)) {
205+ if (const llvm::Function *calledFunc = callInst->getCalledFunction ()) {
206+ std::string demangledName = llvm::demangle (calledFunc->getName ().str ());
207+ if (demangledName.find (" __spirv_ocl_sqrt" ) != std::string::npos) {
208+ return true ;
209+ }
210+ }
211+ }
212+ return false ;
213+ }
214+ } // namespace
191215
192216Accuracy AccuracyDecoratedCallsBiFResolution::getAccuracy (double maxError, double cutOff,
193217 const Instruction *currInst) const {
194- if (maxError < 1.0 )
195- getAnalysis<CodeGenContextWrapper>(). getCodeGenContext ()-> EmitError (
196- " fpbuiltin-max-error can't have values below 1.0 " , currInst) ;
218+ if (maxError < 1.0 ) {
219+ if ( isSqrt (currInst))
220+ return CORRECTLY_ROUNDED ;
197221
222+ getAnalysis<CodeGenContextWrapper>().getCodeGenContext ()->EmitError (
223+ " fpbuiltin-max-error with values below 1.0 is only supported for sqrt and division" , currInst);
224+ }
198225 if (maxError < 4.0 )
199226 return HIGH_ACCURACY;
200227
201228 if (maxError < cutOff) // 2^12 (for f32) or 2^26 (for f64)
202229 return LOW_ACCURACY;
203230
204- return ENHANCED_PRECISION ;
231+ return ENHANCED_PERFORMANCE ;
205232}
0 commit comments