diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index a533bea6f1e6..cf9b7ed15e67 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -574,7 +574,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) { assert(mlir::isa(src.getType()) && "expected ptr src"); - return createBitcast(src, getPointerTo(newPointeeTy)); + auto srcPtrTy = mlir::cast(src.getType()); + mlir::Type newPtrTy = getPointerTo(newPointeeTy, srcPtrTy.getAddrSpace()); + return createBitcast(src, newPtrTy); } mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src, @@ -586,6 +588,29 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { return createAddrSpaceCast(src.getLoc(), src, newTy); } + mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Location loc, + mlir::Value src, + mlir::Type newPointerTy) { + assert(mlir::isa(src.getType()) && + "expected source pointer"); + assert(mlir::isa(newPointerTy) && + "expected destination pointer type"); + + auto srcPtrTy = mlir::cast(src.getType()); + auto dstPtrTy = mlir::cast(newPointerTy); + + mlir::Value addrSpaceCasted = src; + if (srcPtrTy.getAddrSpace() != dstPtrTy.getAddrSpace()) + addrSpaceCasted = createAddrSpaceCast(loc, src, dstPtrTy); + + return createPtrBitcast(addrSpaceCasted, dstPtrTy.getPointee()); + } + + mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Value src, + mlir::Type newPointerTy) { + return createPointerBitCastOrAddrSpaceCast(src.getLoc(), src, newPointerTy); + } + mlir::Value createPtrIsNull(mlir::Value ptr) { return createNot(createPtrToBoolCast(ptr)); } diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index b6b114f0e4b9..17241b806c56 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -820,7 +820,81 @@ class ScalarExprEmitter : public StmtVisitor { mlir::Value VisitObjCDictionaryLiteral(ObjCDictionaryLiteral *E) { llvm_unreachable("NYI"); } - mlir::Value VisitAsTypeExpr(AsTypeExpr *E) { llvm_unreachable("NYI"); } + + // Create cast instructions for converting LLVM value Src to MLIR type DstTy. + // Src has the same size as DstTy. Both are single value types + // but could be scalar or vectors of different lengths, and either can be + // pointer. + mlir::Value createCastsForTypeOfSameSize(mlir::Value Src, mlir::Type DstTy) { + auto SrcTy = Src.getType(); + + // Case 1. + if (!isa(SrcTy) && !isa(DstTy)) + return Builder.createBitcast(Src, DstTy); + + // Case 2. + if (isa(SrcTy) && isa(DstTy)) + return Builder.createPointerBitCastOrAddrSpaceCast(Src, DstTy); + + // Case 3. + if (isa(SrcTy) && !isa(DstTy)) { + // Case 3b. + if (!Builder.isInt(DstTy)) + llvm_unreachable("NYI"); + // Cases 3a and 3b. + llvm_unreachable("NYI"); + } + + // Case 4b. + if (!Builder.isInt(SrcTy)) + llvm_unreachable("NYI"); + + // Cases 4a and 4b. + llvm_unreachable("NYI"); + } + + mlir::Value VisitAsTypeExpr(AsTypeExpr *E) { + unsigned numSrcElems = 0; + QualType qualSrcTy = E->getSrcExpr()->getType(); + mlir::Type srcTy = CGF.convertType(qualSrcTy); + if (auto v = dyn_cast(srcTy)) { + assert(!cir::MissingFeatures::scalableVectors() && + "NYI: non-fixed (scalable) vector src"); + numSrcElems = v.getSize(); + } + + unsigned numDstElems = 0; + QualType qualDstTy = E->getType(); + mlir::Type dstTy = CGF.convertType(qualDstTy); + if (auto v = dyn_cast(dstTy)) { + assert(!cir::MissingFeatures::scalableVectors() && + "NYI: non-fixed (scalable) vector dst"); + numDstElems = v.getSize(); + } + + // Use bit vector expansion for ext_vector_type boolean vectors. + if (qualDstTy->isExtVectorBoolType()) { + llvm_unreachable("NYI"); + } + + // Going from vec3 to non-vec3 is a special case and requires a shuffle + // vector to get a vec4, then a bitcast if the target type is different. + if (numSrcElems == 3 && numDstElems != 3) { + llvm_unreachable("NYI"); + } + + // Going from non-vec3 to vec3 is a special case and requires a bitcast + // to vec4 if the original type is not vec4, then a shuffle vector to + // get a vec3. + if (numSrcElems != 3 && numDstElems == 3) { + llvm_unreachable("NYI"); + } + + // Otherwise, fallback to bitcast of same size + mlir::Value src = CGF.emitScalarExpr(E->getSrcExpr()); + return createCastsForTypeOfSameSize(src, dstTy); + } + mlir::Value VisitAtomicExpr(AtomicExpr *E) { return CGF.emitAtomicExpr(E).getScalarVal(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 4387142ac8c5..cecaade6a4dc 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -662,8 +662,10 @@ LogicalResult cir::CastOp::verify() { auto resPtrTy = mlir::dyn_cast(resType); if (!srcPtrTy || !resPtrTy) return emitOpError() << "requires !cir.ptr type for source and result"; - if (srcPtrTy.getPointee() != resPtrTy.getPointee()) - return emitOpError() << "requires two types differ in addrspace only"; + // Address space verification is sufficient here. The pointee types need not + // be verified as they are handled by bitcast verification logic, which + // ensures address space compatibility. Verifying pointee types would create + // a circular dependency between address space and pointee type casting. return success(); } case cir::CastKind::float_to_complex: { diff --git a/clang/test/CIR/CodeGen/OpenCL/as_type.cl b/clang/test/CIR/CodeGen/OpenCL/as_type.cl new file mode 100644 index 000000000000..6fc8104e8a24 --- /dev/null +++ b/clang/test/CIR/CodeGen/OpenCL/as_type.cl @@ -0,0 +1,55 @@ +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-cir -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=CIR + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM + +typedef __attribute__(( ext_vector_type(4) )) char char4; + +// CIR: cir.func @f4(%{{.*}}: !s32i loc({{.*}})) -> !cir.vector +// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr +// CIR: cir.cast bitcast %[[x]] : !s32i -> !cir.vector +// LLVM: define spir_func <4 x i8> @f4(i32 %[[x:.*]]) +// LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8> +// LLVM-NOT: shufflevector +// LLVM: ret <4 x i8> %[[astype]] +// OG-LLVM: define spir_func noundef <4 x i8> @f4(i32 noundef %[[x:.*]]) +// OG-LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8> +// OG-LLVM-NOT: shufflevector +// OG-LLVM: ret <4 x i8> %[[astype]] +char4 f4(int x) { + return __builtin_astype(x, char4); +} + +// CIR: cir.func @f6(%{{.*}}: !cir.vector loc({{.*}})) -> !s32i +// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr, addrspace(offload_private)>, !cir.vector +// CIR: cir.cast bitcast %[[x]] : !cir.vector -> !s32i +// LLVM: define{{.*}} spir_func i32 @f6(<4 x i8> %[[x:.*]]) +// LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32 +// LLVM-NOT: shufflevector +// LLVM: ret i32 %[[astype]] +// OG-LLVM: define{{.*}} spir_func noundef i32 @f6(<4 x i8> noundef %[[x:.*]]) +// OG-LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32 +// OG-LLVM-NOT: shufflevector +// OG-LLVM: ret i32 %[[astype]] +int f6(char4 x) { + return __builtin_astype(x, int); +} + +// CIR: cir.func @f4_ptr(%{{.*}}: !cir.ptr loc({{.*}})) -> !cir.ptr, addrspace(offload_local)> +// CIR: %[[x:.*]] = cir.load align(8) %{{.*}} : !cir.ptr, addrspace(offload_private)>, !cir.ptr +// CIR: cir.cast address_space %[[x]] : !cir.ptr -> !cir.ptr, addrspace(offload_local)> +// LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) readnone captures(ret: address, provenance) %[[x:.*]]) +// LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3) +// LLVM-NOT: shufflevector +// LLVM: ret ptr addrspace(3) %[[astype]] +// OG-LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) noundef readnone captures(ret: address, provenance) %[[x:.*]]) +// OG-LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3) +// OG-LLVM-NOT: shufflevector +// OG-LLVM: ret ptr addrspace(3) %[[astype]] +__local char4* f4_ptr(__global int* x) { + return __builtin_astype(x, __local char4*); +} \ No newline at end of file diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index f9a7bb92c656..70846ac264cd 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -300,15 +300,6 @@ cir.func @cast24(%p : !u32i) { // ----- -!u32i = !cir.int -!u64i = !cir.int -cir.func @cast25(%p : !cir.ptr)>) { - %0 = cir.cast address_space %p : !cir.ptr)> -> !cir.ptr)> // expected-error {{requires two types differ in addrspace only}} - cir.return -} - -// ----- - !u64i = !cir.int cir.func @cast26(%p : !cir.ptr)>) { %0 = cir.cast address_space %p : !cir.ptr)> -> !u64i // expected-error {{requires !cir.ptr type for source and result}}