//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/KnowledgeRetention.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/CommandLine.h" #include "gtest/gtest.h" using namespace llvm; extern cl::opt ShouldPreserveAllAttributes; static void RunTest( StringRef Head, StringRef Tail, std::vector>> &Tests) { std::string IR; IR.append(Head.begin(), Head.end()); for (auto &Elem : Tests) IR.append(Elem.first.begin(), Elem.first.end()); IR.append(Tail.begin(), Tail.end()); LLVMContext C; SMDiagnostic Err; std::unique_ptr Mod = parseAssemblyString(IR, Err, C); if (!Mod) Err.print("AssumeQueryAPI", errs()); unsigned Idx = 0; for (Instruction &I : (*Mod->getFunction("test")->begin())) { if (Idx < Tests.size()) Tests[Idx].second(&I); Idx++; } } void AssertMatchesExactlyAttributes(CallInst *Assume, Value *WasOn, StringRef AttrToMatch) { Regex Reg(AttrToMatch); SmallVector Matches; for (StringRef Attr : { #define GET_ATTR_NAMES #define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME), #include "llvm/IR/Attributes.inc" }) { bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr; if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr)) { ASSERT_TRUE(false); } } } void AssertHasTheRightValue(CallInst *Assume, Value *WasOn, Attribute::AttrKind Kind, unsigned Value, bool Both, AssumeQuery AQ = AssumeQuery::Highest) { if (!Both) { uint64_t ArgVal = 0; ASSERT_TRUE(hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal, AQ)); ASSERT_EQ(ArgVal, Value); return; } uint64_t ArgValLow = 0; uint64_t ArgValHigh = 0; bool ResultLow = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValLow, AssumeQuery::Lowest); bool ResultHigh = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValHigh, AssumeQuery::Highest); if (ResultLow != ResultHigh || ResultHigh == false) { ASSERT_TRUE(false); } if (ArgValLow != Value || ArgValLow != ArgValHigh) { ASSERT_TRUE(false); } } TEST(AssumeQueryAPI, Basic) { StringRef Head = "declare void @llvm.assume(i1)\n" "declare void @func(i32*, i32*)\n" "declare void @func1(i32*, i32*, i32*, i32*)\n" "declare void @func_many(i32*) \"no-jump-tables\" nounwind " "\"less-precise-fpmad\" willreturn norecurse\n" "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n"; StringRef Tail = "ret void\n" "}"; std::vector>> Tests; Tests.push_back(std::make_pair( "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " "8 noalias %P1)\n", [](Instruction *I) { CallInst *Assume = BuildAssumeFromInst(I); Assume->insertBefore(I); AssertMatchesExactlyAttributes(Assume, I->getOperand(0), "(nonnull|align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(1), "(noalias|align)"); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, 16, true); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Alignment, 4, true); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Alignment, 4, true); })); Tests.push_back(std::make_pair( "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* " "nonnull " "align 8 dereferenceable(28) %P, i32* nonnull align 64 " "dereferenceable(4) " "%P, i32* nonnull align 16 dereferenceable(12) %P)\n", [](Instruction *I) { CallInst *Assume = BuildAssumeFromInst(I); Assume->insertBefore(I); AssertMatchesExactlyAttributes(Assume, I->getOperand(0), "(nonnull|align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(1), "(nonnull|align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(2), "(nonnull|align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(3), "(nonnull|align|dereferenceable)"); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, 48, false, AssumeQuery::Highest); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Alignment, 64, false, AssumeQuery::Highest); AssertHasTheRightValue(Assume, I->getOperand(1), Attribute::AttrKind::Alignment, 64, false, AssumeQuery::Highest); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, 4, false, AssumeQuery::Lowest); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Alignment, 8, false, AssumeQuery::Lowest); AssertHasTheRightValue(Assume, I->getOperand(1), Attribute::AttrKind::Alignment, 8, false, AssumeQuery::Lowest); })); Tests.push_back(std::make_pair( "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) { ShouldPreserveAllAttributes.setValue(true); CallInst *Assume = BuildAssumeFromInst(I); Assume->insertBefore(I); AssertMatchesExactlyAttributes( Assume, nullptr, "(align|no-jump-tables|less-precise-fpmad|" "nounwind|norecurse|willreturn|cold)"); ShouldPreserveAllAttributes.setValue(false); })); Tests.push_back( std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) { CallInst *Assume = cast(I); AssertMatchesExactlyAttributes(Assume, nullptr, ""); })); Tests.push_back(std::make_pair( "call void @func1(i32* readnone align 32 " "dereferenceable(48) noalias %P, i32* " "align 8 dereferenceable(28) %P1, i32* align 64 " "dereferenceable(4) " "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n", [](Instruction *I) { CallInst *Assume = BuildAssumeFromInst(I); Assume->insertBefore(I); AssertMatchesExactlyAttributes( Assume, I->getOperand(0), "(readnone|align|dereferenceable|noalias)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(1), "(align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(2), "(align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, I->getOperand(3), "(nonnull|align|dereferenceable)"); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Alignment, 32, true); AssertHasTheRightValue(Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, 48, true); AssertHasTheRightValue(Assume, I->getOperand(1), Attribute::AttrKind::Dereferenceable, 28, true); AssertHasTheRightValue(Assume, I->getOperand(1), Attribute::AttrKind::Alignment, 8, true); AssertHasTheRightValue(Assume, I->getOperand(2), Attribute::AttrKind::Alignment, 64, true); AssertHasTheRightValue(Assume, I->getOperand(2), Attribute::AttrKind::Dereferenceable, 4, true); AssertHasTheRightValue(Assume, I->getOperand(3), Attribute::AttrKind::Alignment, 16, true); AssertHasTheRightValue(Assume, I->getOperand(3), Attribute::AttrKind::Dereferenceable, 12, true); })); /// Keep this test last as it modifies the function. Tests.push_back(std::make_pair( "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " "8 noalias %P1)\n", [](Instruction *I) { CallInst *Assume = BuildAssumeFromInst(I); Assume->insertBefore(I); Value *New = I->getFunction()->getArg(3); Value *Old = I->getOperand(0); AssertMatchesExactlyAttributes(Assume, New, ""); AssertMatchesExactlyAttributes(Assume, Old, "(nonnull|align|dereferenceable)"); Old->replaceAllUsesWith(New); AssertMatchesExactlyAttributes(Assume, New, "(nonnull|align|dereferenceable)"); AssertMatchesExactlyAttributes(Assume, Old, ""); })); RunTest(Head, Tail, Tests); }