diff --git a/3rdparty/spirv-tools/include/generated/build-version.inc b/3rdparty/spirv-tools/include/generated/build-version.inc index 81c26ace3..3977361de 100644 --- a/3rdparty/spirv-tools/include/generated/build-version.inc +++ b/3rdparty/spirv-tools/include/generated/build-version.inc @@ -1 +1 @@ -"v2025.5", "SPIRV-Tools v2025.5 v2025.4-64-gd2a11ec9" +"v2025.5", "SPIRV-Tools v2025.5 v2025.5.rc1-32-g6e7423bc" diff --git a/3rdparty/spirv-tools/source/diff/diff.cpp b/3rdparty/spirv-tools/source/diff/diff.cpp index 7fd21f352..d548aeada 100644 --- a/3rdparty/spirv-tools/source/diff/diff.cpp +++ b/3rdparty/spirv-tools/source/diff/diff.cpp @@ -1219,6 +1219,7 @@ bool Differ::DoDebugAndAnnotationInstructionsMatch( case spv::Op::OpMemberDecorate: return DoOperandsMatch(src_inst, dst_inst, 0, 3); case spv::Op::OpExtInst: + return DoOperandsMatch(src_inst, dst_inst, 0, 2); case spv::Op::OpDecorationGroup: case spv::Op::OpGroupDecorate: case spv::Op::OpGroupMemberDecorate: @@ -2612,6 +2613,9 @@ void Differ::MatchExtInstDebugInfo() { // This section includes OpExtInst for DebugInfo extension MatchDebugAndAnnotationInstructions(src_->ext_inst_debuginfo(), dst_->ext_inst_debuginfo()); + // OpExtInst can exist in other sections too, such as with non-semantic info. + MatchDebugAndAnnotationInstructions(src_->types_values(), + dst_->types_values()); } void Differ::MatchAnnotations() { diff --git a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp index 51dc68e3a..54b100250 100644 --- a/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/aggressive_dead_code_elim_pass.cpp @@ -44,9 +44,10 @@ constexpr uint32_t kExtInstSetInIdx = 0; constexpr uint32_t kExtInstOpInIdx = 1; constexpr uint32_t kInterpolantInIdx = 2; constexpr uint32_t kCooperativeMatrixLoadSourceAddrInIdx = 0; -constexpr uint32_t kDebugValueLocalVariable = 2; -constexpr uint32_t kDebugValueValue = 3; -constexpr uint32_t kDebugValueExpression = 4; +constexpr uint32_t kDebugDeclareVariableInIdx = 3; +constexpr uint32_t kDebugValueLocalVariableInIdx = 2; +constexpr uint32_t kDebugValueValueInIdx = 3; +constexpr uint32_t kDebugValueExpressionInIdx = 4; // Sorting functor to present annotation instructions in an easy-to-process // order. The functor orders by opcode first and falls back on unique id @@ -290,40 +291,95 @@ Pass::Status AggressiveDCEPass::ProcessDebugInformation( std::list& structured_order) { for (auto bi = structured_order.begin(); bi != structured_order.end(); bi++) { bool succeeded = (*bi)->WhileEachInst([this](Instruction* inst) { - // DebugDeclare is not dead. It must be converted to DebugValue in a - // later pass - if (inst->IsNonSemanticInstruction() && - inst->GetShader100DebugOpcode() == - NonSemanticShaderDebugInfo100DebugDeclare) { - AddToWorklist(inst); - return true; - } + if (!inst->IsNonSemanticInstruction()) return true; - // If the Value of a DebugValue is killed, set Value operand to Undef - if (inst->IsNonSemanticInstruction() && - inst->GetShader100DebugOpcode() == - NonSemanticShaderDebugInfo100DebugValue) { - uint32_t id = inst->GetSingleWordInOperand(kDebugValueValue); - auto def = get_def_use_mgr()->GetDef(id); - if (!IsLive(def)) { + if (inst->GetShader100DebugOpcode() == + NonSemanticShaderDebugInfo100DebugDeclare) { + if (IsLive(inst)) return true; + + uint32_t var_id = + inst->GetSingleWordInOperand(kDebugDeclareVariableInIdx); + auto var_def = get_def_use_mgr()->GetDef(var_id); + + if (IsLive(var_def)) { AddToWorklist(inst); - uint32_t undef_id = Type2Undef(def->type_id()); - if (undef_id == 0) { - return false; - } - inst->SetInOperand(kDebugValueValue, {undef_id}); - context()->get_def_use_mgr()->UpdateDefUse(inst); - id = inst->GetSingleWordInOperand(kDebugValueLocalVariable); - auto localVar = get_def_use_mgr()->GetDef(id); - AddToWorklist(localVar); - context()->get_def_use_mgr()->UpdateDefUse(localVar); - AddOperandsToWorkList(localVar); - id = inst->GetSingleWordInOperand(kDebugValueExpression); - auto expression = get_def_use_mgr()->GetDef(id); - AddToWorklist(expression); - context()->get_def_use_mgr()->UpdateDefUse(expression); return true; } + + // DebugDeclare Variable is not live. Find the value that was being + // stored to this variable. If it's live then create a new DebugValue + // with this value. Otherwise let it die in peace. + get_def_use_mgr()->ForEachUser(var_id, [this, var_id, + inst](Instruction* user) { + if (user->opcode() == spv::Op::OpStore) { + uint32_t stored_value_id = 0; + const uint32_t kStoreValueInIdx = 1; + stored_value_id = user->GetSingleWordInOperand(kStoreValueInIdx); + if (!IsLive(get_def_use_mgr()->GetDef(stored_value_id))) { + return true; + } + + // value being stored is still live + Instruction* next_inst = inst->NextNode(); + bool added = + context()->get_debug_info_mgr()->AddDebugValueForVariable( + user, var_id, stored_value_id, inst); + if (added && next_inst) { + auto new_debug_value = next_inst->PreviousNode(); + live_insts_.Set(new_debug_value->unique_id()); + } + } + return true; + }); + } else if (inst->GetShader100DebugOpcode() == + NonSemanticShaderDebugInfo100DebugValue) { + uint32_t var_operand_idx = kDebugValueValueInIdx; + uint32_t id = inst->GetSingleWordInOperand(var_operand_idx); + auto def = get_def_use_mgr()->GetDef(id); + + if (IsLive(def)) { + AddToWorklist(inst); + return true; + } + + // Value operand of DebugValue is not live + // Set Value to Undef of appropriate type + live_insts_.Set(inst->unique_id()); + + uint32_t type_id = def->type_id(); + auto type_def = get_def_use_mgr()->GetDef(type_id); + AddToWorklist(type_def); + + uint32_t undef_id = Type2Undef(type_id); + if (undef_id == 0) return false; + + auto undef_inst = get_def_use_mgr()->GetDef(undef_id); + live_insts_.Set(undef_inst->unique_id()); + inst->SetInOperand(var_operand_idx, {undef_id}); + context()->get_def_use_mgr()->AnalyzeInstUse(inst); + + id = inst->GetSingleWordInOperand(kDebugValueLocalVariableInIdx); + auto localVar = get_def_use_mgr()->GetDef(id); + AddToWorklist(localVar); + + uint32_t expr_idx = kDebugValueExpressionInIdx; + id = inst->GetSingleWordInOperand(expr_idx); + auto expression = get_def_use_mgr()->GetDef(id); + AddToWorklist(expression); + + for (uint32_t i = expr_idx + 1; i < inst->NumInOperands(); ++i) { + id = inst->GetSingleWordInOperand(i); + auto index_def = get_def_use_mgr()->GetDef(id); + if (index_def) { + AddToWorklist(index_def); + } + } + + for (auto& line_inst : inst->dbg_line_insts()) { + if (line_inst.IsDebugLineInst()) { + AddToWorklist(&line_inst); + } + } } return true; }); @@ -731,13 +787,16 @@ Pass::Status AggressiveDCEPass::InitializeModuleScopeLiveInstructions() { AddToWorklist(dbg_none); } - // Add top level DebugInfo to worklist + // Add DebugInfo which should never be eliminated to worklist for (auto& dbg : get_module()->ext_inst_debuginfo()) { auto op = dbg.GetShader100DebugOpcode(); if (op == NonSemanticShaderDebugInfo100DebugCompilationUnit || op == NonSemanticShaderDebugInfo100DebugEntryPoint || op == NonSemanticShaderDebugInfo100DebugSource || - op == NonSemanticShaderDebugInfo100DebugSourceContinued) { + op == NonSemanticShaderDebugInfo100DebugSourceContinued || + op == NonSemanticShaderDebugInfo100DebugLocalVariable || + op == NonSemanticShaderDebugInfo100DebugExpression || + op == NonSemanticShaderDebugInfo100DebugOperation) { AddToWorklist(&dbg); } } @@ -813,7 +872,9 @@ Pass::Status AggressiveDCEPass::ProcessImpl() { // Cleanup all CFG including all unreachable blocks. for (Function& fp : *context()->module()) { - modified |= CFGCleanup(&fp); + auto status = CFGCleanup(&fp); + if (status == Status::Failure) return Status::Failure; + if (status == Status::SuccessWithChange) modified = true; } return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; diff --git a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp index 26fed89fb..6cd047961 100644 --- a/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/cfg_cleanup_pass.cpp @@ -25,8 +25,17 @@ namespace opt { Pass::Status CFGCleanupPass::Process() { // Process all entry point functions. - ProcessFunction pfn = [this](Function* fp) { return CFGCleanup(fp); }; + bool failure = false; + ProcessFunction pfn = [this, &failure](Function* fp) { + auto status = CFGCleanup(fp); + if (status == Status::Failure) { + failure = true; + return false; + } + return status == Status::SuccessWithChange; + }; bool modified = context()->ProcessReachableCallTree(pfn); + if (failure) return Pass::Status::Failure; return modified ? Pass::Status::SuccessWithChange : Pass::Status::SuccessWithoutChange; } diff --git a/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp b/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp index 734c96715..ec90d97c9 100644 --- a/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp +++ b/3rdparty/spirv-tools/source/opt/combine_access_chains.cpp @@ -27,36 +27,48 @@ Pass::Status CombineAccessChains::Process() { bool modified = false; for (auto& function : *get_module()) { - modified |= ProcessFunction(function); + auto status = ProcessFunction(function); + if (status == Status::Failure) return Status::Failure; + if (status == Status::SuccessWithChange) modified = true; } return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } -bool CombineAccessChains::ProcessFunction(Function& function) { +Pass::Status CombineAccessChains::ProcessFunction(Function& function) { if (function.IsDeclaration()) { - return false; + return Status::SuccessWithoutChange; } bool modified = false; + bool failure = false; cfg()->ForEachBlockInReversePostOrder( - function.entry().get(), [&modified, this](BasicBlock* block) { - block->ForEachInst([&modified, this](Instruction* inst) { + function.entry().get(), [&modified, &failure, this](BasicBlock* block) { + if (failure) return; + block->ForEachInst([&modified, &failure, this](Instruction* inst) { + if (failure) return; switch (inst->opcode()) { case spv::Op::OpAccessChain: case spv::Op::OpInBoundsAccessChain: case spv::Op::OpPtrAccessChain: - case spv::Op::OpInBoundsPtrAccessChain: - modified |= CombineAccessChain(inst); + case spv::Op::OpInBoundsPtrAccessChain: { + auto status = CombineAccessChain(inst); + if (status == Status::Failure) { + failure = true; + } else if (status == Status::SuccessWithChange) { + modified = true; + } break; + } default: break; } }); }); - return modified; + if (failure) return Status::Failure; + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } uint32_t CombineAccessChains::GetConstantValue( @@ -121,9 +133,9 @@ const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { return type; } -bool CombineAccessChains::CombineIndices(Instruction* ptr_input, - Instruction* inst, - std::vector* new_operands) { +Pass::Status CombineAccessChains::CombineIndices( + Instruction* ptr_input, Instruction* inst, + std::vector* new_operands) { analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); @@ -150,28 +162,30 @@ bool CombineAccessChains::CombineIndices(Instruction* ptr_input, GetConstantValue(element_constant); const analysis::Constant* new_value_constant = constant_mgr->GetConstant(last_index_constant->type(), {new_value}); + if (!new_value_constant) return Status::Failure; Instruction* new_value_inst = constant_mgr->GetDefiningInstruction(new_value_constant); + if (!new_value_inst) return Status::Failure; new_value_id = new_value_inst->result_id(); } else if (!type->AsStruct() || combining_element_operands) { // Generate an addition of the two indices. InstructionBuilder builder( context(), inst, IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); - // TODO(1841): Handle id overflow. Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), last_index_inst->result_id(), element_inst->result_id()); + if (!addition) return Status::Failure; new_value_id = addition->result_id(); } else { // Indexing into structs must be constant, so bail out here. - return false; + return Status::SuccessWithoutChange; } new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); - return true; + return Status::SuccessWithChange; } -bool CombineAccessChains::CreateNewInputOperands( +Pass::Status CombineAccessChains::CreateNewInputOperands( Instruction* ptr_input, Instruction* inst, std::vector* new_operands) { // Start by copying all the input operands of the feeder access chain. @@ -183,7 +197,8 @@ bool CombineAccessChains::CreateNewInputOperands( if (IsPtrAccessChain(inst->opcode())) { // The last index of the feeder should be combined with the element operand // of |inst|. - if (!CombineIndices(ptr_input, inst, new_operands)) return false; + auto status = CombineIndices(ptr_input, inst, new_operands); + if (status != Status::SuccessWithChange) return status; } else { // The indices aren't being combined so now add the last index operand of // |ptr_input|. @@ -197,10 +212,10 @@ bool CombineAccessChains::CreateNewInputOperands( new_operands->push_back(inst->GetInOperand(i)); } - return true; + return Status::SuccessWithChange; } -bool CombineAccessChains::CombineAccessChain(Instruction* inst) { +Pass::Status CombineAccessChains::CombineAccessChain(Instruction* inst) { assert((inst->opcode() == spv::Op::OpPtrAccessChain || inst->opcode() == spv::Op::OpAccessChain || inst->opcode() == spv::Op::OpInBoundsAccessChain || @@ -213,10 +228,11 @@ bool CombineAccessChains::CombineAccessChain(Instruction* inst) { ptr_input->opcode() != spv::Op::OpInBoundsAccessChain && ptr_input->opcode() != spv::Op::OpPtrAccessChain && ptr_input->opcode() != spv::Op::OpInBoundsPtrAccessChain) { - return false; + return Status::SuccessWithoutChange; } - if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; + if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) + return Status::SuccessWithoutChange; // Handles the following cases: // 1. |ptr_input| is an index-less access chain. Replace the pointer @@ -238,7 +254,7 @@ bool CombineAccessChains::CombineAccessChain(Instruction* inst) { // size/alignment of the type and converting the stride into an element // index. uint32_t array_stride = GetArrayStride(ptr_input); - if (array_stride != 0) return false; + if (array_stride != 0) return Status::SuccessWithoutChange; if (ptr_input->NumInOperands() == 1) { // The input is effectively a no-op. @@ -250,14 +266,15 @@ bool CombineAccessChains::CombineAccessChain(Instruction* inst) { inst->SetOpcode(spv::Op::OpCopyObject); } else { std::vector new_operands; - if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; + auto status = CreateNewInputOperands(ptr_input, inst, &new_operands); + if (status != Status::SuccessWithChange) return status; // Update the instruction. inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); inst->SetInOperands(std::move(new_operands)); context()->AnalyzeUses(inst); } - return true; + return Status::SuccessWithChange; } spv::Op CombineAccessChains::UpdateOpcode(spv::Op base_opcode, diff --git a/3rdparty/spirv-tools/source/opt/combine_access_chains.h b/3rdparty/spirv-tools/source/opt/combine_access_chains.h index 32ee50d30..1872720d7 100644 --- a/3rdparty/spirv-tools/source/opt/combine_access_chains.h +++ b/3rdparty/spirv-tools/source/opt/combine_access_chains.h @@ -40,12 +40,12 @@ class CombineAccessChains : public Pass { private: // Combine access chains in |function|. Blocks are processed in reverse // post-order. Returns true if the function is modified. - bool ProcessFunction(Function& function); + Status ProcessFunction(Function& function); // Combines an access chain (normal, in bounds or pointer) |inst| if its base // pointer is another access chain. Returns true if the access chain was // modified. - bool CombineAccessChain(Instruction* inst); + Status CombineAccessChain(Instruction* inst); // Returns the value of |constant_inst| as a uint32_t. uint32_t GetConstantValue(const analysis::Constant* constant_inst); @@ -59,13 +59,13 @@ class CombineAccessChains : public Pass { // Populates |new_operands| with the operands for the combined access chain. // Returns false if the access chains cannot be combined. - bool CreateNewInputOperands(Instruction* ptr_input, Instruction* inst, - std::vector* new_operands); + Status CreateNewInputOperands(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); // Combines the last index of |ptr_input| with the element operand of |inst|. // Adds the combined operand to |new_operands|. - bool CombineIndices(Instruction* ptr_input, Instruction* inst, - std::vector* new_operands); + Status CombineIndices(Instruction* ptr_input, Instruction* inst, + std::vector* new_operands); // Returns the opcode to use for the combined access chain. spv::Op UpdateOpcode(spv::Op base_opcode, spv::Op input_opcode); diff --git a/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp b/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp index 0f4e44045..b7a69bc3c 100644 --- a/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp +++ b/3rdparty/spirv-tools/source/opt/const_folding_rules.cpp @@ -1126,6 +1126,26 @@ ConstantFoldingRule FoldFUnordGreaterThanEqual() { return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); } +ConstantFoldingRule FoldInvariantSelect() { + return [](IRContext*, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + assert(inst->opcode() == spv::Op::OpSelect); + (void)inst; + + if (!constants[1] || !constants[2]) { + return nullptr; + } + if (constants[1] == constants[2]) { + return constants[1]; + } + if (constants[1]->IsZero() && constants[2]->IsZero()) { + return constants[1]; + } + return nullptr; + }; +} + // Folds an OpDot where all of the inputs are constants to a // constant. A new constant is created if necessary. ConstantFoldingRule FoldOpDotWithConstants() { @@ -1435,6 +1455,18 @@ ConstantFoldingRule FoldFMix() { }; } +template +static bool NegZeroAwareLessThan(FloatType a, FloatType b) { + if (a == 0.0 && b == 0.0) { + bool sba = std::signbit(a); + bool sbb = std::signbit(b); + if (sba && !sbb) { + return true; + } + } + return a < b; +} + const analysis::Constant* FoldMin(const analysis::Type* result_type, const analysis::Constant* a, const analysis::Constant* b, @@ -1480,11 +1512,11 @@ const analysis::Constant* FoldMin(const analysis::Type* result_type, if (float_type->width() == 32) { float va = a->GetFloat(); float vb = b->GetFloat(); - return (va < vb ? a : b); + return NegZeroAwareLessThan(va, vb) ? a : b; } else if (float_type->width() == 64) { double va = a->GetDouble(); double vb = b->GetDouble(); - return (va < vb ? a : b); + return NegZeroAwareLessThan(va, vb) ? a : b; } } return nullptr; @@ -1535,11 +1567,71 @@ const analysis::Constant* FoldMax(const analysis::Type* result_type, if (float_type->width() == 32) { float va = a->GetFloat(); float vb = b->GetFloat(); - return (va > vb ? a : b); + return NegZeroAwareLessThan(vb, va) ? a : b; } else if (float_type->width() == 64) { double va = a->GetDouble(); double vb = b->GetDouble(); - return (va > vb ? a : b); + return NegZeroAwareLessThan(vb, va) ? a : b; + } + } + return nullptr; +} + +const analysis::Constant* FoldNMin(const analysis::Type* result_type, + const analysis::Constant* a, + const analysis::Constant* b, + analysis::ConstantManager*) { + if (const analysis::Float* float_type = result_type->AsFloat()) { + if (float_type->width() == 32) { + float va = a->GetFloat(); + float vb = b->GetFloat(); + if (std::isnan(va)) { + return b; + } + if (std::isnan(vb)) { + return a; + } + return NegZeroAwareLessThan(va, vb) ? a : b; + } else if (float_type->width() == 64) { + double va = a->GetDouble(); + double vb = b->GetDouble(); + if (std::isnan(va)) { + return b; + } + if (std::isnan(vb)) { + return a; + } + return NegZeroAwareLessThan(va, vb) ? a : b; + } + } + return nullptr; +} + +const analysis::Constant* FoldNMax(const analysis::Type* result_type, + const analysis::Constant* a, + const analysis::Constant* b, + analysis::ConstantManager*) { + if (const analysis::Float* float_type = result_type->AsFloat()) { + if (float_type->width() == 32) { + float va = a->GetFloat(); + float vb = b->GetFloat(); + if (std::isnan(va)) { + return b; + } + if (std::isnan(vb)) { + return a; + } + return NegZeroAwareLessThan(vb, va) ? a : b; + } else if (float_type->width() == 64) { + double va = a->GetDouble(); + double vb = b->GetDouble(); + if (std::isnan(va)) { + return b; + } + if (std::isnan(vb)) { + return a; + } + return NegZeroAwareLessThan(vb, va) ? a : b; } } return nullptr; @@ -1627,6 +1719,88 @@ const analysis::Constant* FoldClamp3( return nullptr; } +// Fold an clamp instruction when all three operands are constant. +const analysis::Constant* FoldNClamp1( + IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpExtInst && + "Expecting an extended instruction."); + assert(inst->GetSingleWordInOperand(0) == + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && + "Expecting a GLSLstd450 extended instruction."); + + // Make sure all Clamp operands are constants. + for (uint32_t i = 1; i < 4; i++) { + if (constants[i] == nullptr) { + return nullptr; + } + } + + const analysis::Constant* temp = FoldFPBinaryOp( + FoldNMax, inst->type_id(), {constants[1], constants[2]}, context); + if (temp == nullptr) { + return nullptr; + } + return FoldFPBinaryOp(FoldNMin, inst->type_id(), {temp, constants[3]}, + context); +} + +// Fold a clamp instruction when |x <= min_val|. +const analysis::Constant* FoldNClamp2( + IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpExtInst && + "Expecting an extended instruction."); + assert(inst->GetSingleWordInOperand(0) == + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && + "Expecting a GLSLstd450 extended instruction."); + + const analysis::Constant* x = constants[1]; + const analysis::Constant* min_val = constants[2]; + + if (x == nullptr || min_val == nullptr) { + return nullptr; + } + + const analysis::Constant* temp = + FoldFPBinaryOp(FoldNMax, inst->type_id(), {x, min_val}, context); + if (temp == min_val) { + // We can assume that |min_val| is less than |max_val|. Therefore, if the + // result of the max operation is |min_val|, we know the result of the min + // operation, even if |max_val| is not a constant. + return min_val; + } + return nullptr; +} + +// Fold a clamp instruction when |x >= max_val|. +const analysis::Constant* FoldNClamp3( + IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpExtInst && + "Expecting an extended instruction."); + assert(inst->GetSingleWordInOperand(0) == + context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && + "Expecting a GLSLstd450 extended instruction."); + + const analysis::Constant* x = constants[1]; + const analysis::Constant* max_val = constants[3]; + + if (x == nullptr || max_val == nullptr) { + return nullptr; + } + + const analysis::Constant* temp = + FoldFPBinaryOp(FoldNMin, inst->type_id(), {x, max_val}, context); + if (temp == max_val) { + // We can assume that |min_val| is less than |max_val|. Therefore, if the + // result of the max operation is |min_val|, we know the result of the min + // operation, even if |max_val| is not a constant. + return max_val; + } + return nullptr; +} + UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) { return [fp](const analysis::Type* result_type, const analysis::Constant* a, @@ -1775,6 +1949,8 @@ void ConstantFoldingRules::AddFoldingRules() { rules_[spv::Op::OpFMul].push_back(FoldFMul()); rules_[spv::Op::OpFSub].push_back(FoldFSub()); + rules_[spv::Op::OpSelect].push_back(FoldInvariantSelect()); + rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual()); rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual()); @@ -1878,12 +2054,16 @@ void ConstantFoldingRules::AddFoldingRules() { FoldFPBinaryOp(FoldMin)); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back( FoldFPBinaryOp(FoldMin)); + ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMin}].push_back( + FoldFPBinaryOp(FoldNMin)); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back( FoldFPBinaryOp(FoldMax)); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back( FoldFPBinaryOp(FoldMax)); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back( FoldFPBinaryOp(FoldMax)); + ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMax}].push_back( + FoldFPBinaryOp(FoldNMax)); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( FoldClamp1); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( @@ -1902,6 +2082,12 @@ void ConstantFoldingRules::AddFoldingRules() { FoldClamp2); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( FoldClamp3); + ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( + FoldNClamp1); + ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( + FoldNClamp2); + ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( + FoldNClamp3); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back( FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin))); ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back( diff --git a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp index 3078a7c3b..547a5e479 100644 --- a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp +++ b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.cpp @@ -104,10 +104,17 @@ Pass::Status CopyPropagateArrays::Process() { continue; } - if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) { + uint32_t pointer_type_id = source_object->GetPointerTypeId(this); + if (pointer_type_id == 0) { + return Status::Failure; + } + + if (CanUpdateUses(&*var_inst, pointer_type_id)) { modified = true; - PropagateObject(&*var_inst, source_object.get(), store_inst); + if (!PropagateObject(&*var_inst, source_object.get(), store_inst)) { + return Status::Failure; + } } } @@ -170,15 +177,16 @@ Instruction* CopyPropagateArrays::FindStoreInstruction( return store_inst; } -void CopyPropagateArrays::PropagateObject(Instruction* var_inst, +bool CopyPropagateArrays::PropagateObject(Instruction* var_inst, MemoryObject* source, Instruction* insertion_point) { assert(var_inst->opcode() == spv::Op::OpVariable && "This function propagates variables."); Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source); + if (!new_access_chain) return false; context()->KillNamesAndDecorates(var_inst); - UpdateUses(var_inst, new_access_chain); + return UpdateUses(var_inst, new_access_chain); } Instruction* CopyPropagateArrays::BuildNewAccessChain( @@ -192,7 +200,7 @@ Instruction* CopyPropagateArrays::BuildNewAccessChain( return source->GetVariable(); } - source->BuildConstants(); + if (!source->BuildConstants()) return nullptr; std::vector access_ids(source->AccessChain().size()); std::transform( source->AccessChain().cbegin(), source->AccessChain().cend(), @@ -642,7 +650,7 @@ bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst, }); } -void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, +bool CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, Instruction* new_ptr_inst) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); @@ -699,6 +707,7 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1)); auto* deref_expr_instr = context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr); + if (!deref_expr_instr) return false; use->SetOperand(index + 1, {deref_expr_instr->result_id()}); context()->AnalyzeUses(deref_expr_instr); @@ -783,6 +792,8 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, uint32_t new_pointer_type_id = type_mgr->FindPointerToType(new_pointee_type_id, storage_class); + if (new_pointer_type_id == 0) return false; + if (new_pointer_type_id != use->type_id()) { use->SetResultType(new_pointer_type_id); context()->AnalyzeUses(use); @@ -829,8 +840,7 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, uint32_t pointee_type_id = pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx); uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use); - assert(copy != 0 && - "Should not be updating uses unless we know it can be done."); + if (copy == 0) return false; context()->ForgetUses(use); use->SetInOperand(index, {copy}); @@ -852,6 +862,7 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, break; } } + return true; } uint32_t CopyPropagateArrays::GetMemberTypeId( @@ -955,7 +966,7 @@ bool CopyPropagateArrays::MemoryObject::Contains( return true; } -void CopyPropagateArrays::MemoryObject::BuildConstants() { +bool CopyPropagateArrays::MemoryObject::BuildConstants() { for (auto& entry : access_chain_) { if (entry.is_result_id) { continue; @@ -968,10 +979,13 @@ void CopyPropagateArrays::MemoryObject::BuildConstants() { analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Constant* index_const = const_mgr->GetConstant(uint32_type, {entry.immediate}); - entry.result_id = - const_mgr->GetDefiningInstruction(index_const)->result_id(); + if (!index_const) return false; + Instruction* constant_inst = const_mgr->GetDefiningInstruction(index_const); + if (!constant_inst) return false; + entry.result_id = constant_inst->result_id(); entry.is_result_id = true; } + return true; } } // namespace opt diff --git a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h index bf4bfb5c5..cb04a1435 100644 --- a/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h +++ b/3rdparty/spirv-tools/source/opt/copy_prop_arrays.h @@ -118,7 +118,8 @@ class CopyPropagateArrays : public MemPass { // Converts all immediate values in the AccessChain their OpConstant // equivalent. - void BuildConstants(); + // Returns false if the constants could not be created. + bool BuildConstants(); // Returns the type id of the pointer type that can be used to point to this // memory object. @@ -175,7 +176,8 @@ class CopyPropagateArrays : public MemPass { // Replaces all loads of |var_inst| with a load from |source| instead. // |insertion_pos| is a position where it is possible to construct the // address of |source| and also dominates all of the loads of |var_inst|. - void PropagateObject(Instruction* var_inst, MemoryObject* source, + // Returns false if the propagation failed. + bool PropagateObject(Instruction* var_inst, MemoryObject* source, Instruction* insertion_pos); // Returns true if all of the references to |ptr_inst| can be rewritten and @@ -241,7 +243,7 @@ class CopyPropagateArrays : public MemPass { // types of other instructions as needed. This function should not be called // if |CanUpdateUses(original_ptr_inst, new_pointer_inst->type_id())| returns // false. - void UpdateUses(Instruction* original_ptr_inst, + bool UpdateUses(Instruction* original_ptr_inst, Instruction* new_pointer_inst); // Return true if |UpdateUses| is able to change all of the uses of diff --git a/3rdparty/spirv-tools/source/opt/debug_info_manager.cpp b/3rdparty/spirv-tools/source/opt/debug_info_manager.cpp index c084a6c64..4570113cf 100644 --- a/3rdparty/spirv-tools/source/opt/debug_info_manager.cpp +++ b/3rdparty/spirv-tools/source/opt/debug_info_manager.cpp @@ -331,6 +331,7 @@ Instruction* DebugInfoManager::GetDebugOperationWithDeref() { if (deref_operation_ != nullptr) return deref_operation_; uint32_t result_id = context()->TakeNextId(); + if (result_id == 0) return nullptr; std::unique_ptr deref_operation; if (context()->get_feature_mgr()->GetExtInstImportId_OpenCL100DebugInfo()) { @@ -374,10 +375,13 @@ Instruction* DebugInfoManager::GetDebugOperationWithDeref() { Instruction* DebugInfoManager::DerefDebugExpression(Instruction* dbg_expr) { assert(dbg_expr->GetCommonDebugOpcode() == CommonDebugInfoDebugExpression); std::unique_ptr deref_expr(dbg_expr->Clone(context())); - deref_expr->SetResultId(context()->TakeNextId()); - deref_expr->InsertOperand( - kDebugExpressOperandOperationIndex, - {SPV_OPERAND_TYPE_ID, {GetDebugOperationWithDeref()->result_id()}}); + uint32_t result_id = context()->TakeNextId(); + if (result_id == 0) return nullptr; + deref_expr->SetResultId(result_id); + Instruction* deref_op = GetDebugOperationWithDeref(); + if (!deref_op) return nullptr; + deref_expr->InsertOperand(kDebugExpressOperandOperationIndex, + {SPV_OPERAND_TYPE_ID, {deref_op->result_id()}}); auto* deref_expr_instr = context()->ext_inst_debuginfo_end()->InsertBefore(std::move(deref_expr)); AnalyzeDebugInst(deref_expr_instr); diff --git a/3rdparty/spirv-tools/source/opt/folding_rules.cpp b/3rdparty/spirv-tools/source/opt/folding_rules.cpp index 669038166..ecdbf85a3 100644 --- a/3rdparty/spirv-tools/source/opt/folding_rules.cpp +++ b/3rdparty/spirv-tools/source/opt/folding_rules.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include "ir_builder.h" @@ -933,6 +934,42 @@ FoldingRule MergeMulNegateArithmetic() { }; } +// Returns true if |inst| is negation op and is safe to fold. +static bool IsFoldableNegation(const Instruction* inst) { + return (inst->opcode() == spv::Op::OpSNegate || + (inst->opcode() == spv::Op::OpFNegate && + inst->IsFloatingPointFoldingAllowed())); +} + +// Merges multiplies of two negations. +// Cases: +// (-x) * (-y) = x * y +FoldingRule MergeMulDoubleNegative() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == spv::Op::OpFMul || + inst->opcode() == spv::Op::OpIMul); + + const analysis::Type* type = + context->get_type_mgr()->GetType(inst->type_id()); + + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + + if (IsFoldableNegation(lhs) && IsFoldableNegation(rhs)) { + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0u)}}, + {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0u)}}}); + return true; + } + return false; + }; +} + // Merges consecutive divides if each instruction contains one constant operand. // Does not support integer division. // Cases: @@ -1125,13 +1162,12 @@ FoldingRule MergeDivNegateArithmetic() { }; } -// Folds addition of a constant and a negation. -// Cases: -// (-x) + 2 = 2 - x -// 2 + (-x) = 2 - x +// Folds addition, where one side is a negation. +// (-x) + y = y - x +// y + (-x) = y - x FoldingRule MergeAddNegateArithmetic() { return [](IRContext* context, Instruction* inst, - const std::vector& constants) { + const std::vector&) { assert(inst->opcode() == spv::Op::OpFAdd || inst->opcode() == spv::Op::OpIAdd); const analysis::Type* type = @@ -1139,73 +1175,65 @@ FoldingRule MergeAddNegateArithmetic() { bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; - const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; - Instruction* other_inst = NonConstInput(context, constants[0], inst); - if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) - return false; + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); - if (other_inst->opcode() == spv::Op::OpSNegate || - other_inst->opcode() == spv::Op::OpFNegate) { - inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub - : spv::Op::OpISub); - uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u) - : inst->GetSingleWordInOperand(1u); - inst->SetInOperands( - {{SPV_OPERAND_TYPE_ID, {const_id}}, - {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); - return true; - } - return false; + auto TrySubstitute = [inst, uses_float](Instruction* first, + Instruction* second) { + if (IsFoldableNegation(first)) { + inst->SetOpcode(uses_float ? spv::Op::OpFSub : spv::Op::OpISub); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {second->result_id()}}, + {SPV_OPERAND_TYPE_ID, {first->GetSingleWordInOperand(0u)}}}); + return true; + } + return false; + }; + + return TrySubstitute(lhs, rhs) || TrySubstitute(rhs, lhs); }; } -// Folds subtraction of a constant and a negation. +// Folds subtraction, where one side is a negation. // Cases: // (-x) - 2 = -2 - x -// 2 - (-x) = x + 2 +// y - (-x) = x + y FoldingRule MergeSubNegateArithmetic() { return [](IRContext* context, Instruction* inst, const std::vector& constants) { assert(inst->opcode() == spv::Op::OpFSub || inst->opcode() == spv::Op::OpISub); - analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); + bool uses_float = HasFloatingPoint(type); + if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); + Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); + + if (IsFoldableNegation(rhs)) { + inst->SetOpcode(uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {lhs->result_id()}}, + {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0)}}}); + return true; + } + if (IsCooperativeMatrix(type)) { return false; } - bool uses_float = HasFloatingPoint(type); - if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; - uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; - const analysis::Constant* const_input1 = ConstInput(constants); - if (!const_input1) return false; - Instruction* other_inst = NonConstInput(context, constants[0], inst); - if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) - return false; - - if (other_inst->opcode() == spv::Op::OpSNegate || - other_inst->opcode() == spv::Op::OpFNegate) { - uint32_t op1 = 0; - uint32_t op2 = 0; - spv::Op opcode = inst->opcode(); - if (constants[0] != nullptr) { - op1 = other_inst->GetSingleWordInOperand(0u); - op2 = inst->GetSingleWordInOperand(0u); - opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd; - } else { - op1 = NegateConstant(const_mgr, const_input1); - op2 = other_inst->GetSingleWordInOperand(0u); - } - - inst->SetOpcode(opcode); + if (constants[1] && IsFoldableNegation(lhs)) { inst->SetInOperands( - {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); + {{SPV_OPERAND_TYPE_ID, + {NegateConstant(context->get_constant_mgr(), constants[1])}}, + {SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0)}}}); return true; } return false; @@ -1530,11 +1558,13 @@ FoldingRule MergeGenericAddSubArithmetic() { }; } -// Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|, -// generate |factor0_0| * (|factor0_1| + |factor1_1|). -bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, - uint32_t factor1_0, uint32_t factor1_1, - Instruction* inst) { +// Helper function for FactorAddSubMuls. +// If |factor0_0| is the same as |factor1_0|, generate: +// |factor0_0| * (|factor0_1| + |factor1_1|) +// |factor0_0| * (|factor0_1| - |factor1_1|) +bool FactorAddSubMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, + uint32_t factor1_0, uint32_t factor1_1, + Instruction* inst) { IRContext* context = inst->context(); if (factor0_0 != factor1_0) return false; InstructionBuilder ir_builder( @@ -1545,8 +1575,10 @@ bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, if (!new_add_inst) { return false; } - inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul - : spv::Op::OpIMul); + + bool is_float = + inst->opcode() == spv::Op::OpFAdd || inst->opcode() == spv::Op::OpFSub; + inst->SetOpcode(is_float ? spv::Op::OpFMul : spv::Op::OpIMul); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}}, {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}}); context->UpdateDefUse(inst); @@ -1554,12 +1586,16 @@ bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, } // Perform the following factoring identity, handling all operand order -// combinations: (a * b) + (a * c) = a * (b + c) -FoldingRule FactorAddMuls() { +// combinations: +// (a * b) + (a * c) = a * (b + c) +// (a * b) - (a * c) = a * (b - c) +FoldingRule FactorAddSubMuls() { return [](IRContext* context, Instruction* inst, const std::vector&) { assert(inst->opcode() == spv::Op::OpFAdd || - inst->opcode() == spv::Op::OpIAdd); + inst->opcode() == spv::Op::OpFSub || + inst->opcode() == spv::Op::OpIAdd || + inst->opcode() == spv::Op::OpISub); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); @@ -1590,11 +1626,11 @@ FoldingRule FactorAddMuls() { for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { // Check if operand i in add_op0_inst matches operand j in add_op1_inst. - if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i), - add_op0_inst->GetSingleWordInOperand(1 - i), - add_op1_inst->GetSingleWordInOperand(j), - add_op1_inst->GetSingleWordInOperand(1 - j), - inst)) + if (FactorAddSubMulsOpnds(add_op0_inst->GetSingleWordInOperand(i), + add_op0_inst->GetSingleWordInOperand(1 - i), + add_op1_inst->GetSingleWordInOperand(j), + add_op1_inst->GetSingleWordInOperand(1 - j), + inst)) return true; } } @@ -2296,6 +2332,39 @@ FoldingRule BitCastScalarOrVector() { }; } +// Remove indirect bitcasts which have no effect. +// uint32 x; asuint32(x) => x +// uint32 x; asuint32(asint32(x)) => x +// float32 x; asuint32(asint32(x)) => asuint32(x) +FoldingRule RedundantBitcast() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == spv::Op::OpBitcast); + + analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); + Instruction* child = def_mgr->GetDef(inst->GetSingleWordInOperand(0)); + + if (inst->type_id() == child->type_id()) { + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {child->result_id()}}}); + return true; + } + + if (child->opcode() != spv::Op::OpBitcast) { + return false; + } + + if (def_mgr->GetDef(child->GetSingleWordInOperand(0))->type_id() == + inst->type_id()) { + inst->SetOpcode(spv::Op::OpCopyObject); + } + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}}); + + return true; + }; +} + FoldingRule BitReverseScalarOrVector() { return [](IRContext* context, Instruction* inst, const std::vector& constants) { @@ -2410,6 +2479,250 @@ FoldingRule RedundantSelect() { }; } +std::optional GetBoolConstantKind(const analysis::Constant* c) { + if (!c) { + return {}; + } + if (auto composite = c->AsCompositeConstant()) { + auto& components = composite->GetComponents(); + if (components.empty()) { + return {}; + } + auto first = GetBoolConstantKind(components[0]); + if (!first) { + return {}; + } + if (std::all_of(std::begin(components) + 1, std::end(components), + [first](const analysis::Constant* c2) { + return GetBoolConstantKind(c2) == first; + })) { + return first; + } + return {}; + } else if (c->AsNullConstant()) { + return false; + } else if (c->AsBoolConstant()) { + return c->AsBoolConstant()->value(); + } + return {}; +} + +// Fold OpSelect instructions which have constant booleans as their result. +// x ? true : false = x +// x ? false : true = !x +FoldingRule FoldConstantBooleanSelect() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpSelect); + assert(inst->NumInOperands() == 3); + assert(constants.size() == 3); + + if (!constants[1] || !constants[2]) { + return false; + } + + analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); + if (inst->type_id() != + def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) { + return false; + } + + std::optional uniform_true = GetBoolConstantKind(constants[1]); + std::optional uniform_false = GetBoolConstantKind(constants[2]); + + if (!uniform_true || !uniform_false) { + return false; + } + + if (uniform_true.value() && !uniform_false.value()) { + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); + return true; + } else if (!uniform_true.value() && uniform_false.value()) { + inst->SetOpcode(spv::Op::OpLogicalNot); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); + return true; + } + return false; + }; +} + +// Fold OpLogicalAnd instructions which have a constant true on one side. +// x && true = x +// true && x = x +FoldingRule RedundantLogicalAnd() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpLogicalAnd); + + if (GetBoolConstantKind(ConstInput(constants)) == + std::optional(true)) { + Instruction* other_inst = NonConstInput(context, constants[0], inst); + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}}); + return true; + } + return false; + }; +} + +// Fold OpLogicalOr instructions which have a constant false on one side. +// x || false = x +// false || x = x +FoldingRule RedundantLogicalOr() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) { + assert(inst->opcode() == spv::Op::OpLogicalOr); + + if (GetBoolConstantKind(ConstInput(constants)) == + std::optional(false)) { + Instruction* other_inst = NonConstInput(context, constants[0], inst); + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}}); + return true; + } + return false; + }; +} + +// Fold concurrent OpLogicalNot instructions: +// !!x = x +FoldingRule RedundantLogicalNot() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == spv::Op::OpLogicalNot); + Instruction* child = + context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); + if (child->opcode() == spv::Op::OpLogicalNot) { + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}}); + return true; + } + return false; + }; +} + +// Fold OpLogicalNot instructions that follow a comparison, +// if the comparison is only used by that instruction. +// +// !(a == b) = (a != b) +// !(a != b) = (a == b) +// !(a < b) = (a >= b) +// !(a >= b) = (a < b) +// !(a > b) = (a <= b) +// !(a <= b) = (a > b) +FoldingRule FoldLogicalNotComparison() { + return [](IRContext* context, Instruction* inst, + const std::vector&) { + assert(inst->opcode() == spv::Op::OpLogicalNot); + analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); + Instruction* child = + context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); + + if (def_mgr->NumUses(child) > 1) { + return false; + } + + spv::Op new_opcode = spv::Op::OpNop; + switch (child->opcode()) { + // (a == b) <=> (a != b) + case spv::Op::OpIEqual: + new_opcode = spv::Op::OpINotEqual; + break; + case spv::Op::OpINotEqual: + new_opcode = spv::Op::OpIEqual; + break; + case spv::Op::OpFOrdEqual: + new_opcode = spv::Op::OpFUnordNotEqual; + break; + case spv::Op::OpFOrdNotEqual: + new_opcode = spv::Op::OpFUnordEqual; + break; + case spv::Op::OpFUnordEqual: + new_opcode = spv::Op::OpFOrdNotEqual; + break; + case spv::Op::OpFUnordNotEqual: + new_opcode = spv::Op::OpFOrdEqual; + break; + case spv::Op::OpLogicalEqual: + new_opcode = spv::Op::OpLogicalNotEqual; + break; + case spv::Op::OpLogicalNotEqual: + new_opcode = spv::Op::OpLogicalEqual; + break; + + // (a > b) <=> (a <= b) + case spv::Op::OpUGreaterThan: + new_opcode = spv::Op::OpULessThanEqual; + break; + case spv::Op::OpULessThanEqual: + new_opcode = spv::Op::OpUGreaterThan; + break; + case spv::Op::OpSGreaterThan: + new_opcode = spv::Op::OpSLessThanEqual; + break; + case spv::Op::OpSLessThanEqual: + new_opcode = spv::Op::OpSGreaterThan; + break; + case spv::Op::OpFOrdGreaterThan: + new_opcode = spv::Op::OpFUnordLessThanEqual; + break; + case spv::Op::OpFOrdLessThanEqual: + new_opcode = spv::Op::OpFUnordGreaterThan; + break; + case spv::Op::OpFUnordGreaterThan: + new_opcode = spv::Op::OpFOrdLessThanEqual; + break; + case spv::Op::OpFUnordLessThanEqual: + new_opcode = spv::Op::OpFOrdGreaterThan; + break; + + // (a < b) <=> (a >= b) + case spv::Op::OpULessThan: + new_opcode = spv::Op::OpUGreaterThanEqual; + break; + case spv::Op::OpUGreaterThanEqual: + new_opcode = spv::Op::OpULessThan; + break; + case spv::Op::OpSLessThan: + new_opcode = spv::Op::OpSGreaterThanEqual; + break; + case spv::Op::OpSGreaterThanEqual: + new_opcode = spv::Op::OpSLessThan; + break; + case spv::Op::OpFOrdLessThan: + new_opcode = spv::Op::OpFUnordGreaterThanEqual; + break; + case spv::Op::OpFOrdGreaterThanEqual: + new_opcode = spv::Op::OpFUnordLessThan; + break; + case spv::Op::OpFUnordLessThan: + new_opcode = spv::Op::OpFOrdGreaterThanEqual; + break; + case spv::Op::OpFUnordGreaterThanEqual: + new_opcode = spv::Op::OpFOrdLessThan; + break; + + default: + break; + } + + if (new_opcode == spv::Op::OpNop) { + return false; + } + + inst->SetOpcode(new_opcode); + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}, + {SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(1)}}}); + + return true; + }; +} + enum class FloatConstantKind { Unknown, Zero, One }; FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { @@ -3394,6 +3707,8 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpUMod].push_back(RedundantSUMod()); rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector()); + rules_[spv::Op::OpBitcast].push_back(RedundantBitcast()); + rules_[spv::Op::OpBitReverse].push_back(BitReverseScalarOrVector()); rules_[spv::Op::OpCompositeConstruct].push_back( @@ -3417,7 +3732,7 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic()); rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic()); rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic()); - rules_[spv::Op::OpFAdd].push_back(FactorAddMuls()); + rules_[spv::Op::OpFAdd].push_back(FactorAddSubMuls()); rules_[spv::Op::OpFDiv].push_back(RedundantFDiv()); rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv()); @@ -3431,6 +3746,7 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic()); rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic()); rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic()); + rules_[spv::Op::OpFMul].push_back(MergeMulDoubleNegative()); rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic()); rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic()); @@ -3440,20 +3756,23 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic()); rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic()); rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic()); + rules_[spv::Op::OpFSub].push_back(FactorAddSubMuls()); rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic()); rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic()); rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic()); rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic()); - rules_[spv::Op::OpIAdd].push_back(FactorAddMuls()); + rules_[spv::Op::OpIAdd].push_back(FactorAddSubMuls()); rules_[spv::Op::OpIMul].push_back(IntMultipleBy1()); rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic()); rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic()); + rules_[spv::Op::OpIMul].push_back(MergeMulDoubleNegative()); rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic()); rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic()); rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic()); + rules_[spv::Op::OpISub].push_back(FactorAddSubMuls()); rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndOrXor()); rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndAddSub()); @@ -3466,6 +3785,14 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic()); rules_[spv::Op::OpSelect].push_back(RedundantSelect()); + rules_[spv::Op::OpSelect].push_back(FoldConstantBooleanSelect()); + + rules_[spv::Op::OpLogicalAnd].push_back(RedundantLogicalAnd()); + + rules_[spv::Op::OpLogicalOr].push_back(RedundantLogicalOr()); + + rules_[spv::Op::OpLogicalNot].push_back(RedundantLogicalNot()); + rules_[spv::Op::OpLogicalNot].push_back(FoldLogicalNotComparison()); rules_[spv::Op::OpStore].push_back(StoringUndef()); diff --git a/3rdparty/spirv-tools/source/opt/ir_context.cpp b/3rdparty/spirv-tools/source/opt/ir_context.cpp index 2ce9e8504..88b1f2ec2 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.cpp +++ b/3rdparty/spirv-tools/source/opt/ir_context.cpp @@ -183,6 +183,8 @@ Instruction* IRContext::KillInst(Instruction* inst) { KillOperandFromDebugInstructions(inst); + KillRelatedDebugScopes(inst); + if (AreAnalysesValid(kAnalysisDefUse)) { analysis::DefUseManager* def_use_mgr = get_def_use_mgr(); def_use_mgr->ClearInst(inst); @@ -532,6 +534,20 @@ void IRContext::KillOperandFromDebugInstructions(Instruction* inst) { } } +void IRContext::KillRelatedDebugScopes(Instruction* inst) { + // Extension has been fully unloaded, remove debug scope from every + // instruction. + if (inst->opcode() == spv::Op::OpExtInstImport) { + const std::string extension_name = inst->GetInOperand(0).AsString(); + if (extension_name == "NonSemantic.Shader.DebugInfo.100" || + extension_name == "OpenCL.DebugInfo.100") { + module()->ForEachInst([](Instruction* child) { + child->SetDebugScope(DebugScope(kNoDebugScope, kNoInlinedAt)); + }); + } + } +} + void IRContext::AddCombinatorsForCapability(uint32_t capability) { spv::Capability cap = spv::Capability(capability); if (cap == spv::Capability::Shader) { diff --git a/3rdparty/spirv-tools/source/opt/ir_context.h b/3rdparty/spirv-tools/source/opt/ir_context.h index 89e8cb052..f4a69fc4b 100644 --- a/3rdparty/spirv-tools/source/opt/ir_context.h +++ b/3rdparty/spirv-tools/source/opt/ir_context.h @@ -508,6 +508,9 @@ class IRContext { // Change operands of debug instruction to DebugInfoNone. void KillOperandFromDebugInstructions(Instruction* inst); + // Remove the debug scope from any instruction related to |inst|. + void KillRelatedDebugScopes(Instruction* inst); + // Returns the next unique id for use by an instruction. inline uint32_t TakeNextUniqueId() { assert(unique_id_ != std::numeric_limits::max()); diff --git a/3rdparty/spirv-tools/source/opt/mem_pass.cpp b/3rdparty/spirv-tools/source/opt/mem_pass.cpp index e4eb751cb..8da06683a 100644 --- a/3rdparty/spirv-tools/source/opt/mem_pass.cpp +++ b/3rdparty/spirv-tools/source/opt/mem_pass.cpp @@ -340,7 +340,7 @@ bool MemPass::IsTargetVar(uint32_t varId) { // %50 = OpUndef %int // [ ... ] // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15 -void MemPass::RemovePhiOperands( +bool MemPass::RemovePhiOperands( Instruction* phi, const std::unordered_set& reachable_blocks) { std::vector keep_operands; uint32_t type_id = 0; @@ -382,6 +382,7 @@ void MemPass::RemovePhiOperands( if (!undef_id) { type_id = arg_def_instr->type_id(); undef_id = Type2Undef(type_id); + if (undef_id == 0) return false; } keep_operands.push_back( Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id})); @@ -400,6 +401,7 @@ void MemPass::RemovePhiOperands( context()->ForgetUses(phi); phi->ReplaceOperands(keep_operands); context()->AnalyzeUses(phi); + return true; } void MemPass::RemoveBlock(Function::iterator* bi) { @@ -422,8 +424,8 @@ void MemPass::RemoveBlock(Function::iterator* bi) { *bi = bi->Erase(); } -bool MemPass::RemoveUnreachableBlocks(Function* func) { - if (func->IsDeclaration()) return false; +Pass::Status MemPass::RemoveUnreachableBlocks(Function* func) { + if (func->IsDeclaration()) return Status::SuccessWithoutChange; bool modified = false; // Mark reachable all blocks reachable from the function's entry block. @@ -469,9 +471,11 @@ bool MemPass::RemoveUnreachableBlocks(Function* func) { // If the block is reachable and has Phi instructions, remove all // operands from its Phi instructions that reference unreachable blocks. // If the block has no Phi instructions, this is a no-op. - block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) { - RemovePhiOperands(phi, reachable_blocks); - }); + bool success = + block.WhileEachPhiInst([&reachable_blocks, this](Instruction* phi) { + return RemovePhiOperands(phi, reachable_blocks); + }); + if (!success) return Status::Failure; } // Erase unreachable blocks. @@ -484,13 +488,11 @@ bool MemPass::RemoveUnreachableBlocks(Function* func) { } } - return modified; + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; } -bool MemPass::CFGCleanup(Function* func) { - bool modified = false; - modified |= RemoveUnreachableBlocks(func); - return modified; +Pass::Status MemPass::CFGCleanup(Function* func) { + return RemoveUnreachableBlocks(func); } void MemPass::CollectTargetVars(Function* func) { diff --git a/3rdparty/spirv-tools/source/opt/mem_pass.h b/3rdparty/spirv-tools/source/opt/mem_pass.h index aef9e5ffa..496286b5f 100644 --- a/3rdparty/spirv-tools/source/opt/mem_pass.h +++ b/3rdparty/spirv-tools/source/opt/mem_pass.h @@ -114,7 +114,7 @@ class MemPass : public Pass { void DCEInst(Instruction* inst, const std::function&); // Call all the cleanup helper functions on |func|. - bool CFGCleanup(Function* func); + Status CFGCleanup(Function* func); // Return true if |op| is supported decorate. inline bool IsNonTypeDecorate(spv::Op op) const { @@ -142,15 +142,15 @@ class MemPass : public Pass { bool HasOnlySupportedRefs(uint32_t varId); // Remove all the unreachable basic blocks in |func|. - bool RemoveUnreachableBlocks(Function* func); + Status RemoveUnreachableBlocks(Function* func); // Remove the block pointed by the iterator |*bi|. This also removes // all the instructions in the pointed-to block. void RemoveBlock(Function::iterator* bi); // Remove Phi operands in |phi| that are coming from blocks not in - // |reachable_blocks|. - void RemovePhiOperands( + // |reachable_blocks|. Returns false if it fails. + bool RemovePhiOperands( Instruction* phi, const std::unordered_set& reachable_blocks); diff --git a/3rdparty/spirv-tools/source/opt/pass.cpp b/3rdparty/spirv-tools/source/opt/pass.cpp index 08d76b5a0..ce37f3628 100644 --- a/3rdparty/spirv-tools/source/opt/pass.cpp +++ b/3rdparty/spirv-tools/source/opt/pass.cpp @@ -117,6 +117,9 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, // TODO(1841): Handle id overflow. Instruction* extract = ir_builder.AddCompositeExtract( original_element_type_id, object_to_copy->result_id(), {i}); + if (extract == nullptr) { + return 0; + } uint32_t new_id = GenerateCopy(extract, new_element_type_id, insertion_position); if (new_id == 0) { @@ -125,8 +128,12 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, element_ids.push_back(new_id); } - return ir_builder.AddCompositeConstruct(new_type_id, element_ids) - ->result_id(); + Instruction* construct = + ir_builder.AddCompositeConstruct(new_type_id, element_ids); + if (construct == nullptr) { + return 0; + } + return construct->result_id(); } case spv::Op::OpTypeStruct: { std::vector element_ids; @@ -136,6 +143,9 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, // TODO(1841): Handle id overflow. Instruction* extract = ir_builder.AddCompositeExtract( orig_member_type_id, object_to_copy->result_id(), {i}); + if (extract == nullptr) { + return 0; + } uint32_t new_id = GenerateCopy(extract, new_member_type_id, insertion_position); if (new_id == 0) { @@ -143,8 +153,12 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, } element_ids.push_back(new_id); } - return ir_builder.AddCompositeConstruct(new_type_id, element_ids) - ->result_id(); + Instruction* construct = + ir_builder.AddCompositeConstruct(new_type_id, element_ids); + if (construct == nullptr) { + return 0; + } + return construct->result_id(); } default: // If we do not have an aggregate type, then we have a problem. Either we diff --git a/3rdparty/spirv-tools/source/opt/value_number_table.cpp b/3rdparty/spirv-tools/source/opt/value_number_table.cpp index 8c33ab7fb..a93d33cae 100644 --- a/3rdparty/spirv-tools/source/opt/value_number_table.cpp +++ b/3rdparty/spirv-tools/source/opt/value_number_table.cpp @@ -183,29 +183,13 @@ uint32_t ValueNumberTable::AssignValueNumber(Instruction* inst) { } // Apply normal form, so a+b == b+a - switch (value_ins.opcode()) { - case spv::Op::OpIAdd: - case spv::Op::OpFAdd: - case spv::Op::OpIMul: - case spv::Op::OpFMul: - case spv::Op::OpDot: - case spv::Op::OpLogicalEqual: - case spv::Op::OpLogicalNotEqual: - case spv::Op::OpLogicalOr: - case spv::Op::OpLogicalAnd: - case spv::Op::OpIEqual: - case spv::Op::OpINotEqual: - case spv::Op::OpBitwiseOr: - case spv::Op::OpBitwiseXor: - case spv::Op::OpBitwiseAnd: - if (value_ins.GetSingleWordInOperand(0) > - value_ins.GetSingleWordInOperand(1)) { - value_ins.SetInOperands( - {{SPV_OPERAND_TYPE_ID, {value_ins.GetSingleWordInOperand(1)}}, - {SPV_OPERAND_TYPE_ID, {value_ins.GetSingleWordInOperand(0)}}}); - } - default: - break; + if (spvOpcodeIsCommutativeBinaryOperator(value_ins.opcode())) { + if (value_ins.GetSingleWordInOperand(0) > + value_ins.GetSingleWordInOperand(1)) { + value_ins.SetInOperands( + {{SPV_OPERAND_TYPE_ID, {value_ins.GetSingleWordInOperand(1)}}, + {SPV_OPERAND_TYPE_ID, {value_ins.GetSingleWordInOperand(0)}}}); + } } // Otherwise, we check if this value has been computed before. diff --git a/3rdparty/spirv-tools/source/val/validate_atomics.cpp b/3rdparty/spirv-tools/source/val/validate_atomics.cpp index 510960ba1..8cda07e04 100644 --- a/3rdparty/spirv-tools/source/val/validate_atomics.cpp +++ b/3rdparty/spirv-tools/source/val/validate_atomics.cpp @@ -235,7 +235,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) { if (!IsStorageClassAllowedByUniversalRules(storage_class)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << spvOpcodeString(opcode) - << ": storage class forbidden by universal validation rules."; + << ": Can not be used with storage class " + << spvtools::StorageClassToString(storage_class) + << " by universal validation rules"; } // Then Shader rules @@ -249,8 +251,10 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) { (storage_class != spv::StorageClass::PhysicalStorageBuffer) && (storage_class != spv::StorageClass::TaskPayloadWorkgroupEXT)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) - << _.VkErrorID(4686) << spvOpcodeString(opcode) - << ": Vulkan spec only allows storage classes for atomic to " + << _.VkErrorID(4686) << spvOpcodeString(opcode) << ": " + << spvtools::StorageClassToString(storage_class) + << " is not allowed, the Vulkan spec only allows storage " + "classes for atomic to " "be: Uniform, Workgroup, Image, StorageBuffer, " "PhysicalStorageBuffer or TaskPayloadWorkgroupEXT."; } @@ -335,8 +339,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) { (storage_class != spv::StorageClass::CrossWorkgroup) && (storage_class != spv::StorageClass::Generic)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) - << spvOpcodeString(opcode) - << ": storage class must be Function, Workgroup, " + << spvOpcodeString(opcode) << ": storage class is " + << spvtools::StorageClassToString(storage_class) + << ", but must be Function, Workgroup, " "CrossWorkGroup or Generic in the OpenCL environment."; } diff --git a/3rdparty/spirv-tools/source/val/validate_builtins.cpp b/3rdparty/spirv-tools/source/val/validate_builtins.cpp index c8586a7fc..f75a707b1 100644 --- a/3rdparty/spirv-tools/source/val/validate_builtins.cpp +++ b/3rdparty/spirv-tools/source/val/validate_builtins.cpp @@ -2955,9 +2955,24 @@ spv_result_t BuiltInsValidator::ValidateMeshBuiltinInterfaceRules( const Decoration& decoration, const Instruction& inst, spv::Op scalar_type, const Instruction& referenced_from_inst) { if (function_id_) { - if (execution_models_.count(spv::ExecutionModel::MeshEXT)) { + if (!execution_models_.count(spv::ExecutionModel::MeshEXT)) { + return SPV_SUCCESS; + } + + const spv::BuiltIn builtin = decoration.builtin(); + const bool is_topology = + builtin == spv::BuiltIn::PrimitiveTriangleIndicesEXT || + builtin == spv::BuiltIn::PrimitiveLineIndicesEXT || + builtin == spv::BuiltIn::PrimitivePointIndicesEXT; + + // These builtin have the ability to be an array with MeshEXT + // When an array, we need to make sure the array size lines up + std::map entry_interface_id_map; + const bool is_interface_var = + IsMeshInterfaceVar(inst, entry_interface_id_map); + + if (!is_topology) { bool is_block = false; - const spv::BuiltIn builtin = decoration.builtin(); static const std::unordered_map mesh_vuid_map = {{ @@ -2997,12 +3012,7 @@ spv_result_t BuiltInsValidator::ValidateMeshBuiltinInterfaceRules( << " within the MeshEXT Execution Model must also be " << "decorated with the PerPrimitiveEXT decoration. "; } - - // These builtin have the ability to be an array with MeshEXT - // When an array, we need to make sure the array size lines up - std::map entry_interface_id_map; - bool found = IsMeshInterfaceVar(inst, entry_interface_id_map); - if (found) { + if (is_interface_var) { for (const auto& id : entry_interface_id_map) { uint32_t entry_point_id = id.first; uint32_t interface_var_id = id.second; @@ -3025,6 +3035,86 @@ spv_result_t BuiltInsValidator::ValidateMeshBuiltinInterfaceRules( } } } + + if (is_interface_var && is_topology) { + for (const auto& id : entry_interface_id_map) { + uint32_t entry_point_id = id.first; + + uint64_t max_output_primitives = + _.GetOutputPrimitivesEXT(entry_point_id); + uint32_t underlying_type = 0; + if (spv_result_t error = + GetUnderlyingType(_, decoration, inst, &underlying_type)) { + return error; + } + + uint64_t primitive_array_dim = 0; + if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) { + underlying_type = _.FindDef(underlying_type)->word(3u); + if (!_.EvalConstantValUint64(underlying_type, &primitive_array_dim)) { + assert(0 && "Array type definition is corrupt"); + } + } + + const auto* modes = _.GetExecutionModes(entry_point_id); + if (builtin == spv::BuiltIn::PrimitiveTriangleIndicesEXT) { + if (!modes || !modes->count(spv::ExecutionMode::OutputTrianglesEXT)) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7054) + << "The PrimitiveTriangleIndicesEXT decoration must be used " + "with the OutputTrianglesEXT Execution Mode. "; + } + if (primitive_array_dim && + primitive_array_dim != max_output_primitives) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7058) + << "The size of the array decorated with " + "PrimitiveTriangleIndicesEXT (" + << primitive_array_dim + << ") must match the value specified " + "by OutputPrimitivesEXT (" + << max_output_primitives << "). "; + } + } else if (builtin == spv::BuiltIn::PrimitiveLineIndicesEXT) { + if (!modes || !modes->count(spv::ExecutionMode::OutputLinesEXT)) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7048) + << "The PrimitiveLineIndicesEXT decoration must be used " + "with the OutputLinesEXT Execution Mode. "; + } + if (primitive_array_dim && + primitive_array_dim != max_output_primitives) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7052) + << "The size of the array decorated with " + "PrimitiveLineIndicesEXT (" + << primitive_array_dim + << ") must match the value specified " + "by OutputPrimitivesEXT (" + << max_output_primitives << "). "; + } + + } else if (builtin == spv::BuiltIn::PrimitivePointIndicesEXT) { + if (!modes || !modes->count(spv::ExecutionMode::OutputPoints)) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7042) + << "The PrimitivePointIndicesEXT decoration must be used " + "with the OutputPoints Execution Mode. "; + } + if (primitive_array_dim && + primitive_array_dim != max_output_primitives) { + return _.diag(SPV_ERROR_INVALID_DATA, &inst) + << _.VkErrorID(7046) + << "The size of the array decorated with " + "PrimitivePointIndicesEXT (" + << primitive_array_dim + << ") must match the value specified " + "by OutputPrimitivesEXT (" + << max_output_primitives << "). "; + } + } + } + } } else { // Propagate this rule to all dependant ids in the global scope. id_to_at_reference_checks_[referenced_from_inst.id()].push_back( @@ -4650,12 +4740,6 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition( } break; case spv::BuiltIn::CullPrimitiveEXT: { - // We know this only allowed for Mesh Execution Model - if (spv_result_t error = ValidateMeshBuiltinInterfaceRules( - decoration, inst, spv::Op::OpTypeBool, inst)) { - return error; - } - for (const uint32_t entry_point : _.entry_points()) { auto* models = _.GetExecutionModels(entry_point); if (models->find(spv::ExecutionModel::MeshEXT) == models->end() && @@ -4683,88 +4767,19 @@ spv_result_t BuiltInsValidator::ValidateMeshShadingEXTBuiltinsAtDefinition( default: assert(0 && "Unexpected mesh EXT builtin"); } - for (const uint32_t entry_point : _.entry_points()) { - // execution modes and builtin are both global, so only check these - // buildit definitions if we know the entrypoint is Mesh - auto* models = _.GetExecutionModels(entry_point); - if (models->find(spv::ExecutionModel::MeshEXT) == models->end() && - models->find(spv::ExecutionModel::MeshNV) == models->end()) { - continue; - } - const auto* modes = _.GetExecutionModes(entry_point); - uint64_t max_output_primitives = _.GetOutputPrimitivesEXT(entry_point); - uint32_t underlying_type = 0; - if (spv_result_t error = - GetUnderlyingType(_, decoration, inst, &underlying_type)) { - return error; - } - - uint64_t primitive_array_dim = 0; - if (_.GetIdOpcode(underlying_type) == spv::Op::OpTypeArray) { - underlying_type = _.FindDef(underlying_type)->word(3u); - if (!_.EvalConstantValUint64(underlying_type, &primitive_array_dim)) { - assert(0 && "Array type definition is corrupt"); - } - } - switch (builtin) { - case spv::BuiltIn::PrimitivePointIndicesEXT: - if (!modes || !modes->count(spv::ExecutionMode::OutputPoints)) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7042) - << "The PrimitivePointIndicesEXT decoration must be used " - "with " - "the OutputPoints Execution Mode. "; - } - if (primitive_array_dim && - primitive_array_dim != max_output_primitives) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7046) - << "The size of the array decorated with " - "PrimitivePointIndicesEXT must match the value specified " - "by OutputPrimitivesEXT. "; - } - break; - case spv::BuiltIn::PrimitiveLineIndicesEXT: - if (!modes || !modes->count(spv::ExecutionMode::OutputLinesEXT)) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7048) - << "The PrimitiveLineIndicesEXT decoration must be used " - "with " - "the OutputLinesEXT Execution Mode. "; - } - if (primitive_array_dim && - primitive_array_dim != max_output_primitives) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7052) - << "The size of the array decorated with " - "PrimitiveLineIndicesEXT must match the value specified " - "by OutputPrimitivesEXT. "; - } - break; - case spv::BuiltIn::PrimitiveTriangleIndicesEXT: - if (!modes || !modes->count(spv::ExecutionMode::OutputTrianglesEXT)) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7054) - << "The PrimitiveTriangleIndicesEXT decoration must be used " - "with " - "the OutputTrianglesEXT Execution Mode. "; - } - if (primitive_array_dim && - primitive_array_dim != max_output_primitives) { - return _.diag(SPV_ERROR_INVALID_DATA, &inst) - << _.VkErrorID(7058) - << "The size of the array decorated with " - "PrimitiveTriangleIndicesEXT must match the value " - "specified " - "by OutputPrimitivesEXT. "; - } - break; - default: - break; // no validation rules - } + // - We know this only allowed for Mesh Execution Model. + // - The Scalar type is is boolean for CullPrimitiveEXT, the other 3 builtin + // (topology) don't need this type. + // - It is possible to have multiple mesh + // shaders (https://github.com/KhronosGroup/SPIRV-Tools/issues/6320) and we + // need to validate these at reference time. + if (spv_result_t error = ValidateMeshBuiltinInterfaceRules( + decoration, inst, spv::Op::OpTypeBool, inst)) { + return error; } } + // Seed at reference checks with this built-in. return ValidateMeshShadingEXTBuiltinsAtReference(decoration, inst, inst, inst); diff --git a/3rdparty/spirv-tools/source/val/validate_decorations.cpp b/3rdparty/spirv-tools/source/val/validate_decorations.cpp index 2c8ef958a..5d6dc55d1 100644 --- a/3rdparty/spirv-tools/source/val/validate_decorations.cpp +++ b/3rdparty/spirv-tools/source/val/validate_decorations.cpp @@ -398,24 +398,6 @@ bool IsAlignedTo(uint32_t offset, uint32_t alignment) { return 0 == (offset % alignment); } -std::string getStorageClassString(spv::StorageClass sc) { - switch (sc) { - case spv::StorageClass::Uniform: - return "Uniform"; - case spv::StorageClass::UniformConstant: - return "UniformConstant"; - case spv::StorageClass::PushConstant: - return "PushConstant"; - case spv::StorageClass::Workgroup: - return "Workgroup"; - case spv::StorageClass::PhysicalStorageBuffer: - return "PhysicalStorageBuffer"; - default: - // Only other valid storage class in these checks - return "StorageBuffer"; - } -} - // Returns SPV_SUCCESS if the given struct satisfies standard layout rules for // Block or BufferBlocks in Vulkan. Otherwise emits a diagnostic and returns // something other than SPV_SUCCESS. Matrices inherit the specified column @@ -442,7 +424,7 @@ spv_result_t checkLayout(uint32_t struct_id, spv::StorageClass storage_class, DiagnosticStream ds = std::move( vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(struct_id)) << "Structure id " << struct_id << " decorated as " << decoration_str - << " for variable in " << getStorageClassString(storage_class) + << " for variable in " << StorageClassToString(storage_class) << " storage class must follow " << (scalar_block_layout ? "scalar " @@ -1282,7 +1264,7 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) { if (!entry_points.empty() && !hasDecoration(var_id, spv::Decoration::Binding, vstate)) { return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(var_id)) - << getStorageClassString(storageClass) << " id '" << var_id + << StorageClassToString(storageClass) << " id '" << var_id << "' is missing Binding decoration.\n" << "From ARB_gl_spirv extension:\n" << "Uniform and shader storage block variables must " diff --git a/3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp b/3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp index 152870155..6f510fcff 100644 --- a/3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp +++ b/3rdparty/spirv-tools/source/val/validate_logical_pointers.cpp @@ -50,6 +50,9 @@ bool IsVariablePointer(const ValidationState_t& _, return iter->second; } + // Temporarily mark the instruction as NOT a variable pointer. + variable_pointers[inst->id()] = false; + bool is_var_ptr = false; switch (inst->opcode()) { case spv::Op::OpPtrAccessChain: @@ -625,7 +628,7 @@ spv_result_t TraceVariablePointers( trace_inst->uses()); std::unordered_set store_seen; while (!store_stack.empty()) { - const auto& use = store_stack.back(); + const auto use = store_stack.back(); store_stack.pop_back(); if (!store_seen.insert(use.first).second) { @@ -766,7 +769,7 @@ spv_result_t TraceUnmodifiedVariablePointers( trace_inst->uses()); std::unordered_set store_seen; while (!store_stack.empty()) { - const auto& use = store_stack.back(); + const auto use = store_stack.back(); store_stack.pop_back(); if (!store_seen.insert(use.first).second) { diff --git a/3rdparty/spirv-tools/source/val/validate_memory.cpp b/3rdparty/spirv-tools/source/val/validate_memory.cpp index 9372f5c38..d33c3c289 100644 --- a/3rdparty/spirv-tools/source/val/validate_memory.cpp +++ b/3rdparty/spirv-tools/source/val/validate_memory.cpp @@ -15,6 +15,7 @@ // limitations under the License. #include +#include #include #include @@ -773,16 +774,17 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { if (spvIsVulkanEnv(_.context()->target_env)) { // OpTypeRuntimeArray should only ever be in a container like OpTypeStruct, // so should never appear as a bare variable. - // Unless the module has the RuntimeDescriptorArrayEXT capability. + // Unless the module has the RuntimeDescriptorArray capability. if (value_type && value_type->opcode() == spv::Op::OpTypeRuntimeArray) { - if (!_.HasCapability(spv::Capability::RuntimeDescriptorArrayEXT)) { + if (!_.HasCapability(spv::Capability::RuntimeDescriptorArray)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << _.VkErrorID(4680) << "OpVariable, " << _.getIdName(inst->id()) << ", is attempting to create memory for an illegal type, " << "OpTypeRuntimeArray.\nFor Vulkan OpTypeRuntimeArray can only " << "appear as the final member of an OpTypeStruct, thus cannot " - << "be instantiated via OpVariable"; + << "be instantiated via OpVariable, unless the " + "RuntimeDescriptorArray Capability is declared"; } else { // A bare variable OpTypeRuntimeArray is allowed in this context, but // still need to check the storage class. @@ -791,7 +793,7 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) { storage_class != spv::StorageClass::UniformConstant) { return _.diag(SPV_ERROR_INVALID_ID, inst) << _.VkErrorID(4680) - << "For Vulkan with RuntimeDescriptorArrayEXT, a variable " + << "For Vulkan with RuntimeDescriptorArray, a variable " << "containing OpTypeRuntimeArray must have storage class of " << "StorageBuffer, Uniform, or UniformConstant."; } @@ -1118,6 +1120,29 @@ spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) { } } + // Skip checking if there is zero chance for this having a mesh shader + // entrypoint + if (_.HasCapability(spv::Capability::MeshShadingEXT) && + pointer_type->GetOperandAs(1) == + spv::StorageClass::Output) { + std::string errorVUID = _.VkErrorID(7107); + _.function(inst->function()->id()) + ->RegisterExecutionModelLimitation( + [errorVUID](spv::ExecutionModel model, std::string* message) { + // Seems the NV Mesh extension was less strict and allowed + // writting to outputs + if (model == spv::ExecutionModel::MeshEXT) { + if (message) { + *message = errorVUID + + "The Output Storage Class in a Mesh Execution " + "Model must not be read from"; + } + return false; + } + return true; + }); + } + _.RegisterQCOMImageProcessingTextureConsumer(pointer_id, inst, nullptr); return SPV_SUCCESS; @@ -1822,13 +1847,19 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, // At this point, we have fully walked down from the base using the indeces. // The type being pointed to should be the same as the result type. if (type_pointee->id() != result_type_pointee->id()) { + bool same_type = result_type_pointee->opcode() == type_pointee->opcode(); return _.diag(SPV_ERROR_INVALID_ID, inst) - << "Op" << spvOpcodeString(opcode) << " result type (Op" + << "Op" << spvOpcodeString(opcode) << " result type " + << _.getIdName(result_type_pointee->id()) << " (Op" << spvOpcodeString(result_type_pointee->opcode()) << ") does not match the type that results from indexing into the " "base " - " (Op" - << spvOpcodeString(type_pointee->opcode()) << ")."; + " " + << _.getIdName(type_pointee->id()) << " (Op" + << spvOpcodeString(type_pointee->opcode()) << ")." + << (same_type ? " (The types must be the exact same Id, so the " + "two types referenced are slighlty different)" + : ""); } } diff --git a/3rdparty/spirv-tools/source/val/validate_mesh_shading.cpp b/3rdparty/spirv-tools/source/val/validate_mesh_shading.cpp index 3bd1dbd38..d7352eb8d 100644 --- a/3rdparty/spirv-tools/source/val/validate_mesh_shading.cpp +++ b/3rdparty/spirv-tools/source/val/validate_mesh_shading.cpp @@ -132,9 +132,9 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) { } case spv::Op::OpVariable: { if (_.HasCapability(spv::Capability::MeshShadingEXT)) { - bool meshInterfaceVar = + bool is_mesh_interface_var = IsInterfaceVariable(_, inst, spv::ExecutionModel::MeshEXT); - bool fragInterfaceVar = + bool is_frag_interface_var = IsInterfaceVariable(_, inst, spv::ExecutionModel::Fragment); const spv::StorageClass storage_class = @@ -143,14 +143,14 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) { bool storage_input = (storage_class == spv::StorageClass::Input); if (_.HasDecoration(inst->id(), spv::Decoration::PerPrimitiveEXT)) { - if (fragInterfaceVar && !storage_input) { + if (is_frag_interface_var && !storage_input) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "PerPrimitiveEXT decoration must be applied only to " "variables in the Input Storage Class in the Fragment " "Execution Model."; } - if (meshInterfaceVar && !storage_output) { + if (is_mesh_interface_var && !storage_output) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << _.VkErrorID(4336) << "PerPrimitiveEXT decoration must be applied only to " @@ -158,6 +158,20 @@ spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) { "Storage Class in the MeshEXT Execution Model."; } } + + // This only applies to user interface variables, not built-ins (they + // are validated with the rest of the builtin) + if (is_mesh_interface_var && storage_output && + !_.HasDecoration(inst->id(), spv::Decoration::BuiltIn)) { + const Instruction* pointer_inst = _.FindDef(inst->type_id()); + if (pointer_inst->opcode() == spv::Op::OpTypePointer) { + if (!_.IsArrayType(pointer_inst->word(3))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "In the MeshEXT Execution Mode, all Output Variables " + "must contain an Array."; + } + } + } } break; } diff --git a/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp b/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp index f2b43b651..22d464fe2 100644 --- a/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp +++ b/3rdparty/spirv-tools/source/val/validate_mode_setting.cpp @@ -27,6 +27,48 @@ namespace spvtools { namespace val { namespace { +// TODO - Make a common util if someone else needs it too outside this file +const char* ExecutionModelToString(spv::ExecutionModel value) { + switch (value) { + case spv::ExecutionModel::Vertex: + return "Vertex"; + case spv::ExecutionModel::TessellationControl: + return "TessellationControl"; + case spv::ExecutionModel::TessellationEvaluation: + return "TessellationEvaluation"; + case spv::ExecutionModel::Geometry: + return "Geometry"; + case spv::ExecutionModel::Fragment: + return "Fragment"; + case spv::ExecutionModel::GLCompute: + return "GLCompute"; + case spv::ExecutionModel::Kernel: + return "Kernel"; + case spv::ExecutionModel::TaskNV: + return "TaskNV"; + case spv::ExecutionModel::MeshNV: + return "MeshNV"; + case spv::ExecutionModel::RayGenerationKHR: + return "RayGenerationKHR"; + case spv::ExecutionModel::IntersectionKHR: + return "IntersectionKHR"; + case spv::ExecutionModel::AnyHitKHR: + return "AnyHitKHR"; + case spv::ExecutionModel::ClosestHitKHR: + return "ClosestHitKHR"; + case spv::ExecutionModel::MissKHR: + return "MissKHR"; + case spv::ExecutionModel::CallableKHR: + return "CallableKHR"; + case spv::ExecutionModel::TaskEXT: + return "TaskEXT"; + case spv::ExecutionModel::MeshEXT: + return "MeshEXT"; + default: + return "Unknown"; + } +} + spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) { const auto entry_point_id = inst->GetOperandAs(1); auto entry_point = _.FindDef(entry_point_id); @@ -306,74 +348,79 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) { } if (spvIsVulkanEnv(_.context()->target_env)) { - switch (execution_model) { - case spv::ExecutionModel::GLCompute: - if (!has_mode(spv::ExecutionMode::LocalSize)) { - bool ok = has_workgroup_size || has_local_size_id; - if (!ok && _.HasCapability(spv::Capability::TileShadingQCOM)) { - ok = has_mode(spv::ExecutionMode::TileShadingRateQCOM); - } - if (!ok) { + // SPV_QCOM_tile_shading checks + if (execution_model == spv::ExecutionModel::GLCompute) { + if (_.HasCapability(spv::Capability::TileShadingQCOM)) { + if (has_mode(spv::ExecutionMode::TileShadingRateQCOM) && + (has_mode(spv::ExecutionMode::LocalSize) || + has_mode(spv::ExecutionMode::LocalSizeId))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "If the TileShadingRateQCOM execution mode is used, " + << "LocalSize and LocalSizeId must not be specified."; + } + if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The NonCoherentTileAttachmentQCOM execution mode must " + "not be used in any stage other than fragment."; + } + } else { + if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "If the TileShadingRateQCOM execution mode is used, the " + "TileShadingQCOM capability must be enabled."; + } + } + } else { + if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The TileShadingRateQCOM execution mode must not be used " + "in any stage other than compute."; + } + if (execution_model != spv::ExecutionModel::Fragment) { + if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "The NonCoherentTileAttachmentQCOM execution mode must " + "not be used in any stage other than fragment."; + } + if (_.HasCapability(spv::Capability::TileShadingQCOM)) { + return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) + << "The TileShadingQCOM capability must not be enabled in " + "any stage other than compute or fragment."; + } + } else { + if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { + if (!_.HasCapability(spv::Capability::TileShadingQCOM)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) - << _.VkErrorID(10685) - << "In the Vulkan environment, GLCompute execution model " - "entry points require either the " - << (_.HasCapability(spv::Capability::TileShadingQCOM) - ? "TileShadingRateQCOM, " - : "") - << "LocalSize or LocalSizeId execution mode or an object " - "decorated with WorkgroupSize must be specified."; + << "If the NonCoherentTileAttachmentReadQCOM execution " + "mode is used, the TileShadingQCOM capability must be " + "enabled."; } } + } + } - if (_.HasCapability(spv::Capability::TileShadingQCOM)) { - if (has_mode(spv::ExecutionMode::TileShadingRateQCOM) && - (has_mode(spv::ExecutionMode::LocalSize) || - has_mode(spv::ExecutionMode::LocalSizeId))) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "If the TileShadingRateQCOM execution mode is used, " - << "LocalSize and LocalSizeId must not be specified."; - } - if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "The NonCoherentTileAttachmentQCOM execution mode must " - "not be used in any stage other than fragment."; - } - } else { - if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "If the TileShadingRateQCOM execution mode is used, the " - "TileShadingQCOM capability must be enabled."; - } + switch (execution_model) { + case spv::ExecutionModel::GLCompute: + case spv::ExecutionModel::MeshEXT: + case spv::ExecutionModel::MeshNV: + case spv::ExecutionModel::TaskEXT: + case spv::ExecutionModel::TaskNV: + if (!has_mode(spv::ExecutionMode::LocalSize) && !has_workgroup_size && + !has_local_size_id && + !has_mode(spv::ExecutionMode::TileShadingRateQCOM)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << _.VkErrorID(10685) << "In the Vulkan environment, " + << ExecutionModelToString(execution_model) + << " execution model " + "entry points require either the " + << (_.HasCapability(spv::Capability::TileShadingQCOM) + ? "TileShadingRateQCOM, " + : "") + << "LocalSize or LocalSizeId execution mode or an object " + "decorated with WorkgroupSize must be specified."; } break; default: - if (has_mode(spv::ExecutionMode::TileShadingRateQCOM)) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "The TileShadingRateQCOM execution mode must not be used " - "in any stage other than compute."; - } - if (execution_model != spv::ExecutionModel::Fragment) { - if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "The NonCoherentTileAttachmentQCOM execution mode must " - "not be used in any stage other than fragment."; - } - if (_.HasCapability(spv::Capability::TileShadingQCOM)) { - return _.diag(SPV_ERROR_INVALID_CAPABILITY, inst) - << "The TileShadingQCOM capability must not be enabled in " - "any stage other than compute or fragment."; - } - } else { - if (has_mode(spv::ExecutionMode::NonCoherentTileAttachmentReadQCOM)) { - if (!_.HasCapability(spv::Capability::TileShadingQCOM)) { - return _.diag(SPV_ERROR_INVALID_DATA, inst) - << "If the NonCoherentTileAttachmentReadQCOM execution " - "mode is used, the TileShadingQCOM capability must be " - "enabled."; - } - } - } break; } } diff --git a/3rdparty/spirv-tools/source/val/validate_type.cpp b/3rdparty/spirv-tools/source/val/validate_type.cpp index 786a22414..12c1ef099 100644 --- a/3rdparty/spirv-tools/source/val/validate_type.cpp +++ b/3rdparty/spirv-tools/source/val/validate_type.cpp @@ -737,9 +737,10 @@ spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _, } } - uint64_t scope_value; - if (_.EvalConstantValUint64(scope_id, &scope_value)) { - if (scope_value == static_cast(spv::Scope::Workgroup)) { + uint64_t scope_raw_value; + if (_.EvalConstantValUint64(scope_id, &scope_raw_value)) { + spv::Scope scope_value = static_cast(scope_raw_value); + if (scope_value == spv::Scope::Workgroup) { for (auto entry_point_id : _.entry_points()) { if (!_.EntryPointHasLocalSizeOrId(entry_point_id)) { return _.diag(SPV_ERROR_INVALID_ID, inst) @@ -766,6 +767,13 @@ spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _, } } } + if (scope_value != spv::Scope::Workgroup && + scope_value != spv::Scope::Subgroup) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << _.VkErrorID(12243) + << "OpTypeCooperativeMatrixKHR Scope is limited to Workgroup and " + "Subgroup"; + } } return SPV_SUCCESS; diff --git a/3rdparty/spirv-tools/source/val/validation_state.cpp b/3rdparty/spirv-tools/source/val/validation_state.cpp index bc9d8358e..cf2ff3bc5 100644 --- a/3rdparty/spirv-tools/source/val/validation_state.cpp +++ b/3rdparty/spirv-tools/source/val/validation_state.cpp @@ -2703,6 +2703,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id, return VUID_WRAP(VUID-ViewportIndex-ViewportIndex-07060); case 7102: return VUID_WRAP(VUID-StandaloneSpirv-MeshEXT-07102); + case 7107: + return VUID_WRAP(VUID-StandaloneSpirv-MeshEXT-07107); case 7290: return VUID_WRAP(VUID-StandaloneSpirv-Input-07290); case 7320: @@ -2813,6 +2815,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id, return VUID_WRAP(VUID-StandaloneSpirv-OpUntypedVariableKHR-11167); case 11805: return VUID_WRAP(VUID-StandaloneSpirv-OpArrayLength-11805); + case 12243: + return VUID_WRAP(VUID-StandaloneSpirv-Scope-12243); default: return ""; // unknown id }